Source code for span_marker.tokenizer

import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
from tokenizers.pre_tokenizers import Punctuation, Sequence
from transformers import AutoTokenizer, PreTrainedTokenizer, XLMRobertaTokenizerFast

from span_marker.configuration import SpanMarkerConfig

logger = logging.getLogger(__name__)


[docs] @dataclass class EntityTracker: """ For giving a warning about what percentage of entities are ignored/skipped. Example:: This SpanMarker model won't be able to predict 5.930931% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words and the maximum model input length of 64 tokens. These are the frequencies of the missed entities due to maximum entity length out of 1332 total entities: - 7 missed entities with 7 words (0.525526%) - 2 missed entities with 8 words (0.150150%) - 2 missed entities with 9 words (0.150150%) - 2 missed entities with 13 words (0.150150%) Additionally, a total of 66 (4.954955%) entities were missed due to the maximum input length. """ entity_max_length: int model_max_length: int split: str = "train" # or "evaluation" or "test" total_num_entities: int = 0 skipped_entities: Dict[int, int] = field(default_factory=lambda: defaultdict(int)) enabled: bool = False def __call__(self, split: Optional[str] = None) -> None: """Update the current split, which affects the warning message. Example: with tokenizer.entity_tracker(split=dataset_name): dataset = dataset.map( tokenizer, ... ) Args: split (Optional[str]): The new split string, either "train", "evaluation" or "test". Defaults to None. Returns: Self: The EntityTracker instance. """ if split: self.split = split return self def __enter__(self) -> None: """Start tracking (ignored) entities on enter.""" self.enabled = True return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Trigger the ignored entities warning on exit.""" if not self.skipped_entities: return self.reset() entity_max_length_missed_freq = sorted( [(key, value) for key, value in self.skipped_entities.items() if key > self.entity_max_length], key=lambda x: x[0], ) model_max_length_missed_total = sum( value for key, value in self.skipped_entities.items() if key <= self.entity_max_length ) total_num_missed_entities = sum(self.skipped_entities.values()) if self.split == "train": message = "This SpanMarker model will ignore" else: message = "This SpanMarker model won't be able to predict" message += ( f" {total_num_missed_entities/self.total_num_entities:%} of all annotated entities in the {self.split}" " dataset. This is caused by the SpanMarkerModel " ) if entity_max_length_missed_freq: message += ( f"maximum entity length of {self.entity_max_length} word{'s' if self.entity_max_length > 1 else ''}" ) if model_max_length_missed_total: message += " and the " if model_max_length_missed_total: message += ( f"maximum model input length of {self.model_max_length} token{'s' if self.model_max_length > 1 else ''}" ) message += "." if entity_max_length_missed_freq: message += f"\nThese are the frequencies of the missed entities due to maximum entity length out of {self.total_num_entities} total entities:\n" message += "\n".join( [ f"- {freq} missed entities with {length} word{'s' if length > 1 else ''}" f" ({freq / self.total_num_entities:%})" for length, freq in entity_max_length_missed_freq ] ) if model_max_length_missed_total: if entity_max_length_missed_freq: message += "\nAdditionally, a " else: message += "\nA " message += ( f"total of {model_max_length_missed_total} ({model_max_length_missed_total / self.total_num_entities:%})" " entities were missed due to the maximum input length." ) logger.warning(message) self.reset()
[docs] def add(self, num_entities: int) -> None: """Add to the counter of total number of entities. Args: num_entities (int): How many entities to increment by. """ self.total_num_entities += num_entities
[docs] def missed(self, length: int) -> None: """Add to the counter of missed/ignored/skipped entities. Args: length (int): How many entities were missed. """ self.skipped_entities[length] += 1
[docs] def reset(self) -> None: """Reset to defaults, stops tracking.""" self.total_num_entities = 0 self.skipped_entities = defaultdict(int) self.enabled = False
[docs] class SpanMarkerTokenizer: def __init__(self, tokenizer: PreTrainedTokenizer, config: SpanMarkerConfig, **kwargs) -> None: self.tokenizer = tokenizer self.config = config tokenizer.add_tokens(["<start>", "<end>"], special_tokens=True) self.start_marker_id, self.end_marker_id = self.tokenizer.convert_tokens_to_ids(["<start>", "<end>"]) if self.tokenizer.model_max_length > 1e29 and self.config.model_max_length is None: logger.warning( f"The underlying {self.tokenizer.__class__.__name__!r} tokenizer nor {self.config.__class__.__name__!r}" f" specify `model_max_length`: defaulting to {self.config.model_max_length_default} tokens." ) self.model_max_length = min( self.tokenizer.model_max_length, self.config.model_max_length or self.config.model_max_length_default ) self.entity_tracker = EntityTracker(self.config.entity_max_length, self.model_max_length)
[docs] def get_all_valid_spans(self, num_words: int, entity_max_length: int) -> Iterator[Tuple[int, int]]: for start_idx in range(num_words): for end_idx in range(start_idx + 1, min(num_words + 1, start_idx + 1 + entity_max_length)): yield (start_idx, end_idx)
[docs] def get_all_valid_spans_and_labels( self, num_words: int, span_to_label: Dict[Tuple[int, int], int], entity_max_length: int, outside_id: int ) -> Iterator[Tuple[Tuple[int, int], int]]: for span in self.get_all_valid_spans(num_words, entity_max_length): yield span, span_to_label.pop(span, outside_id)
def __getattribute__(self, key: str) -> Any: try: return super().__getattribute__(key) except AttributeError: return super().__getattribute__("tokenizer").__getattribute__(key) def __call__( self, batch: Dict[str, List[Any]], return_num_words: bool = False, return_batch_encoding=False, **kwargs ) -> Dict[str, List]: tokens = batch["tokens"] labels = batch.get("ner_tags", None) is_split_into_words = True if isinstance(tokens, str): is_split_into_words = False elif tokens: for token in tokens: if " " in token: is_split_into_words = False break batch_encoding = self.tokenizer( tokens, **kwargs, is_split_into_words=is_split_into_words, padding="max_length", truncation=True, max_length=self.model_max_length, return_tensors="pt", ) all_input_ids = [] all_num_spans = [] all_start_position_ids = [] all_end_position_ids = [] all_labels = [] all_num_words = [] for sample_idx, input_ids in enumerate(batch_encoding["input_ids"]): max_word_ids = np.nanmax(np.array(batch_encoding.word_ids(sample_idx), dtype=float)) if np.isnan(max_word_ids): raise ValueError("The `SpanMarkerTokenizer` detected an empty sentence, please remove it.") num_words = int(max_word_ids) + 1 if self.tokenizer.pad_token_id in input_ids: num_tokens = list(input_ids).index(self.tokenizer.pad_token_id) else: num_tokens = len(input_ids) if labels: span_to_label = {(start_idx, end_idx): label for label, start_idx, end_idx in labels[sample_idx]} if self.entity_tracker.enabled: self.entity_tracker.add(len(span_to_label)) spans, span_labels = zip( *list( self.get_all_valid_spans_and_labels( num_words, span_to_label, self.config.entity_max_length, self.config.outside_id ) ) ) # `self.get_all_valid_spans_and_labels` popped `span_to_label`, so if it's non-empty, then that # entity was ignored, and we may want to track it for a useful warning if self.entity_tracker.enabled: for start, end in span_to_label.keys(): self.entity_tracker.missed(end - start) else: spans = list(self.get_all_valid_spans(num_words, self.config.entity_max_length)) start_position_ids, end_position_ids = [], [] for start_word_i, end_word_i in spans: start_token_span = batch_encoding.word_to_tokens(sample_idx, word_index=start_word_i) # The if ... else 0 exists because of words like '\u2063' start_position_ids.append(start_token_span.start if start_token_span else 0) end_token_span = batch_encoding.word_to_tokens(sample_idx, word_index=end_word_i - 1) end_position_ids.append(end_token_span.end - 1 if end_token_span else 0) all_input_ids.append(input_ids[:num_tokens].tolist()) all_num_spans.append(len(spans)) all_start_position_ids.append(start_position_ids) all_end_position_ids.append(end_position_ids) if labels: all_labels.append(span_labels) if return_num_words: all_num_words.append(num_words) output = { "input_ids": all_input_ids, "num_spans": all_num_spans, "start_position_ids": all_start_position_ids, "end_position_ids": all_end_position_ids, } if labels: output["labels"] = all_labels if return_num_words: # Store the number of words, useful for computing the spans in the evaluation and model.predict() method output["num_words"] = all_num_words if return_batch_encoding: # Store the batch encoding, useful for converting word IDs to characters in the model.predict() method output["batch_encoding"] = batch_encoding return output def __len__(self) -> int: return len(self.tokenizer)
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *inputs, config=None, **kwargs): tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, *inputs, **kwargs, add_prefix_space=True ) # XLM-R is known to have some tokenization issues, so be sure to also split on punctuation. # Strictly required for inference, shouldn't affect training. if isinstance(tokenizer, XLMRobertaTokenizerFast): tokenizer._tokenizer.pre_tokenizer = Sequence([Punctuation(), tokenizer._tokenizer.pre_tokenizer]) return cls(tokenizer, config=config, **kwargs)