Text2SqlParser(vocab: allennlp.data.vocabulary.Vocabulary, utterance_embedder: allennlp.modules.text_field_embedders.text_field_embedder.TextFieldEmbedder, action_embedding_dim: int, encoder: allennlp.modules.seq2seq_encoders.seq2seq_encoder.Seq2SeqEncoder, decoder_beam_search: allennlp.state_machines.beam_search.BeamSearch, max_decoding_steps: int, input_attention: allennlp.modules.attention.attention.Attention, add_action_bias: bool = True, dropout: float = 0.0, initializer: allennlp.nn.initializers.InitializerApplicator = <allennlp.nn.initializers.InitializerApplicator object>, regularizer: typing.Union[allennlp.nn.regularizers.regularizer_applicator.RegularizerApplicator, NoneType] = None) → None¶
- vocab :
- utterance_embedder :
Embedder for utterances.
- action_embedding_dim :
Dimension to use for action embeddings.
- encoder :
The encoder to use for the input utterance.
- decoder_beam_search :
Beam search used to retrieve best sequences after training.
- max_decoding_steps :
When we’re decoding with a beam search, what’s the maximum number of steps we should take? This only applies at evaluation time, not during training.
- input_attention: ``Attention``
We compute an attention over the input utterance at each step of the decoder, using the decoder hidden state as the query. Passed to the transition function.
- add_action_bias :
bool, optional (default=True)
True, we will learn a bias weight for each action that gets used when predicting that action, in addition to its embedding.
- dropout :
float, optional (default=0)
If greater than 0, we will apply dropout with this probability after all encoders (pytorch LSTMs do not apply dropout to their last layer).
decode(output_dict: typing.Dict[str, torch.Tensor]) → typing.Dict[str, torch.Tensor]¶
This method overrides
Model.decode, which gets called after
Model.forward, at test time, to finalize predictions. This is (confusingly) a separate notion from the “decoder” in “encoder/decoder”, where that decoder logic lives in
This method trims the output predictions to the first end symbol, replaces indices with corresponding tokens, and adds a field called
forward(tokens: typing.Dict[str, torch.LongTensor], valid_actions: typing.List[typing.List[allennlp.data.fields.production_rule_field.ProductionRule]], action_sequence: torch.LongTensor = None) → typing.Dict[str, torch.Tensor]¶
We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we’re training, or a BeamSearch for inference, if we’re not.
- tokens : Dict[str, torch.LongTensor]
The output of
TextField.as_array()applied on the tokens
TextField. This will be passed through a
TextFieldEmbedderand then through an encoder.
- valid_actions :
A list of all possible actions for each
Worldin the batch, indexed into a
ProductionRuleField. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder.
- target_action_sequence : torch.Tensor, optional (default=None)
The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape
(batch_size, sequence_length, 1). We remove the trailing dimension.
- sql_queries : List[List[str]], optional (default=None)
A list of the SQL queries that are given during training or validation.
get_metrics(reset: bool = False) → typing.Dict[str, float]¶
We track four metrics here:
1. exact_match, which is the percentage of the time that our best output action sequence matches the SQL query exactly.
2. denotation_acc, which is the percentage of examples where we get the correct denotation. This is the typical “accuracy” metric, and it is what you should usually report in an experimental result. You need to be careful, though, that you’re computing this on the full data, and not just the subset that can be parsed. (make sure you pass “keep_if_unparseable=True” to the dataset reader, which we do for validation data, but not training data).
3. valid_sql_query, which is the percentage of time that decoding actually produces a valid SQL query. We might not produce a valid SQL query if the decoder gets into a repetitive loop, or we’re trying to produce a super long SQL query and run out of time steps, or something.
4. action_similarity, which is how similar the action sequence predicted is to the actual action sequence. This is basically a soft measure of exact_match.
- vocab :