Use Span Representations
Using Span Representations in AllenNLP#
Note that this tutorial goes through some quite advanced usage of AllenNLP - you may want to familiarize yourself with the repository before you go through this Span Representation Tutorial.
Many state of the art Deep NLP models use representations of spans, rather than representations of words, as the basic building block for models. In AllenNLP (starting from version 0.4), Span Representations are extremely easy to use in your model.
Examples of papers which contain span representations include:
- End to End Neural Coreference Resolution
- A Minimal Span Based Neural Constituency Parser
- Learning Recurrent Span Representations for Extractive Question Answering
- Frame-Semantic Parsing with Softmax-Margin Segmental RNNs and a Syntactic Scaffold
- Segmental Recurrent Neural Networks
In order to use span representations in your model, there are three things you probably need to think about: (1) enumerating all possible spans in a DatasetReader as input to your model; (2) extracting embedded span representations for the span indices and (3) pruning the spans in your model to only keep the most promising ones; We'll describe how to do each of these steps.
SpanFields from text in a
SpanFields are a type of
Field in AllenNLP which take a start index, an end index
SequenceField which the indices refer to. Once a batch of
SpanFields has been
converted to a tensor, we will have a matrix of shape (batch_size, 2), where the last
dimension contains the start and end indices passed in to the SpanField constructor.
However, for many models, you'll want to represent many spans for a single batch
element - the way to do this is to use a
ListField[SpanFields], which will create
a tensor of shape (batch_size, num_spans, 2) once indexed.
Extracting Span Representations from a text sequence#
In many cases, you will want to extract spans from vector representations of sentences.
In order to do this in AllenNLP, you will need to use a [
SpanExtractor]. Broadly, a
SpanExtractor takes a sequence tensor of shape
(batch_size, sentence_length, embedding_size) and some indices of shape
(batch_size, num_spans, 2) and returns an encoded representation of each span as a tensor of shape
(batch_size, num_spans, encoded_size).
SpanExtractor is the
EndpointSpanExtractor, which represents spans as a combination of the embeddings of their endpoints.
import torch from torch.autograd import Variable from allennlp.modules.span_extractors import EndpointSpanExtractor sequence_tensor = Variable(torch.randn([2, 5, 7])) # Concatentate start and end points together to form our representation. extractor = EndpointSpanExtractor(input_dim=7, combination="x,y") # Typically these would come from a SpanField, # rather than being created directly. indices = Variable(torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])) # We concatenated the representations for the start and end of # the span, so the embedded span size is 2 * embedding_size. # Shape (batch_size, num_spans, 2 * embedding_size). span_representations = extractor(sequence_tensor, indices) assert list(span_representations.size()) == [2, 2, 14]
There are other types of Span Extractors - for instance, the
which computes span representations by generating an unnormalized attention score for each
word in the sentence. Spans representations are then computed with respect to these
scores by normalising the attention scores for words inside the span.
Scoring and Pruning Spans#
Span-based representations have been effective for modeling/approximating structured prediction problems - however, many models which leverage this type of representation also involve some kind of span enumeration (i.e considering all possible spans in a sentence/document). For a given sentence of length n, there are n2 spans. In itself, this is not too problematic, but for instance, the co-reference model in AllenNLP compares pairs of spans - meaning that naively we consider n4 spans, with potential document lengths of upwards of 3000 tokens.
In order to solve this problem, we need to be able to prune spans as we go inside our model. There are several ways to do this:
Heuristically prune spans in your DatasetReader.#
We have added a utility method for enumerating all spans in a sentence, but excluding those which fulfil some condition based on the input text or any Spacy
For instance, for co-reference, all spans which are mentions (spans which are co-referent with something) never start or end with punctuation, or occur across sentence boundaries because of the way the Onotonotes 5.0 dataset was created. This means that we can exclude any span which
starts or ends with punctuation using a very simple python function:
from typing import List from allennlp.data.dataset_readers.dataset_utils import span_utils from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer from allennlp.data.tokenizers.token import Token tokenizer = SpacyTokenizer(pos_tags=True) sentence = tokenizer.tokenize("This is a sentence.") def no_prefixed_punctuation(tokens: List[Token]) -> bool: # Only include spans which don't start or end with punctuation. return tokens.pos_ != "PUNCT" and tokens[-1].pos_ != "PUNCT" spans = span_utils.enumerate_spans(sentence, max_span_width=3, min_span_width=2, filter_function=no_prefixed_punctuation) # 'spans' won't include (2, 4) or (3, 4) as these have # punctuation as their last element. Note that these spans # have inclusive start and end indices! assert spans == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)]
There are other helpful functions in
such as a function to convert between BIO labelings and span-based representations.
Use a Pruner#
It's not always possible to prune spans before they enter your model. AllenNLP contains
Pruner, which allows you to prune spans based on a parameterized function which
is trained end-to-end with the rest of your model.
import torch from torch.autograd import Variable from allennlp.modules import Pruner # Create a linear layer which will score our spans. linear_scorer = torch.nn.Linear(5, 1) pruner = Pruner(scorer=linear_scorer) # Here we'll create some spans from a random tensor of shape # (batch_size, num_spans, embedding_size). Typically this would # be the output of a SpanExtractor applied to some encoded representation # of a sentence, such as the output of an LSTM, or word embeddings. spans = Variable(torch.randn([3, 4, 5])) mask = Variable(torch.ones([3, 4])) # There's quite a bit to unpack here. # See below for a full explanation. pruned_embeddings, pruned_mask, pruned_indices, pruned_scores = pruner(spans, mask, num_items_to_keep=3)
Pruner has four return values:
First, we've got our
pruned_embeddings. These are of shape
(batch_size, num_items_to_keep, embedding_size)The spans we kept correspond to the top k with respect to the parameterized span scorer. The other spans just get discarded, and your eventual loss function for your model won't be a function of the discarded spans!
Secondly, we've got the
pruned_mask, which has shape
(batch_size, num_items_to_keep). In 99% of cases, this will be all ones. However, if you have masked spans in a batch element, and you request that the
Prunerkeeps more than the number of non-masked spans, there will be some masked elements in the returned spans.
Thirdly, we have the
pruned_indiceswhich has shape
(batch_size, num_items_to_keep)which are the indices of the top k scoring spans in the original
spanstensor. This is returned because it can be useful to retain pointers to the original spans, if each span is being scored by multiple distinct scorers, such as in the co-reference model, for instance.
Finally, we have the
pruned_scores, which has shape
(batch_size, num_items_to_keep, 1). This is returned so that you can incorporate the scores of the spans into some loss function.
Existing AllenNLP examples for generating
We've already started using
SpanFields in AllenNLP - you can see some examples in the
Coreference DatasetReader, where we enumerate all possible spans in sentences
of a document, or in the
PennTreeBankConstituencySpanDatasetReader in order to
classify whether or not they are constituents in a constitutency parse of the sentence.