import dataclasses
import logging
import math
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from transformers import (
EvalPrediction,
TrainerCallback,
TrainingArguments,
)
from transformers import Trainer as TransformersTrainer
from transformers.trainer_utils import PredictionOutput
from span_marker.evaluation import compute_f1_via_seqeval
from span_marker.label_normalizer import AutoLabelNormalizer, LabelNormalizer
from span_marker.model_card import ModelCardCallback
from span_marker.modeling import SpanMarkerModel
from span_marker.tokenizer import SpanMarkerTokenizer
logger = logging.getLogger(__name__)
[docs]
class Trainer(TransformersTrainer):
"""
Trainer is a simple but feature-complete training and eval loop for SpanMarker,
built tightly on top of the 🤗 Transformers :external:doc:`Trainer <main_classes/trainer>`.
Args:
model (Optional[SpanMarkerModel]):
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
args (Optional[~transformers.TrainingArguments]):
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments` with the
``output_dir`` set to a directory named *models/my_span_marker_model* in the current directory if not provided.
train_dataset (Optional[~datasets.Dataset]):
The dataset to use for training. Must contain ``tokens`` and ``ner_tags`` columns, and may contain
``document_id`` and ``sentence_id`` columns for document-level context during training.
eval_dataset (Optional[~datasets.Dataset]):
The dataset to use for evaluation. Must contain ``tokens`` and ``ner_tags`` columns, and may contain
``document_id`` and ``sentence_id`` columns for document-level context during evaluation.
model_init (Optional[Callable[[], SpanMarkerModel]]):
A function that instantiates the model to be used. If provided, each call to :meth:`Trainer.train` will start
from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
inner layers, dropout probabilities etc).
compute_metrics (Optional[Callable[[~transformers.EvalPrediction], Dict]]):
The function that will be used to compute metrics at evaluation. Must take a :class:`~transformers.EvalPrediction` and return
a dictionary string to metric values.
callbacks (Optional[List[~transformers.TrainerCallback]]):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in the Hugging Face :external:doc:`Callback documentation <main_classes/callback>`.
If you want to remove one of the default callbacks used, use the :meth:`~Trainer.remove_callback` method.
optimizers (Tuple[Optional[~torch.optim.Optimizer], Optional[~torch.optim.lr_scheduler.LambdaLR]]): A tuple
containing the optimizer and the scheduler to use. Will default to an instance of ``AdamW`` on your model
and a scheduler given by ``get_linear_schedule_with_warmup`` controlled by ``args``.
preprocess_logits_for_metrics (Optional[Callable[[~torch.Tensor, ~torch.Tensor], ~torch.Tensor]]):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by ``compute_metrics``.
Note that the labels (second parameter) will be ``None`` if the dataset does not have them.
Important attributes:
- **model** -- Always points to the core model.
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
the inner model is wrapped in ``DeepSpeed`` and then again in :class:`torch.nn.DistributedDataParallel`. If the
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to ``False`` if model parallel or deepspeed is used, or if the default
`TrainingArguments.place_model_on_device` is overridden to return ``False``.
- **is_in_train** -- Whether or not a model is currently running :meth:`~Trainer.train` (e.g. when ``evaluate`` is called while
in ``train``)
"""
REQUIRED_COLUMNS: Tuple[str] = ("tokens", "ner_tags")
OPTIONAL_COLUMNS: Tuple[str] = ("document_id", "sentence_id")
def __init__(
self,
model: Optional[SpanMarkerModel] = None,
args: Optional[TrainingArguments] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
model_init: Callable[[], SpanMarkerModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> None:
# Extract the model from an initializer function
if model_init:
self.model_init = model_init
model = self.call_model_init()
# To convert dataset labels to a common format (list of label-start-end tuples)
self.label_normalizer = AutoLabelNormalizer.from_config(model.config)
# Set some Training arguments that must be set for SpanMarker
if args is None:
args = TrainingArguments(
output_dir="models/my_span_marker_model", include_inputs_for_metrics=True, remove_unused_columns=False
)
else:
args = dataclasses.replace(args, include_inputs_for_metrics=True, remove_unused_columns=False)
# Always compute `compute_f1_via_seqeval` - optionally compute user-provided metrics
if compute_metrics is not None:
compute_metrics_func = lambda eval_prediction: {
**compute_f1_via_seqeval(model.tokenizer, eval_prediction, self.is_in_train),
**compute_metrics(eval_prediction),
}
else:
compute_metrics_func = lambda eval_prediction: compute_f1_via_seqeval(
model.tokenizer, eval_prediction, self.is_in_train
)
# If the model ID is set via the TrainingArguments, but not via the SpanMarkerModelCardData,
# then we can set it here for the model card regardless
if args.hub_model_id and not model.model_card_data.model_id:
model.model_card_data.model_id = args.hub_model_id
if not model.model_card_data.dataset_id:
# Inferring is hacky - it may break in the future, so let's be safe
try:
model.model_card_data.infer_dataset_id(train_dataset)
except Exception:
pass
super().__init__(
model=model,
args=args,
data_collator=model.data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=model.tokenizer,
model_init=None,
compute_metrics=compute_metrics_func,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# We have to provide the __init__ with None for model_init and then override it here again
# We do this because we need `model` to already be defined in this SpanMarker Trainer class
# and the Transformers Trainer would complain if we provide both a model and a model_init
# in its __init__.
self.model_init = model_init
# Override the type hint
self.model: SpanMarkerModel
# Add the callback for filling the model card data with hyperparameters
# and evaluation results
self.add_callback(ModelCardCallback(self))
[docs]
def preprocess_dataset(
self,
dataset: Dataset,
label_normalizer: LabelNormalizer,
tokenizer: SpanMarkerTokenizer,
dataset_name: str = "train",
is_evaluate: bool = False,
) -> Dataset:
"""Normalize the ``ner_tags`` labels and call tokenizer on ``tokens``.
Args:
dataset (~datasets.Dataset): A Hugging Face dataset with ``tokens`` and ``ner_tags`` columns.
label_normalizer (LabelNormalizer): A callable that normalizes ``ner_tags`` into start-end-label tuples.
tokenizer (SpanMarkerTokenizer): The tokenizer responsible for tokenizing ``tokens`` into input IDs,
and adding start and end markers.
dataset_name (str, optional): The name of the dataset. Defaults to "train".
is_evaluate (bool, optional): Whether to return the number of words for each sample.
Required for evaluation. Defaults to False.
Raises:
ValueError: If the ``dataset`` does not contain ``tokens`` and ``ner_tags`` columns.
Returns:
Dataset: The normalized and tokenized version of the input dataset.
"""
for column in self.REQUIRED_COLUMNS:
if column not in dataset.column_names:
raise ValueError(f"The {dataset_name} dataset must contain a {column!r} column.")
# Drop all unused columns, only keep "tokens", "ner_tags", "document_id", "sentence_id"
dataset = dataset.remove_columns(
set(dataset.column_names) - set(self.OPTIONAL_COLUMNS) - set(self.REQUIRED_COLUMNS)
)
# Normalize the labels to a common format (list of label-start-end tuples)
# Also add "entity_count" and "word_count" labels
dataset = dataset.map(
label_normalizer,
input_columns=("tokens", "ner_tags"),
desc=f"Label normalizing the {dataset_name} dataset",
batched=True,
)
# Setting model card data based on training data
if not is_evaluate:
# Pick some example entities from each entity class for the model card.
if not self.model.model_card_data.label_example_list:
self.model.model_card_data.set_label_examples(
dataset, self.model.config.id2label, self.model.config.outside_id
)
if not self.model.model_card_data.train_set_metrics_list:
self.model.model_card_data.set_train_set_metrics(dataset)
# Set some example sentences for the model card widget
if is_evaluate and not self.model.model_card_data.widget:
self.model.model_card_data.set_widget_examples(dataset)
# Remove dataset columns that are only used for model card
dataset = dataset.remove_columns(("entity_count", "word_count"))
# Tokenize and add start/end markers
with tokenizer.entity_tracker(split=dataset_name):
dataset = dataset.map(
tokenizer,
batched=True,
remove_columns=set(dataset.column_names) - set(self.OPTIONAL_COLUMNS),
desc=f"Tokenizing the {dataset_name} dataset",
fn_kwargs={"return_num_words": is_evaluate},
)
# If "document_id" AND "sentence_id" exist in the training dataset
if {"document_id", "sentence_id"} <= set(dataset.column_names):
# If training, set the config flag that this model is trained with document context
if not is_evaluate:
self.model.config.trained_with_document_context = True
# If evaluating and the model was not trained with document context, warn
elif not self.model.config.trained_with_document_context:
logger.warning(
"This model was trained without document-level context: "
"evaluation with document-level context may cause decreased performance."
)
dataset = dataset.sort(column_names=["document_id", "sentence_id"])
dataset = self.add_context(
dataset,
tokenizer.model_max_length,
max_prev_context=self.model.config.max_prev_context,
max_next_context=self.model.config.max_next_context,
)
elif is_evaluate and self.model.config.trained_with_document_context:
logger.warning(
"This model was trained with document-level context: "
"evaluation without document-level context may cause decreased performance."
)
# Spread between multiple samples where needed
original_length = len(dataset)
dataset = dataset.map(
Trainer.spread_sample,
batched=True,
desc="Spreading data between multiple samples",
fn_kwargs={
"model_max_length": tokenizer.model_max_length,
"marker_max_length": self.model.config.marker_max_length,
},
)
new_length = len(dataset)
logger.info(
f"Spread {original_length} sentences across {new_length} samples, "
f"a {(new_length / original_length) - 1:%} increase. You can increase "
"`model_max_length` or `marker_max_length` to decrease the number of samples, "
"but recognize that longer samples are slower."
)
return dataset
[docs]
@staticmethod
def add_context(
dataset: Dataset,
model_max_length: int,
max_prev_context: Optional[int] = None,
max_next_context: Optional[int] = None,
show_progress_bar: bool = True,
) -> Dataset:
"""Add document-level context from previous and next sentences in the same document.
Args:
dataset (`Dataset`): The partially processed dataset, containing `"input_ids"`, `"start_position_ids"`,
`"end_position_ids"`, `"document_id"` and `"sentence_id"` columns.
model_max_length (`int`): The total number of tokens that can be processed before
truncation.
max_prev_context (`Optional[int]`): The maximum number of previous sentences to include. Defaults to None,
representing as many previous sentences as fits.
max_next_context (`Optional[int]`): The maximum number of next sentences to include. Defaults to None,
representing as many previous sentences as fits.
show_progress_bar (`bool`): Whether to show a progress bar. Defaults to `True`.
Returns:
Dataset: A copy of the Dataset with additional previous and next sentences added to input_ids.
"""
all_input_ids = []
all_start_position_ids = []
all_end_position_ids = []
for sample_idx, sample in tqdm(
enumerate(dataset),
desc="Adding document-level context",
total=len(dataset),
leave=False,
disable=not show_progress_bar,
):
# Sequentially add next context, previous context, next context, previous context, etc. until
# max token length or max_prev/next_context
tokens = sample["input_ids"][1:-1]
start_position_ids = sample["start_position_ids"]
end_position_ids = sample["end_position_ids"]
next_context_added = 0
prev_context_added = 0
remaining_space = model_max_length - len(tokens) - 2
while remaining_space > 0:
next_context_index = sample_idx + next_context_added + 1
should_add_next = (
(max_next_context is None or next_context_added < max_next_context)
and next_context_index < len(dataset)
and dataset[next_context_index]["document_id"] == sample["document_id"]
)
if should_add_next:
# TODO: [1:-1][:remaining_space] is not efficient
tokens += dataset[next_context_index]["input_ids"][1:-1][:remaining_space]
next_context_added += 1
remaining_space = model_max_length - len(tokens) - 2
if remaining_space <= 0:
break
prev_context_index = sample_idx - prev_context_added - 1
should_add_prev = (
(max_prev_context is None or prev_context_added < max_prev_context)
and prev_context_index >= 0
and dataset[prev_context_index]["document_id"] == sample["document_id"]
)
if should_add_prev:
# TODO: [1:-1][remaining_space:] is not efficient
prepended_tokens = dataset[prev_context_index]["input_ids"][1:-1][-remaining_space:]
tokens = prepended_tokens + tokens
# TODO: Use numpy? np.array(sample["start_position_ids"]) + len(prepended_tokens)
start_position_ids = [index + len(prepended_tokens) for index in start_position_ids]
end_position_ids = [index + len(prepended_tokens) for index in end_position_ids]
prev_context_added += 1
if not should_add_next and not should_add_prev:
break
remaining_space = model_max_length - len(tokens) - 2
all_input_ids.append([sample["input_ids"][0]] + tokens + [sample["input_ids"][-1]])
all_start_position_ids.append(start_position_ids)
all_end_position_ids.append(end_position_ids)
dataset = dataset.remove_columns(("input_ids", "start_position_ids", "end_position_ids"))
dataset = dataset.add_column("input_ids", all_input_ids)
dataset = dataset.add_column("start_position_ids", all_start_position_ids)
dataset = dataset.add_column("end_position_ids", all_end_position_ids)
return dataset
[docs]
@staticmethod
def spread_sample(
batch: Dict[str, List[Any]], model_max_length: int, marker_max_length: int
) -> Dict[str, List[Any]]:
"""Spread sentences between multiple samples if lack of space per sample requires it.
Args:
batch (`Dict[str, List[Any]]`): A dictionary of dataset keys to lists of values.
model_max_length (`int`): The total number of tokens that can be processed before
truncation.
marker_max_length (`int`): The maximum length for each of the span markers. A value of 128
means that each training and inferencing sample contains a maximum of 128 start markers
and 128 end markers, for a total of 256 markers per sample.
Returns:
Dict[str, List[Any]]: A dictionary of dataset keys to lists of values.
"""
keys = batch.keys()
values = batch.values()
total_sample_length = model_max_length + 2 * marker_max_length
batch_samples = {key: [] for key in keys}
for sample in zip(*values):
sample = dict(zip(keys, sample))
sample_marker_space = (total_sample_length - len(sample["input_ids"])) // 2
spread_between_n = math.ceil(len(sample["start_position_ids"]) / sample_marker_space)
for i in range(spread_between_n):
sample_copy = sample.copy()
start = i * sample_marker_space
end = (i + 1) * sample_marker_space
sample_copy["start_position_ids"] = sample["start_position_ids"][start:end]
sample_copy["end_position_ids"] = sample["end_position_ids"][start:end]
if "labels" in sample:
sample_copy["labels"] = sample["labels"][start:end]
sample_copy["num_spans"] = len(sample_copy["start_position_ids"])
for key, value in sample_copy.items():
batch_samples[key].append(value)
return batch_samples
def get_train_dataloader(self) -> DataLoader:
"""Return the preprocessed training DataLoader."""
self.train_dataset = self.preprocess_dataset(self.train_dataset, self.label_normalizer, self.tokenizer)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""Return the preprocessed evaluation DataLoader."""
eval_dataset = eval_dataset or self.eval_dataset
if eval_dataset is not None:
eval_dataset = self.preprocess_dataset(
eval_dataset, self.label_normalizer, self.tokenizer, dataset_name="evaluation", is_evaluate=True
)
return super().get_eval_dataloader(eval_dataset)
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""Return the preprocessed evaluation DataLoader."""
test_dataset = self.preprocess_dataset(
test_dataset, self.label_normalizer, self.tokenizer, dataset_name="test", is_evaluate=True
)
return super().get_test_dataloader(test_dataset)
def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
logger.warning(
f"`Trainer.predict` is not recommended for a {self.model.__class__.__name__}. "
f"Consider using `{self.model.__class__.__name__}.predict` instead."
)
return super().predict(test_dataset, ignore_keys, metric_key_prefix)
[docs]
def create_model_card(self, *_args, **_kwargs) -> None:
"""
Creates a draft of a model card using the information available to the `Trainer`,
the `SpanMarkerModel` and the `SpanMarkerModelCardData`.
"""
with open(os.path.join(self.args.output_dir, "README.md"), "w", encoding="utf8") as f:
f.write(self.model.generate_model_card())