allennlp.state_machines

This module contains code for using state machines in a model to do transition-based decoding. “Transition-based decoding” is where you start in some state, iteratively transition between states, and have some kind of supervision signal that tells you which end states, or which transition sequences, are “good”.

Typical seq2seq decoding, where you have a fixed vocabulary and no constraints on your output, can be done much more efficiently than we do in this code. This is intended for structured models that have constraints on their outputs.

The key abstractions in this code are the following:

  • State represents the current state of decoding, containing a list of all of the actions taken so far, and a current score for the state. It also has methods around determining whether the state is “finished” and for combining states for batched computation.
  • TransitionFunction is a torch.nn.Module that models the transition function between states. Its main method is take_step, which generates a ranked list of next states given a current state.
  • DecoderTrainer is an algorithm for training the transition function with some kind of supervision signal. There are many options for training algorithms and supervision signals; this is an abstract class that is generic over the type of the supervision signal.

There is also a generic BeamSearch class for finding the k highest-scoring transition sequences given a trained TransitionFunction and an initial State.

class allennlp.state_machines.beam_search.BeamSearch(beam_size: int, per_node_beam_size: int = None) → None[source]

Bases: allennlp.common.from_params.FromParams, typing.Generic

This class implements beam search over transition sequences given an initial State and a TransitionFunction, returning the highest scoring final states found by the beam (the states will keep track of the transition sequence themselves).

The initial State is assumed to be batched. The value we return from the search is a dictionary from batch indices to ranked finished states.

IMPORTANT: We assume that the TransitionFunction that you are using returns possible next states in sorted order, so we do not do an additional sort inside of BeamSearch.search(). If you’re implementing your own TransitionFunction, you must ensure that you’ve sorted the states that you return.

Parameters:
beam_size : int

The beam size to use.

per_node_beam_size : int, optional (default = beam_size)

The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Freitag and Al-Onaizan 2017, “Beam Search Strategies for Neural Machine Translation”.

search(num_steps: int, initial_state: StateType, transition_function: allennlp.state_machines.transition_functions.transition_function.TransitionFunction, keep_final_unfinished_states: bool = True) → typing.Mapping[int, typing.Sequence[StateType]][source]
Parameters:
num_steps : int

How many steps should we take in our search? This is an upper bound, as it’s possible for the search to run out of valid actions before hitting this number, or for all states on the beam to finish.

initial_state : StateType

The starting state of our search. This is assumed to be batched, and our beam search is batch-aware - we’ll keep beam_size states around for each instance in the batch.

transition_function : TransitionFunction

The TransitionFunction object that defines and scores transitions from one state to the next.

keep_final_unfinished_states : bool, optional (default=True)

If we run out of steps before a state is “finished”, should we return that state in our search results?

Returns:
best_states : Dict[int, List[StateType]]

This is a mapping from batch index to the top states for that instance.

class allennlp.state_machines.constrained_beam_search.ConstrainedBeamSearch(beam_size: typing.Union[int, NoneType], allowed_sequences: torch.Tensor, allowed_sequence_mask: torch.Tensor, per_node_beam_size: int = None) → None[source]

Bases: object

This class implements beam search over transition sequences given an initial State, a TransitionFunction, and a list of allowed transition sequences. We will do a beam search over the list of allowed sequences and return the highest scoring states found by the beam. This is only actually a beam search if your beam size is smaller than the list of allowed transition sequences; otherwise, we are just scoring and sorting the sequences using a prefix tree.

The initial State is assumed to be batched. The value we return from the search is a dictionary from batch indices to ranked finished states.

IMPORTANT: We assume that the TransitionFunction that you are using returns possible next states in sorted order, so we do not do an additional sort inside of ConstrainedBeamSearch.search(). If you’re implementing your own TransitionFunction, you must ensure that you’ve sorted the states that you return.

Parameters:
beam_size : Optional[int]

The beam size to use. Because this is a constrained beam search, we allow for the case where you just want to evaluate all options in the constrained set. In that case, you don’t need a beam, and you can pass a beam size of None, and we will just evaluate everything. This lets us be more efficient in TransitionFunction.take_step() and skip the sorting that is typically done there.

allowed_sequences : torch.Tensor

A (batch_size, num_sequences, sequence_length) tensor containing the transition sequences that we will search in. The values in this tensor must match whatever the State keeps in its action_history variable (typically this is action indices).

allowed_sequence_mask : torch.Tensor

A (batch_size, num_sequences, sequence_length) tensor indicating whether each entry in the allowed_sequences tensor is padding. The allowed sequences could be padded both on the num_sequences dimension and the sequence_length dimension.

per_node_beam_size : int, optional (default = beam_size)

The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Freitag and Al-Onaizan 2017, “Beam Search Strategies for Neural Machine Translation”.

search(initial_state: allennlp.state_machines.states.state.State, transition_function: allennlp.state_machines.transition_functions.transition_function.TransitionFunction) → typing.Dict[int, typing.List[allennlp.state_machines.states.state.State]][source]
Parameters:
initial_state : State

The starting state of our search. This is assumed to be batched, and our beam search is batch-aware - we’ll keep beam_size states around for each instance in the batch.

transition_function : TransitionFunction

The TransitionFunction object that defines and scores transitions from one state to the next.

Returns:
best_states : Dict[int, List[State]]

This is a mapping from batch index to the top states for that instance.

allennlp.state_machines.util.construct_prefix_tree(targets: typing.Union[torch.Tensor, typing.List[typing.List[typing.List[int]]]], target_mask: typing.Union[torch.Tensor, NoneType] = None) → typing.List[typing.Dict[typing.Tuple[int, ...], typing.Set[int]]][source]

Takes a list of valid target action sequences and creates a mapping from all possible (valid) action prefixes to allowed actions given that prefix. While the method is called construct_prefix_tree, we’re actually returning a map that has as keys the paths to all internal nodes of the trie, and as values all of the outgoing edges from that node.

targets is assumed to be a tensor of shape (batch_size, num_valid_sequences, sequence_length). If the mask is not None, it is assumed to have the same shape, and we will ignore any value in targets that has a value of 0 in the corresponding position in the mask. We assume that the mask has the format 1*0* for each item in targets - that is, once we see our first zero, we stop processing that target.

For example, if targets is the following tensor: [[1, 2, 3], [1, 4, 5]], the return value will be: {(): set([1]), (1,): set([2, 4]), (1, 2): set([3]), (1, 4): set([5])}.

This could be used, e.g., to do an efficient constrained beam search, or to efficiently evaluate the probability of all of the target sequences.