Source code for span_marker.data_collator

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List

import torch
from torch.nn import functional as F

from span_marker.tokenizer import SpanMarkerTokenizer


[docs] @dataclass class SpanMarkerDataCollator: """ Data Collator class responsible for converting the minimal outputs from the tokenizer into complete and meaningful inputs to the model. In particular, the ``input_ids`` from the tokenizer features are padded, and the correct amount of start and end markers (with padding) are added. Furthermore, the position IDs are generated for the input IDs, and ``start_position_ids`` and ``end_position_ids`` are used alongside some padding to create a full position ID vector. Lastly, the attention matrix is computed. The expected usage is something like: >>> collator = SpanMarkerDataCollator(...) >>> tokenized = tokenizer(...) >>> batch = collator(tokenized) >>> output = model(**batch) """ tokenizer: SpanMarkerTokenizer marker_max_length: int
[docs] def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: """Convert the minimal tokenizer outputs into inputs ready for :meth:`~span_marker.modeling.SpanMarkerModel.forward`. Args: features (List[Dict[str, Any]]): A list of dictionaries, one element per sample in the batch. The dictionaries contain the following keys: * ``input_ids``: The non-padded input IDs. * ``num_spans``: The number of spans that should be encoded in each sample. * ``start_position_ids``: The position IDs of the start markers in the sample. * ``end_position_ids``: The position IDs of the end markers in the sample. * ``labels`` (optional): The labels corresponding to each of the spans in the sample. * ``num_words`` (optional): The number of words in the input sample. Required for some evaluation metrics. Returns: Dict[str, torch.Tensor]: Batch dictionary ready to be fed into :meth:`~span_marker.modeling.SpanMarkerModel.forward`. """ total_size = self.tokenizer.model_max_length + 2 * self.marker_max_length batch = defaultdict(list) num_words = [] document_ids = [] sentence_ids = [] start_marker_indices = [] num_marker_pairs = [] for sample in features: input_ids = sample["input_ids"] num_spans = sample["num_spans"] num_tokens = len(input_ids) # The start markers start after the input IDs, rounded up to the nearest even number start_marker_idx = num_tokens + num_tokens % 2 end_marker_idx = start_marker_idx + num_spans # Prepare input_ids by padding and adding start and end markers if not isinstance(input_ids, torch.Tensor): input_ids = torch.tensor(input_ids, dtype=torch.int) else: input_ids.to(torch.int) input_ids = F.pad(input_ids, (0, total_size - len(input_ids)), value=self.tokenizer.pad_token_id) input_ids[start_marker_idx : start_marker_idx + num_spans] = self.tokenizer.start_marker_id input_ids[end_marker_idx : end_marker_idx + num_spans] = self.tokenizer.end_marker_id batch["input_ids"].append(input_ids) # Prepare position IDs position_ids = torch.arange(num_tokens, dtype=torch.int) + 2 position_ids = F.pad(position_ids, (0, total_size - len(position_ids)), value=1) position_ids[start_marker_idx : start_marker_idx + num_spans] = ( torch.tensor(sample["start_position_ids"]) + 2 ) position_ids[end_marker_idx : end_marker_idx + num_spans] = torch.tensor(sample["end_position_ids"]) + 2 # Increase the position_ids by 2, inspired by PL-Marker. The intuition is that these position IDs # better match the circumstances under which the underlying encoders are trained. batch["position_ids"].append(position_ids) # Prepare attention mask matrix attention_mask = torch.zeros((total_size, total_size), dtype=torch.bool) # text tokens self-attention attention_mask[:num_tokens, :num_tokens] = 1 # let markers attend text tokens attention_mask[start_marker_idx : start_marker_idx + num_spans, :num_tokens] = 1 attention_mask[end_marker_idx : end_marker_idx + num_spans, :num_tokens] = 1 # self-attentions of start/end markers start_index_list = list(range(start_marker_idx, start_marker_idx + num_spans)) end_index_list = list(range(end_marker_idx, end_marker_idx + num_spans)) attention_mask[start_index_list, start_index_list] = 1 attention_mask[start_index_list, end_index_list] = 1 attention_mask[end_index_list, start_index_list] = 1 attention_mask[end_index_list, end_index_list] = 1 batch["attention_mask"].append(attention_mask) # Add start of the markers, so the model knows where the input IDs end and where the markers start start_marker_indices.append(start_marker_idx) num_marker_pairs.append(end_marker_idx - start_marker_idx) if "num_words" in sample: num_words.append(sample["num_words"]) if "document_id" in sample: document_ids.append(sample["document_id"]) if "sentence_id" in sample: sentence_ids.append(sample["sentence_id"]) if "labels" in sample: labels = torch.tensor(sample["labels"]) labels = F.pad(labels, (0, (total_size // 2) - len(labels)), value=-100) batch["labels"].append(labels) batch = {key: torch.stack(value) for key, value in batch.items()} # Used for evaluation, does not need to be padded/stacked if num_words: batch["num_words"] = torch.tensor(num_words) if document_ids: batch["document_ids"] = torch.tensor(document_ids) if sentence_ids: batch["sentence_ids"] = torch.tensor(sentence_ids) batch["start_marker_indices"] = torch.tensor(start_marker_indices) batch["num_marker_pairs"] = torch.tensor(num_marker_pairs) return batch