SpanPruner(scorer: torch.nn.modules.module.Module) → None¶
This module scores and prunes span-based representations using a parameterised scoring function and a threshold.
A module which, given a tensor of shape (batch_size, num_spans, embedding_size), produces a tensor of shape (batch_size, num_spans, 1), representing a scalar score per span in the tensor.
forward(span_embeddings: torch.FloatTensor, span_mask: torch.LongTensor, num_spans_to_keep: int) → typing.Tuple[[torch.FloatTensor, torch.LongTensor, torch.LongTensor], torch.FloatTensor]¶
Extracts the top-k scoring spans with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that we can rely on the ordering to consider the previous k spans as antecedents for each span later.
A tensor of shape (batch_size, num_spans, embedding_size), representing the set of embedded span representations.
A tensor of shape (batch_size, num_spans), denoting unpadded elements of
The number of spans to keep when pruning.
The span representations of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep, embedding_size).
The coresponding mask for
top_span_embeddings. Has shape (batch_size, num_spans_to_keep).
The indices of the top-k scoring spans into the original
span_embeddingstensor. This is returned because it can be useful to retain pointers to the original spans, if each span is being scored by multiple distinct scorers, for instance. Has shape (batch_size, num_spans_to_keep).
The values of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep, 1).