allennlp.nn.decoding

This module contains code for transition-based decoding. “Transition-based decoding” is where you start in some state, iteratively transition between states, and have some kind of supervision signal that tells you which end states, or which transition sequences, are “good”.

If you want to do decoding for a vocabulary-based model, where the allowable outputs are the same at every timestep of decoding, this code is not what you are looking for, and it will be quite inefficient compared to other things you could do.

The key abstractions in this code are the following:

  • DecoderState represents the current state of decoding, containing a list of all of the actions taken so far, and a current score for the state. It also has methods around determining whether the state is “finished” and for combining states for batched computation.
  • DecoderStep is a torch.nn.Module that models the transition function between states. Its main method is take_step, which generates a ranked list of next states given a current state.
  • DecoderTrainer is an algorithm for training the transition function with some kind of supervision signal. There are many options for training algorithms and supervision signals; this is an abstract class that is generic over the type of the supervision signal.

The module also has some classes to help represent the DecoderState, including RnnState, which you can use to keep track of a decoder RNN’s internal state, GrammarState, which keeps track of what actions are allowed at each timestep of decoding, if your outputs are production rules from a grammar, and ChecklistState that keeps track of coverage inforation if you are training a coverage based parser.

There is also a generic BeamSearch class for finding the k highest-scoring transition sequences given a trained DecoderStep and an initial DecoderState.

class allennlp.nn.decoding.decoder_step.DecoderStep[source]

Bases: torch.nn.modules.module.Module, typing.Generic

A DecoderStep is a module that assigns scores to state transitions in a transition-based decoder.

The DecoderStep takes a DecoderState and outputs a ranked list of next states, ordered by the state’s score.

The intention with this class is that a model will implement a subclass of DecoderStep that defines how exactly you want to handle the input and what computations get done at each step of decoding, and how states are scored. This subclass then gets passed to a DecoderTrainer to have its parameters trained.

take_step(state: StateType, max_actions: int = None, allowed_actions: typing.List[typing.Set] = None) → typing.List[StateType][source]

The main method in the DecoderStep API. This function defines the computation done at each step of decoding and returns a ranked list of next states.

The input state is grouped, to allow for efficient computation, but the output states should all have a group_size of 1, to make things easier on the decoding algorithm. They will get regrouped later as needed.

Because of the way we handle grouping in the decoder states, constructing a new state is actually a relatively expensive operation. If you know a priori that only some of the states will be needed (either because you have a set of gold action sequences, or you have a fixed beam size), passing that information into this function will keep us from constructing more states than we need, which will greatly speed up your computation.

IMPORTANT: This method must returns states already sorted by their score, otherwise BeamSearch and other methods will break. For efficiency, we do not perform an additional sort in those methods.

ALSO IMPORTANT: When alowed_actions is given and max_actions is not, we assume you want to evaluate all possible states and do not need any sorting (e.g., this is true for maximum marginal likelihood training that does not use a beam search). In this case, we may skip the sorting step for efficiency reasons.

Parameters:
state : DecoderState

The current state of the decoder, which we will take a step from. We may be grouping together computation for several states here. Because we can have several states for each instance in the original batch being evaluated at the same time, we use group_size for this kind of batching, and batch_size for the original batch in model.forward.

max_actions : int, optional

If you know that you will only need a certain number of states out of this (e.g., in a beam search), you can pass in the max number of actions that you need, and we will only construct that many states (for each batch instance - not for each group instance!). This can save a whole lot of computation if you have an action space that’s much larger than your beam size.

allowed_actions : List[Set], optional

If the DecoderTrainer has constraints on which actions need to be evaluated (e.g., maximum marginal likelihood only needs to evaluate action sequences in a given set), you can pass those constraints here, to avoid constructing state objects unnecessarily. If there are no constraints from the trainer, passing a value of None here will allow all actions to be considered.

This is a list because it is batched - every instance in the batch has a set of allowed actions. Note that the size of this list is the group_size in the DecoderState, not the batch_size of model.forward. The training algorithm needs to convert from the batched allowed action sequences that it has to a grouped allowed action sequence list.

Returns:
next_states : List[DecoderState]

A list of next states, ordered by score.

class allennlp.nn.decoding.decoder_state.DecoderState(batch_indices: typing.List[int], action_history: typing.List[typing.List[int]], score: typing.List[torch.Tensor]) → None[source]

Bases: typing.Generic

Represents the (batched) state of a transition-based decoder.

There are two different kinds of batching we need to distinguish here. First, there’s the batch of training instances passed to model.forward(). We’ll use “batch” and batch_size to refer to this through the docs and code. We additionally batch together computation for several states at the same time, where each state could be from the same training instance in the original batch, or different instances. We use “group” and group_size in the docs and code to refer to this kind of batching, to distinguish it from the batch of training instances.

So, using this terminology, a single DecoderState object represents a grouped collection of states. Because different states in this group might finish at different timesteps, we have methods and member variables to handle some bookkeeping around this, to split and regroup things.

Parameters:
batch_indices : List[int]

A group_size-length list, where each element specifies which batch_index that group element came from.

Our internal variables (like scores, action histories, hidden states, whatever) are grouped, and our group_size is likely different from the original batch_size. This variable keeps track of which batch instance each group element came from (e.g., to know what the correct action sequences are, or which encoder outputs to use).

action_history : List[List[int]]

The list of actions taken so far in this state. This is also grouped, so each state in the group has a list of actions.

score : List[torch.Tensor]

This state’s score. It’s a variable, because typically we’ll be computing a loss based on this score, and using it for backprop during training. Like the other variables here, this is a group_size-length list.

classmethod combine_states(states: typing.List[T]) → T[source]

Combines a list of states, each with their own group size, into a single state.

is_finished() → bool[source]

If this state has a group_size of 1, this returns whether the single action sequence in this state is finished or not. If this state has a group_size other than 1, this method raises an error.

class allennlp.nn.decoding.rnn_state.RnnState(hidden_state: torch.Tensor, memory_cell: torch.Tensor, previous_action_embedding: torch.Tensor, attended_input: torch.Tensor, encoder_outputs: typing.List[torch.Tensor], encoder_output_mask: typing.List[torch.Tensor]) → None[source]

Bases: object

This class keeps track of all of decoder-RNN-related variables that you need during decoding. This includes things like the current decoder hidden state, the memory cell (for LSTM decoders), the encoder output that you need for computing attentions, and so on.

This is intended to be used inside a DecoderState, which likely has other things it has to keep track of for doing constrained decoding.

Parameters:
hidden_state : torch.Tensor

This holds the LSTM hidden state, with shape (decoder_output_dim,).

memory_cell : torch.Tensor

This holds the LSTM memory cell, with shape (decoder_output_dim,).

previous_action_embedding : torch.Tensor

This holds the embedding for the action we took at the last timestep (which gets input to the decoder). Has shape (action_embedding_dim,).

attended_input : torch.Tensor

This holds the attention-weighted sum over the input representations that we computed in the previous timestep. We keep this as part of the state because we use the previous attention as part of our decoder cell update. Has shape (encoder_output_dim,).

encoder_outputs : List[torch.Tensor]

A list of variables, each of shape (input_sequence_length, encoder_output_dim), containing the encoder outputs at each timestep. The list is over batch elements, and we do the input this way so we can easily do a torch.cat on a list of indices into this batched list.

Note that all of the above parameters are single tensors, while the encoder outputs and mask are lists of length batch_size. We always pass around the encoder outputs and mask unmodified, regardless of what’s in the grouping for this state. We’ll use the batch_indices for the group to pull pieces out of these lists when we’re ready to actually do some computation.

encoder_output_mask : List[torch.Tensor]

A list of variables, each of shape (input_sequence_length,), containing a mask over question tokens for each batch instance. This is a list over batch elements, for the same reasons as above.

class allennlp.nn.decoding.grammar_state.GrammarState(nonterminal_stack: typing.List[str], lambda_stacks: typing.Dict[typing.Tuple[str, str], typing.List[str]], valid_actions: typing.Dict[str, typing.Dict[str, typing.Tuple[[torch.Tensor, torch.Tensor], typing.List[int]]]], context_actions: typing.Dict[str, typing.Tuple[[torch.Tensor, torch.Tensor], int]], is_nonterminal: typing.Callable[str, bool]) → None[source]

Bases: object

A GrammarState specifies the currently valid actions at every step of decoding.

If we had a global context-free grammar, this would not be necessary - the currently valid actions would always be the same, and we would not need to represent the current state. However, our grammar is not context free (we have lambda expressions that introduce context-dependent production rules), and it is not global (each instance can have its own entities of a particular type, or its own functions).

We thus recognize three different sources of valid actions. The first are actions that come from the type declaration; these are defined once by the model and shared across all GrammarStates produced by that model. The second are actions that come from the current instance; these are defined by the World that corresponds to each instance, and are shared across all decoding states for that instance. The last are actions that come from the current state of the decoder; these are updated after every action taken by the decoder, though only some actions initiate changes.

In practice, we use the World class to get the first two sources of valid actions at the same time, and we take as input a valid_actions dictionary that is computed by the World. These will not change during the course of decoding. The GrammarState object itself maintains the context-dependent valid actions.

Parameters:
nonterminal_stack : List[str]

Holds the list of non-terminals that still need to be expanded. This starts out as [START_SYMBOL], and decoding ends when this is empty. Every time we take an action, we update the non-terminal stack and the context-dependent valid actions, and we use what’s on the stack to decide which actions are valid in the current state.

lambda_stacks : Dict[Tuple[str, str], List[str]]

The lambda stack keeps track of when we’re in the scope of a lambda function. The dictionary is keyed by the production rule we are adding (like “r -> x”, separated into left hand side and right hand side, where the LHS is the type of the lambda variable and the RHS is the variable itself), and the value is a nonterminal stack much like nonterminal_stack. When the stack becomes empty, we remove the lambda entry.

valid_actions : Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]]

A mapping from non-terminals (represented as strings) to all valid expansions of that non-terminal. The way we represent the valid expansions is a little complicated: we use a dictionary of action types, where the key is the action type (like “global”, “linked”, or whatever your model is expecting), and the value is a tuple representing all actions of that type. The tuple is (input tensor, output tensor, action id). The input tensor has the representation that is used when selecting actions, for all actions of this type. The output tensor has the representation that is used when feeding the action to the next step of the decoder (this could just be the same as the input tensor). The action ids are a list of indices into the main action list for each batch instance.

context_actions : Dict[str, Tuple[torch.Tensor, torch.Tensor, int]]

Variable actions are never included in the valid_actions dictionary, because they are only valid depending on the current grammar state. This dictionary maps from the string representation of all such actions to the tensor representations of the actions. These will get added onto the “global” key in the valid_actions when they are allowed.

is_nonterminal : Callable[[str], bool]

A function that is used to determine whether each piece of the RHS of the action string is a non-terminal that needs to be added to the non-terminal stack. You can use type_declaraction.is_nonterminal here, or write your own function if that one doesn’t work for your domain.

get_valid_actions() → typing.Dict[str, typing.Tuple[[torch.Tensor, torch.Tensor], typing.List[int]]][source]

Returns the valid actions in the current grammar state. See the class docstring for a description of what we’re returning here.

is_finished() → bool[source]

Have we finished producing our logical form? We have finished producing the logical form if and only if there are no more non-terminals on the stack.

take_action(production_rule: str) → allennlp.nn.decoding.grammar_state.GrammarState[source]

Takes an action in the current grammar state, returning a new grammar state with whatever updates are necessary. The production rule is assumed to be formatted as “LHS -> RHS”.

This will update the non-terminal stack and the context-dependent actions. Updating the non-terminal stack involves popping the non-terminal that was expanded off of the stack, then pushing on any non-terminals in the production rule back on the stack. We push the non-terminals on in reverse order, so that the first non-terminal in the production rule gets popped off the stack first.

For example, if our current nonterminal_stack is ["r", "<e,r>", "d"], and action is d -> [<e,d>, e], the resulting stack will be ["r", "<e,r>", "e", "<e,d>"].

class allennlp.nn.decoding.checklist_state.ChecklistState(terminal_actions: torch.Tensor, checklist_target: torch.Tensor, checklist_mask: torch.Tensor, checklist: torch.Tensor, terminal_indices_dict: typing.Dict[int, int] = None) → None[source]

Bases: object

This class keeps track of checklist related variables that are used while training a coverage based semantic parser (or any other kind of transition based constrained decoder). This is inteded to be used within a DecoderState.

Parameters:
terminal_actions : torch.Tensor

A vector containing the indices of terminal actions, required for computing checklists for next states based on current actions. The idea is that we will build checklists corresponding to the presence or absence of just the terminal actions. But in principle, they can be all actions that are relevant to checklist computation.

checklist_target : torch.Tensor

Targets corresponding to checklist that indicate the states in which we want the checklist to ideally be. It is the same size as terminal_actions, and it contains 1 for each corresponding action in the list that we want to see in the final logical form, and 0 for each corresponding action that we do not.

checklist_mask : torch.Tensor

Mask corresponding to terminal_actions, indicating which of those actions are relevant for checklist computation. For example, if the parser is penalizing non-agenda terminal actions, all the terminal actions are relevant.

checklist : torch.Tensor

A checklist indicating how many times each action in its agenda has been chosen previously. It contains the actual counts of the agenda actions.

terminal_indices_dict: ``Dict[int, int]``, optional

Mapping from batch action indices to indices in any of the four vectors above. If not provided, this mapping will be computed here.

get_balance() → torch.Tensor[source]
update(action: torch.Tensor) → allennlp.nn.decoding.checklist_state.ChecklistState[source]

Takes an action index, updates checklist and returns an updated state.

class allennlp.nn.decoding.beam_search.BeamSearch(beam_size: int) → None[source]

Bases: allennlp.common.from_params.FromParams, typing.Generic

This class implements beam search over transition sequences given an initial DecoderState and a DecoderStep, returning the highest scoring final states found by the beam (the states will keep track of the transition sequence themselves).

The initial DecoderState is assumed to be batched. The value we return from the search is a dictionary from batch indices to ranked finished states.

IMPORTANT: We assume that the DecoderStep that you are using returns possible next states in sorted order, so we do not do an additional sort inside of BeamSearch.search(). If you’re implementing your own DecoderStep, you must ensure that you’ve sorted the states that you return.

search(num_steps: int, initial_state: StateType, decoder_step: allennlp.nn.decoding.decoder_step.DecoderStep, keep_final_unfinished_states: bool = True) → typing.Mapping[int, typing.Sequence[StateType]][source]
Parameters:
num_steps : int

How many steps should we take in our search? This is an upper bound, as it’s possible for the search to run out of valid actions before hitting this number, or for all states on the beam to finish.

initial_state : StateType

The starting state of our search. This is assumed to be batched, and our beam search is batch-aware - we’ll keep beam_size states around for each instance in the batch.

decoder_step : DecoderStep

The DecoderStep object that defines and scores transitions from one state to the next.

keep_final_unfinished_states : bool, optional (default=True)

If we run out of steps before a state is “finished”, should we return that state in our search results?

Returns:
best_states : Dict[int, List[StateType]]

This is a mapping from batch index to the top states for that instance.

class allennlp.nn.decoding.constrained_beam_search.ConstrainedBeamSearch(beam_size: typing.Union[int, NoneType], allowed_sequences: torch.Tensor, allowed_sequence_mask: torch.Tensor) → None[source]

Bases: object

This class implements beam search over transition sequences given an initial DecoderState, a DecoderStep, and a list of allowed transition sequences. We will do a beam search over the list of allowed sequences and return the highest scoring states found by the beam. This is only actually a beam search if your beam size is smaller than the list of allowed transition sequences; otherwise, we are just scoring and sorting the sequences using a prefix tree.

The initial DecoderState is assumed to be batched. The value we return from the search is a dictionary from batch indices to ranked finished states.

IMPORTANT: We assume that the DecoderStep that you are using returns possible next states in sorted order, so we do not do an additional sort inside of ConstrainedBeamSearch.search(). If you’re implementing your own DecoderStep, you must ensure that you’ve sorted the states that you return.

Parameters:
beam_size : Optional[int]

The beam size to use. Because this is a constrained beam search, we allow for the case where you just want to evaluate all options in the constrained set. In that case, you don’t need a beam, and you can pass a beam size of None, and we will just evaluate everything. This lets us be more efficient in DecoderStep.take_step() and skip the sorting that is typically done there.

allowed_sequences : torch.Tensor

A (batch_size, num_sequences, sequence_length) tensor containing the transition sequences that we will search in. The values in this tensor must match whatever the DecoderState keeps in its action_history variable (typically this is action indices).

allowed_sequence_mask : torch.Tensor

A (batch_size, num_sequences, sequence_length) tensor indicating whether each entry in the allowed_sequences tensor is padding. The allowed sequences could be padded both on the num_sequences dimension and the sequence_length dimension.

search(initial_state: allennlp.nn.decoding.decoder_state.DecoderState, decoder_step: allennlp.nn.decoding.decoder_step.DecoderStep) → typing.Dict[int, typing.List[allennlp.nn.decoding.decoder_state.DecoderState]][source]
Parameters:
initial_state : DecoderState

The starting state of our search. This is assumed to be batched, and our beam search is batch-aware - we’ll keep beam_size states around for each instance in the batch.

decoder_step : DecoderStep

The DecoderStep object that defines and scores transitions from one state to the next.

Returns:
best_states : Dict[int, List[DecoderState]]

This is a mapping from batch index to the top states for that instance.

allennlp.nn.decoding.chu_liu_edmonds.chu_liu_edmonds(length: int, score_matrix: numpy.ndarray, current_nodes: typing.List[bool], final_edges: typing.Dict[int, int], old_input: numpy.ndarray, old_output: numpy.ndarray, representatives: typing.List[typing.Set[int]])[source]

Applies the chu-liu-edmonds algorithm recursively to a graph with edge weights defined by score_matrix.

Note that this function operates in place, so variables will be modified.

Parameters:
length : int, required.

The number of nodes.

score_matrix : numpy.ndarray, required.

The score matrix representing the scores for pairs of nodes.

current_nodes : List[bool], required.

The nodes which are representatives in the graph. A representative at it’s most basic represents a node, but as the algorithm progresses, individual nodes will represent collapsed cycles in the graph.

final_edges: ``Dict[int, int]``, required.

An empty dictionary which will be populated with the nodes which are connected in the maximum spanning tree.

old_input: ``numpy.ndarray``, required.
old_output: ``numpy.ndarray``, required.
representatives : List[Set[int]], required.

A list containing the nodes that a particular node is representing at this iteration in the graph.

Returns:
Nothing - all variables are modified in place.
allennlp.nn.decoding.chu_liu_edmonds.decode_mst(energy: numpy.ndarray, length: int, has_labels: bool = True) → typing.Tuple[numpy.ndarray, numpy.ndarray][source]

Note: Counter to typical intuition, this function decodes the _maximum_ spanning tree.

Decode the optimal MST tree with the Chu-Liu-Edmonds algorithm for maximum spanning arboresences on graphs.

Parameters:
energy : numpy.ndarray, required.

A tensor with shape (num_labels, timesteps, timesteps) containing the energy of each edge. If has_labels is False, the tensor should have shape (timesteps, timesteps) instead.

length : int, required.

The length of this sequence, as the energy may have come from a padded batch.

has_labels : bool, optional, (default = True)

Whether the graph has labels or not.

allennlp.nn.decoding.util.construct_prefix_tree(targets: typing.Union[torch.Tensor, typing.List[typing.List[typing.List[int]]]], target_mask: typing.Union[torch.Tensor, NoneType] = None) → typing.List[typing.Dict[typing.Tuple[int, ...], typing.Set[int]]][source]

Takes a list of valid target action sequences and creates a mapping from all possible (valid) action prefixes to allowed actions given that prefix. While the method is called construct_prefix_tree, we’re actually returning a map that has as keys the paths to all internal nodes of the trie, and as values all of the outgoing edges from that node.

targets is assumed to be a tensor of shape (batch_size, num_valid_sequences, sequence_length). If the mask is not None, it is assumed to have the same shape, and we will ignore any value in targets that has a value of 0 in the corresponding position in the mask. We assume that the mask has the format 1*0* for each item in targets - that is, once we see our first zero, we stop processing that target.

For example, if targets is the following tensor: [[1, 2, 3], [1, 4, 5]], the return value will be: {(): set([1]), (1,): set([2, 4]), (1, 2): set([3]), (1, 4): set([5])}.

This could be used, e.g., to do an efficient constrained beam search, or to efficiently evaluate the probability of all of the target sequences.