allennlp.state_machines.states

This module contains the State abstraction for defining state-machine-based decoders, and some pre-built concrete State classes for various kinds of decoding (e.g., a GrammarBasedState for doing grammar-based decoding, where the output is a sequence of production rules from a grammar).

The module also has some Statelet classes to help represent the State by grouping together related pieces, including RnnStatelet, which you can use to keep track of a decoder RNN’s internal state, GrammarStatelet, which keeps track of what actions are allowed at each timestep of decoding (if your outputs are production rules from a grammar), and ChecklistStatelet that keeps track of coverage information if you are training a coverage-based parser.

class allennlp.state_machines.states.state.State(batch_indices: typing.List[int], action_history: typing.List[typing.List[int]], score: typing.List[torch.Tensor]) → None[source]

Bases: typing.Generic

Represents the (batched) state of a transition-based decoder.

There are two different kinds of batching we need to distinguish here. First, there’s the batch of training instances passed to model.forward(). We’ll use “batch” and batch_size to refer to this through the docs and code. We additionally batch together computation for several states at the same time, where each state could be from the same training instance in the original batch, or different instances. We use “group” and group_size in the docs and code to refer to this kind of batching, to distinguish it from the batch of training instances.

So, using this terminology, a single State object represents a grouped collection of states. Because different states in this group might finish at different timesteps, we have methods and member variables to handle some bookkeeping around this, to split and regroup things.

Parameters:
batch_indices : List[int]

A group_size-length list, where each element specifies which batch_index that group element came from.

Our internal variables (like scores, action histories, hidden states, whatever) are grouped, and our group_size is likely different from the original batch_size. This variable keeps track of which batch instance each group element came from (e.g., to know what the correct action sequences are, or which encoder outputs to use).

action_history : List[List[int]]

The list of actions taken so far in this state. This is also grouped, so each state in the group has a list of actions.

score : List[torch.Tensor]

This state’s score. It’s a variable, because typically we’ll be computing a loss based on this score, and using it for backprop during training. Like the other variables here, this is a group_size-length list.

classmethod combine_states(states: typing.List[T]) → T[source]

Combines a list of states, each with their own group size, into a single state.

is_finished() → bool[source]

If this state has a group_size of 1, this returns whether the single action sequence in this state is finished or not. If this state has a group_size other than 1, this method raises an error.

class allennlp.state_machines.states.grammar_based_state.GrammarBasedState(batch_indices: typing.List[int], action_history: typing.List[typing.List[int]], score: typing.List[torch.Tensor], rnn_state: typing.List[allennlp.state_machines.states.rnn_statelet.RnnStatelet], grammar_state: typing.List[allennlp.state_machines.states.grammar_statelet.GrammarStatelet], possible_actions: typing.List[typing.List[allennlp.data.fields.production_rule_field.ProductionRule]], extras: typing.List[typing.Any] = None, debug_info: typing.List = None) → None[source]

Bases: allennlp.state_machines.states.state.State

A generic State that’s suitable for most models that do grammar-based decoding. We keep around a group of states, and each element in the group has a few things: a batch index, an action history, a score, an RnnStatelet, and a GrammarStatelet. We additionally have some information that’s independent of any particular group element: a list of all possible actions for all batch instances passed to model.forward(), and a extras field that you can use if you really need some extra information about each batch instance (like a string description, or other metadata).

Finally, we also have a specially-treated, optional debug_info field. If this is given, it should be an empty list for each group instance when the initial state is created. In that case, we will keep around information about the actions considered at each timestep of decoding and other things that you might want to visualize in a demo. This probably isn’t necessary for training, and to get it right we need to copy a bunch of data structures for each new state, so it’s best used only at evaluation / demo time.

Parameters:
batch_indices : List[int]

Passed to super class; see docs there.

action_history : List[List[int]]

Passed to super class; see docs there.

score : List[torch.Tensor]

Passed to super class; see docs there.

rnn_state : List[RnnStatelet]

An RnnStatelet for every group element. This keeps track of the current decoder hidden state, the previous decoder output, the output from the encoder (for computing attentions), and other things that are typical seq2seq decoder state things.

grammar_state : List[GrammarStatelet]

This hold the current grammar state for each element of the group. The GrammarStatelet keeps track of which actions are currently valid.

possible_actions : List[List[ProductionRule]]

The list of all possible actions that was passed to model.forward(). We need this so we can recover production strings, which we need to update grammar states.

extras : List[Any], optional (default=None)

If you need to keep around some extra data for each instance in the batch, you can put that in here, without adding another field. This should be used very sparingly, as there is no type checking or anything done on the contents of this field, and it will just be passed around between States as-is, without copying.

debug_info : List[Any], optional (default=None).
classmethod combine_states(states: typing.Sequence[_ForwardRef('GrammarBasedState')]) → allennlp.state_machines.states.grammar_based_state.GrammarBasedState[source]
get_valid_actions() → typing.List[typing.Dict[str, typing.Tuple[[torch.Tensor, torch.Tensor], typing.List[int]]]][source]

Returns a list of valid actions for each element of the group.

is_finished() → bool[source]
new_state_from_group_index(group_index: int, action: int, new_score: torch.Tensor, new_rnn_state: allennlp.state_machines.states.rnn_statelet.RnnStatelet, considered_actions: typing.List[int] = None, action_probabilities: typing.List[float] = None, attention_weights: torch.Tensor = None) → allennlp.state_machines.states.grammar_based_state.GrammarBasedState[source]
print_action_history(group_index: int = None) → None[source]
class allennlp.state_machines.states.coverage_state.CoverageState(batch_indices: typing.List[int], action_history: typing.List[typing.List[int]], score: typing.List[torch.Tensor], rnn_state: typing.List[allennlp.state_machines.states.rnn_statelet.RnnStatelet], grammar_state: typing.List[allennlp.state_machines.states.grammar_statelet.GrammarStatelet], checklist_state: typing.List[allennlp.state_machines.states.checklist_statelet.ChecklistStatelet], possible_actions: typing.List[typing.List[allennlp.data.fields.production_rule_field.ProductionRule]], extras: typing.List[typing.Any] = None, debug_info: typing.List = None) → None[source]

Bases: allennlp.state_machines.states.grammar_based_state.GrammarBasedState

This State adds one field to a GrammarBasedState: a ChecklistStatelet that is used to specify a set of actions that should be taken during decoder, and keep track of which of those actions have already been selected.

We only provide documentation for the ChecklistStatelet here; for the rest, see GrammarBasedState.

Parameters:
batch_indices : List[int]
action_history : List[List[int]]
score : List[torch.Tensor]
rnn_state : List[RnnStatelet]
grammar_state : List[GrammarStatelet]
checklist_state : List[ChecklistStatelet]

This holds the current checklist state for each element of the group. The ChecklistStatelet keeps track of which actions are preferred by some agenda, and which of those have already been selected during decoding.

possible_actions : List[List[ProductionRule]]
extras : List[Any], optional (default=None)
debug_info : List[Any], optional (default=None).
classmethod combine_states(states: typing.Sequence[_ForwardRef('CoverageState')]) → allennlp.state_machines.states.coverage_state.CoverageState[source]
new_state_from_group_index(group_index: int, action: int, new_score: torch.Tensor, new_rnn_state: allennlp.state_machines.states.rnn_statelet.RnnStatelet, considered_actions: typing.List[int] = None, action_probabilities: typing.List[float] = None, attention_weights: torch.Tensor = None) → allennlp.state_machines.states.coverage_state.CoverageState[source]
class allennlp.state_machines.states.rnn_statelet.RnnStatelet(hidden_state: torch.Tensor, memory_cell: torch.Tensor, previous_action_embedding: torch.Tensor, attended_input: torch.Tensor, encoder_outputs: typing.List[torch.Tensor], encoder_output_mask: typing.List[torch.Tensor]) → None[source]

Bases: object

This class keeps track of all of decoder-RNN-related variables that you need during decoding. This includes things like the current decoder hidden state, the memory cell (for LSTM decoders), the encoder output that you need for computing attentions, and so on.

This is intended to be used inside a State, which likely has other things it has to keep track of for doing constrained decoding.

Parameters:
hidden_state : torch.Tensor

This holds the LSTM hidden state, with shape (decoder_output_dim,) if the decoder has 1 layer and (num_layers, decoder_output_dim) otherwise.

memory_cell : torch.Tensor

This holds the LSTM memory cell, with shape (decoder_output_dim,) if the decoder has 1 layer and (num_layers, decoder_output_dim) otherwise.

previous_action_embedding : torch.Tensor

This holds the embedding for the action we took at the last timestep (which gets input to the decoder). Has shape (action_embedding_dim,).

attended_input : torch.Tensor

This holds the attention-weighted sum over the input representations that we computed in the previous timestep. We keep this as part of the state because we use the previous attention as part of our decoder cell update. Has shape (encoder_output_dim,).

encoder_outputs : List[torch.Tensor]

A list of variables, each of shape (input_sequence_length, encoder_output_dim), containing the encoder outputs at each timestep. The list is over batch elements, and we do the input this way so we can easily do a torch.cat on a list of indices into this batched list.

Note that all of the above parameters are single tensors, while the encoder outputs and mask are lists of length batch_size. We always pass around the encoder outputs and mask unmodified, regardless of what’s in the grouping for this state. We’ll use the batch_indices for the group to pull pieces out of these lists when we’re ready to actually do some computation.

encoder_output_mask : List[torch.Tensor]

A list of variables, each of shape (input_sequence_length,), containing a mask over question tokens for each batch instance. This is a list over batch elements, for the same reasons as above.

class allennlp.state_machines.states.grammar_statelet.GrammarStatelet(nonterminal_stack: typing.List[str], valid_actions: typing.Dict[str, ActionRepresentation], is_nonterminal: typing.Callable[str, bool], reverse_productions: bool = True) → None[source]

Bases: typing.Generic

A GrammarStatelet keeps track of the currently valid actions at every step of decoding.

This class is relatively simple: we have a non-terminal stack which tracks which non-terminals we still need to expand. At every timestep of decoding, we take an action that pops something off of the non-terminal stack, and possibly pushes more things on. The grammar state is “finished” when the non-terminal stack is empty.

At any point during decoding, you can query this object to get a representation of all of the valid actions in the current state. The representation is something that you provide when constructing the initial state, in whatever form you want, and we just hold on to it for you and return it when you ask. Putting this in here is purely for convenience, to group together pieces of state that are related to taking actions - if you want to handle the action representations outside of this class, that would work just fine too.

Parameters:
nonterminal_stack : List[str]

Holds the list of non-terminals that still need to be expanded. This starts out as [START_SYMBOL], and decoding ends when this is empty. Every time we take an action, we update the non-terminal stack and the context-dependent valid actions, and we use what’s on the stack to decide which actions are valid in the current state.

valid_actions : Dict[str, ActionRepresentation]

A mapping from non-terminals (represented as strings) to all valid expansions of that non-terminal. The class that constructs this object can pick how it wants the actions to be represented.

is_nonterminal : Callable[[str], bool]

A function that is used to determine whether each piece of the RHS of the action string is a non-terminal that needs to be added to the non-terminal stack. You can use type_declaraction.is_nonterminal here, or write your own function if that one doesn’t work for your domain.

reverse_productions: ``bool``, optional (default=True)

A flag that reverses the production rules when True. If the production rules are reversed, then the first non-terminal in the production will be popped off the stack first, giving us left-to-right production. If this is False, you will get right-to-left production.

get_valid_actions() → ActionRepresentation[source]

Returns the valid actions in the current grammar state. The Model determines what exactly this looks like when it constructs the valid_actions dictionary.

is_finished() → bool[source]

Have we finished producing our logical form? We have finished producing the logical form if and only if there are no more non-terminals on the stack.

take_action(production_rule: str) → allennlp.state_machines.states.grammar_statelet.GrammarStatelet[source]

Takes an action in the current grammar state, returning a new grammar state with whatever updates are necessary. The production rule is assumed to be formatted as “LHS -> RHS”.

This will update the non-terminal stack. Updating the non-terminal stack involves popping the non-terminal that was expanded off of the stack, then pushing on any non-terminals in the production rule back on the stack.

For example, if our current nonterminal_stack is ["r", "<e,r>", "d"], and action is d -> [<e,d>, e], the resulting stack will be ["r", "<e,r>", "e", "<e,d>"].

If self._reverse_productions is set to False then we push the non-terminals on in in their given order, which means that the first non-terminal in the production rule gets popped off the stack last.

class allennlp.state_machines.states.lambda_grammar_statelet.LambdaGrammarStatelet(nonterminal_stack: typing.List[str], lambda_stacks: typing.Dict[typing.Tuple[str, str], typing.List[str]], valid_actions: typing.Dict[str, typing.Dict[str, typing.Tuple[[torch.Tensor, torch.Tensor], typing.List[int]]]], context_actions: typing.Dict[str, typing.Tuple[[torch.Tensor, torch.Tensor], int]], is_nonterminal: typing.Callable[str, bool]) → None[source]

Bases: object

A LambdaGrammarStatelet is a GrammarStatelet that adds lambda productions. These productions change the valid actions depending on the current state (you can produce lambda variables inside the scope of a lambda expression), so we need some extra bookkeeping to keep track of them.

We only use this for the WikiTablesSemanticParser, and so we just hard-code the action representation type here, because the way we handle the context / global / linked action representations is a little convoluted. It would be hard to make this generic in the way that we use it. So we’ll not worry about that until there are other use cases of this class that need it.

Parameters:
nonterminal_stack : List[str]

Holds the list of non-terminals that still need to be expanded. This starts out as [START_SYMBOL], and decoding ends when this is empty. Every time we take an action, we update the non-terminal stack and the context-dependent valid actions, and we use what’s on the stack to decide which actions are valid in the current state.

lambda_stacks : Dict[Tuple[str, str], List[str]]

The lambda stack keeps track of when we’re in the scope of a lambda function. The dictionary is keyed by the production rule we are adding (like “r -> x”, separated into left hand side and right hand side, where the LHS is the type of the lambda variable and the RHS is the variable itself), and the value is a nonterminal stack much like nonterminal_stack. When the stack becomes empty, we remove the lambda entry.

valid_actions : Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]]

A mapping from non-terminals (represented as strings) to all valid expansions of that non-terminal. The way we represent the valid expansions is a little complicated: we use a dictionary of action types, where the key is the action type (like “global”, “linked”, or whatever your model is expecting), and the value is a tuple representing all actions of that type. The tuple is (input tensor, output tensor, action id). The input tensor has the representation that is used when selecting actions, for all actions of this type. The output tensor has the representation that is used when feeding the action to the next step of the decoder (this could just be the same as the input tensor). The action ids are a list of indices into the main action list for each batch instance.

context_actions : Dict[str, Tuple[torch.Tensor, torch.Tensor, int]]

Variable actions are never included in the valid_actions dictionary, because they are only valid depending on the current grammar state. This dictionary maps from the string representation of all such actions to the tensor representations of the actions. These will get added onto the “global” key in the valid_actions when they are allowed.

is_nonterminal : Callable[[str], bool]

A function that is used to determine whether each piece of the RHS of the action string is a non-terminal that needs to be added to the non-terminal stack. You can use type_declaraction.is_nonterminal here, or write your own function if that one doesn’t work for your domain.

get_valid_actions() → typing.Dict[str, typing.Tuple[[torch.Tensor, torch.Tensor], typing.List[int]]][source]

Returns the valid actions in the current grammar state. See the class docstring for a description of what we’re returning here.

is_finished() → bool[source]

Have we finished producing our logical form? We have finished producing the logical form if and only if there are no more non-terminals on the stack.

take_action(production_rule: str) → allennlp.state_machines.states.lambda_grammar_statelet.LambdaGrammarStatelet[source]

Takes an action in the current grammar state, returning a new grammar state with whatever updates are necessary. The production rule is assumed to be formatted as “LHS -> RHS”.

This will update the non-terminal stack and the context-dependent actions. Updating the non-terminal stack involves popping the non-terminal that was expanded off of the stack, then pushing on any non-terminals in the production rule back on the stack. We push the non-terminals on in reverse order, so that the first non-terminal in the production rule gets popped off the stack first.

For example, if our current nonterminal_stack is ["r", "<e,r>", "d"], and action is d -> [<e,d>, e], the resulting stack will be ["r", "<e,r>", "e", "<e,d>"].

class allennlp.state_machines.states.checklist_statelet.ChecklistStatelet(terminal_actions: torch.Tensor, checklist_target: torch.Tensor, checklist_mask: torch.Tensor, checklist: torch.Tensor, terminal_indices_dict: typing.Dict[int, int] = None) → None[source]

Bases: object

This class keeps track of checklist related variables that are used while training a coverage based semantic parser (or any other kind of transition based constrained decoder). This is inteded to be used within a State.

Parameters:
terminal_actions : torch.Tensor

A vector containing the indices of terminal actions, required for computing checklists for next states based on current actions. The idea is that we will build checklists corresponding to the presence or absence of just the terminal actions. But in principle, they can be all actions that are relevant to checklist computation.

checklist_target : torch.Tensor

Targets corresponding to checklist that indicate the states in which we want the checklist to ideally be. It is the same size as terminal_actions, and it contains 1 for each corresponding action in the list that we want to see in the final logical form, and 0 for each corresponding action that we do not.

checklist_mask : torch.Tensor

Mask corresponding to terminal_actions, indicating which of those actions are relevant for checklist computation. For example, if the parser is penalizing non-agenda terminal actions, all the terminal actions are relevant.

checklist : torch.Tensor

A checklist indicating how many times each action in its agenda has been chosen previously. It contains the actual counts of the agenda actions.

terminal_indices_dict: ``Dict[int, int]``, optional

Mapping from batch action indices to indices in any of the four vectors above. If not provided, this mapping will be computed here.

get_balance() → torch.Tensor[source]
update(action: torch.Tensor) → allennlp.state_machines.states.checklist_statelet.ChecklistStatelet[source]

Takes an action index, updates checklist and returns an updated state.