span_marker.data_collator module¶
- class span_marker.data_collator.SpanMarkerDataCollator(tokenizer, marker_max_length)[source]¶
Bases:
objectData Collator class responsible for converting the minimal outputs from the tokenizer into complete and meaningful inputs to the model. In particular, the
input_idsfrom the tokenizer features are padded, and the correct amount of start and end markers (with padding) are added.Furthermore, the position IDs are generated for the input IDs, and
start_position_idsandend_position_idsare used alongside some padding to create a full position ID vector.Lastly, the attention matrix is computed.
The expected usage is something like:
>>> collator = SpanMarkerDataCollator(...) >>> tokenized = tokenizer(...) >>> batch = collator(tokenized) >>> output = model(**batch)
- Parameters:
tokenizer (SpanMarkerTokenizer) –
marker_max_length (int) –
- tokenizer: SpanMarkerTokenizer¶
- __call__(features)[source]¶
Convert the minimal tokenizer outputs into inputs ready for
forward().- Parameters:
features (List[Dict[str, Any]]) –
A list of dictionaries, one element per sample in the batch. The dictionaries contain the following keys:
input_ids: The non-padded input IDs.num_spans: The number of spans that should be encoded in each sample.start_position_ids: The position IDs of the start markers in the sample.end_position_ids: The position IDs of the end markers in the sample.labels(optional): The labels corresponding to each of the spans in the sample.num_words(optional): The number of words in the input sample.Required for some evaluation metrics.
- Returns:
Batch dictionary ready to be fed into
forward().- Return type:
Dict[str, torch.Tensor]