span_marker.data_collator module¶
- class span_marker.data_collator.SpanMarkerDataCollator(tokenizer, marker_max_length)[source]¶
Bases:
object
Data Collator class responsible for converting the minimal outputs from the tokenizer into complete and meaningful inputs to the model. In particular, the
input_ids
from the tokenizer features are padded, and the correct amount of start and end markers (with padding) are added.Furthermore, the position IDs are generated for the input IDs, and
start_position_ids
andend_position_ids
are used alongside some padding to create a full position ID vector.Lastly, the attention matrix is computed.
The expected usage is something like:
>>> collator = SpanMarkerDataCollator(...) >>> tokenized = tokenizer(...) >>> batch = collator(tokenized) >>> output = model(**batch)
- Parameters:
tokenizer (SpanMarkerTokenizer) –
marker_max_length (int) –
- tokenizer: SpanMarkerTokenizer¶
- __call__(features)[source]¶
Convert the minimal tokenizer outputs into inputs ready for
forward()
.- Parameters:
features (List[Dict[str, Any]]) –
A list of dictionaries, one element per sample in the batch. The dictionaries contain the following keys:
input_ids
: The non-padded input IDs.num_spans
: The number of spans that should be encoded in each sample.start_position_ids
: The position IDs of the start markers in the sample.end_position_ids
: The position IDs of the end markers in the sample.labels
(optional): The labels corresponding to each of the spans in the sample.num_words
(optional): The number of words in the input sample.Required for some evaluation metrics.
- Returns:
Batch dictionary ready to be fed into
forward()
.- Return type:
Dict[str, torch.Tensor]