allennlp.modules.span_pruner

class allennlp.modules.span_pruner.SpanPruner(scorer: torch.nn.modules.module.Module) → None[source]

Bases: torch.nn.modules.module.Module

This module scores and prunes span-based representations using a parameterised scoring function and a threshold.

Parameters:
scorer : torch.nn.Module, required.

A module which, given a tensor of shape (batch_size, num_spans, embedding_size), produces a tensor of shape (batch_size, num_spans, 1), representing a scalar score per span in the tensor.

forward(span_embeddings: torch.FloatTensor, span_mask: torch.LongTensor, num_spans_to_keep: int) → typing.Tuple[[torch.FloatTensor, torch.LongTensor, torch.LongTensor], torch.FloatTensor][source]

Extracts the top-k scoring spans with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that we can rely on the ordering to consider the previous k spans as antecedents for each span later.

Parameters:
span_embeddings : torch.FloatTensor, required.

A tensor of shape (batch_size, num_spans, embedding_size), representing the set of embedded span representations.

span_mask : torch.LongTensor, required.

A tensor of shape (batch_size, num_spans), denoting unpadded elements of span_embeddings.

num_spans_to_keep : int, required.

The number of spans to keep when pruning.

Returns:
top_span_embeddings : torch.FloatTensor

The span representations of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep, embedding_size).

top_span_mask : torch.LongTensor

The coresponding mask for top_span_embeddings. Has shape (batch_size, num_spans_to_keep).

top_span_indices : torch.IntTensor

The indices of the top-k scoring spans into the original span_embeddings tensor. This is returned because it can be useful to retain pointers to the original spans, if each span is being scored by multiple distinct scorers, for instance. Has shape (batch_size, num_spans_to_keep).

top_span_scores : torch.FloatTensor

The values of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep, 1).