import logging
import os
import random
from dataclasses import dataclass, field, fields
from pathlib import Path
from platform import python_version
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import datasets
import tokenizers
import torch
import transformers
from datasets import Dataset
from huggingface_hub import (
CardData,
DatasetFilter,
ModelCard,
dataset_info,
list_datasets,
model_info,
)
from huggingface_hub.repocard_data import EvalResult, eval_results_to_model_index
from huggingface_hub.utils import yaml_dump
from transformers import TrainerCallback
from transformers.integrations import CodeCarbonCallback
from transformers.modelcard import (
extract_hyperparameters_from_trainer,
make_markdown_table,
)
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
import span_marker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from span_marker.modeling import SpanMarkerModel
from span_marker.trainer import Trainer
[docs]
class ModelCardCallback(TrainerCallback):
def __init__(self, trainer: "Trainer") -> None:
super().__init__()
self.trainer = trainer
callbacks = [
callback for callback in self.trainer.callback_handler.callbacks if isinstance(callback, CodeCarbonCallback)
]
if callbacks:
trainer.model.model_card_data.code_carbon_callback = callbacks[0]
[docs]
def on_train_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SpanMarkerModel", **kwargs
):
model.model_card_data.hyperparameters = extract_hyperparameters_from_trainer(self.trainer)
[docs]
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: "SpanMarkerModel",
metrics: Dict[str, float],
**kwargs,
):
# Set the most recent evaluation scores for the metadata
model.model_card_data.eval_results_dict = metrics
if self.trainer.is_in_train:
# Either set mid-training evaluation metrics
if "eval_loss" in metrics:
model.model_card_data.eval_lines_list.append(
{
# "Training Loss": self.state.log_history[-1]["loss"] if "loss" in self.state.log_history[-1] else "-",
"Epoch": state.epoch,
"Step": state.global_step,
"Validation Loss": metrics["eval_loss"],
"Validation Precision": metrics["eval_overall_precision"],
"Validation Recall": metrics["eval_overall_recall"],
"Validation F1": metrics["eval_overall_f1"],
"Validation Accuracy": metrics["eval_overall_accuracy"],
}
)
else:
# Or set the post-training metrics
# Determine the dataset split
runtime_key = [key for key in metrics.keys() if key.endswith("_runtime")]
if not runtime_key:
return
dataset_split = runtime_key[0][: -len("_runtime")]
metric_lines = []
for key, value in metrics.items():
if not isinstance(value, float):
metric_lines.append(
{
"Label": key[len(dataset_split) + 1 :],
"Precision": value["precision"],
"Recall": value["recall"],
"F1": value["f1"],
}
)
metric_lines.insert(
0,
{
"Label": "**all**",
"Precision": metrics[f"{dataset_split}_overall_precision"],
"Recall": metrics[f"{dataset_split}_overall_recall"],
"F1": metrics[f"{dataset_split}_overall_f1"],
},
)
model.model_card_data.metric_lines = metric_lines
YAML_FIELDS = [
"language",
"license",
"library_name",
"tags",
"datasets",
"metrics",
"pipeline_tag",
"widget",
"model-index",
"co2_eq_emissions",
"base_model",
]
IGNORED_FIELDS = ["model"]
[docs]
@dataclass
class SpanMarkerModelCardData(CardData):
"""A dataclass storing data used in the model card.
Args:
language (Optional[Union[str, List[str]]]): The model language, either a string or a list,
e.g. "en" or ["en", "de", "nl"]
license: (Optional[str]): The license of the model, e.g. "apache-2.0", "mit"
or "cc-by-nc-sa-4.0"
model_name: (Optional[str]): The pretty name of the model, e.g. "SpanMarker with mBERT-base on CoNLL03".
If not defined, uses encoder_name/encoder_id and dataset_name/dataset_id to generate a model name.
model_id: (Optional[str]): The model ID when pushing the model to the Hub,
e.g. "tomaarsen/span-marker-mbert-base-multinerd".
encoder_name: (Optional[str]): The pretty name of the encoder, e.g. "mBERT-base".
encoder_id: (Optional[str]): The model ID of the encoder, e.g. "bert-base-multilingual-cased".
dataset_name: (Optional[str]): The pretty name of the dataset, e.g. "CoNLL03".
dataset_id: (Optional[str]): The dataset ID of the dataset, e.g. "tner/bionlp2004".
dataset_revision: (Optional[str]): The dataset revision/commit that was for training/evaluation.
Note:
Install ``nltk`` to detokenize the examples used in the model card, i.e. attach punctuation and brackets.
Additionally, ``codecarbon`` can be installed to automatically track carbon emission usage.
Example::
>>> model = SpanMarkerModel.from_pretrained(
... "bert-base-uncased",
... labels=["O", "B-DNA", "I-DNA", "B-protein", ...],
... # SpanMarker hyperparameters:
... model_max_length=256,
... marker_max_length=128,
... entity_max_length=8,
... # Model card variables
... model_card_data=SpanMarkerModelCardData(
... model_id="tomaarsen/span-marker-bbu-bionlp",
... encoder_id="bert-base-uncased",
... dataset_name="BioNLP2004,
... dataset_id="tner/bionlp2004",
... license="apache-2.0",
... language="en",
... ),
... )
"""
# Potentially provided by the user
language: Optional[Union[str, List[str]]] = None
license: Optional[str] = None
tags: Optional[List[str]] = field(
default_factory=lambda: [
"span-marker",
"token-classification",
"ner",
"named-entity-recognition",
"generated_from_span_marker_trainer",
]
)
model_name: Optional[str] = None
model_id: Optional[str] = None
encoder_name: Optional[str] = None
encoder_id: Optional[str] = None
dataset_name: Optional[str] = None
dataset_id: Optional[str] = None
dataset_revision: Optional[str] = None
task_name: str = "Named Entity Recognition"
# Automatically filled by `ModelCardCallback` and the Trainer directly
hyperparameters: Dict[str, Any] = field(default_factory=dict, init=False)
eval_results_dict: Optional[Dict[str, Any]] = field(default_factory=dict, init=False)
eval_lines_list: List[Dict[str, float]] = field(default_factory=list, init=False)
metric_lines: List[Dict[str, float]] = field(default_factory=list, init=False)
widget: List[Dict[str, str]] = field(default_factory=list, init=False)
predict_example: Optional[str] = field(default=None, init=False)
label_example_list: List[Dict[str, str]] = field(default_factory=list, init=False)
tokenizer_warning: bool = field(default=False, init=False)
train_set_metrics_list: List[Dict[str, str]] = field(default_factory=list, init=False)
code_carbon_callback: Optional[CodeCarbonCallback] = field(default=None, init=False)
# Computed once, always unchanged
pipeline_tag: str = field(default="token-classification", init=False)
library_name: str = field(default="span-marker", init=False)
version: Dict[str, str] = field(
default_factory=lambda: {
"python": python_version(),
"span_marker": span_marker.__version__,
"transformers": transformers.__version__,
"torch": torch.__version__,
"datasets": datasets.__version__,
"tokenizers": tokenizers.__version__,
},
init=False,
)
metrics: List[str] = field(default_factory=lambda: ["precision", "recall", "f1"], init=False)
# Passed via `register_model` only
model: Optional["SpanMarkerModel"] = field(default=None, init=False, repr=False)
def __post_init__(self):
# We don't want to save "ignore_metadata_errors" in our Model Card
if self.dataset_id:
if is_on_huggingface(self.dataset_id, is_model=False):
if self.language is None:
# if languages are not set, try to determine the language from the dataset on the Hub
try:
info = dataset_info(self.dataset_id)
except:
pass
else:
if info.cardData:
self.language = info.cardData.get("language", self.language)
else:
logger.warning(
f"The provided {self.dataset_id!r} dataset could not be found on the Hugging Face Hub."
" Setting `dataset_id` to None."
)
self.dataset_id = None
if self.encoder_id and not is_on_huggingface(self.encoder_id):
logger.warning(
f"The provided {self.encoder_id!r} model could not be found on the Hugging Face Hub."
" Setting `encoder_id` to None."
)
self.encoder_id = None
if self.model_id and self.model_id.count("/") != 1:
logger.warning(
f"The provided {self.model_id!r} model ID should include the organization or user,"
' such as "tomaarsen/span-marker-mbert-base-multinerd". Setting `model_id` to None.'
)
self.model_id = None
def set_widget_examples(self, dataset: Dataset) -> None:
# If NLTK is installed, use its detokenization. Otherwise, join by spaces.
try:
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenize = TreebankWordDetokenizer().detokenize
def map_detokenize(tokens) -> Dict[str, str]:
return {"text": detokenize(tokens)}
except ImportError:
def map_detokenize(tokens) -> Dict[str, str]:
return {"text": " ".join(tokens)}
# Out of `sample_subset_size=100` random samples, select `example_count=5` good examples
# based on the number of unique entity classes.
# The shortest example is used in the inference example
sample_subset_size = 100
example_count = 5
if len(dataset) > sample_subset_size:
example_dataset = dataset.select(random.sample(range(len(dataset)), k=sample_subset_size))
else:
example_dataset = dataset
def count_entities(sample: Dict[str, Any]) -> Dict[str, int]:
unique_count = {reduced_label_id for reduced_label_id, _, _ in sample["ner_tags"]}
return {"unique_entity_count": len(unique_count)}
example_dataset = (
example_dataset.map(count_entities)
.sort(("unique_entity_count", "entity_count"), reverse=True)
.select(range(min(len(example_dataset), example_count)))
.map(map_detokenize, input_columns="tokens")
)
self.widget = [{"text": sample["text"]} for sample in example_dataset]
shortest_example = example_dataset.sort("word_count")[0]["text"]
self.predict_example = shortest_example
def set_train_set_metrics(self, dataset: Dataset) -> None:
self.train_set_metrics_list = [
{
"Training set": "Sentence length",
"Min": min(dataset["word_count"]),
"Median": sum(dataset["word_count"]) / len(dataset),
"Max": max(dataset["word_count"]),
},
{
"Training set": "Entities per sentence",
"Min": min(dataset["entity_count"]),
"Median": sum(dataset["entity_count"]) / len(dataset),
"Max": max(dataset["entity_count"]),
},
]
def set_label_examples(self, dataset: Dataset, id2label: Dict[int, str], outside_id: int) -> None:
num_examples_per_label = 3
examples = {label: set() for label_id, label in id2label.items() if label_id != outside_id}
unfinished_entity_ids = set(id2label.keys()) - {outside_id}
for sample in dataset:
for entity_id, start, end in sample["ner_tags"]:
if entity_id in unfinished_entity_ids:
entity = id2label[entity_id]
example = " ".join(sample["tokens"][start:end])
examples[entity].add(f'"{example}"')
if len(examples[entity]) >= num_examples_per_label:
unfinished_entity_ids.remove(entity_id)
if not unfinished_entity_ids:
break
self.label_example_list = [
{"Label": label, "Examples": ", ".join(example_set)} for label, example_set in examples.items()
]
[docs]
def infer_dataset_id(self, dataset: Dataset) -> None:
def subtuple_finder(tuple: Tuple[str], subtuple: Tuple[str]) -> int:
for i, element in enumerate(tuple):
if element == subtuple[0] and tuple[i : i + len(subtuple)] == subtuple:
return i
return -1
def normalize(dataset_id: str) -> str:
for token in "/\\_-":
dataset_id = dataset_id.replace(token, "")
return dataset_id.lower()
if (cache_files := dataset.cache_files) and "filename" in cache_files[0]:
cache_path_parts = Path(cache_files[0]["filename"]).parts
# Check if the cachefile is under "huggingface/datasets"
subtuple = ("huggingface", "datasets")
index = subtuple_finder(cache_path_parts, subtuple)
if index == -1:
return
# Get the folder after "huggingface/datasets"
cache_dataset_name = cache_path_parts[index + len(subtuple)]
# If the dataset has an author:
if "___" in cache_dataset_name:
author, dataset_name = cache_dataset_name.split("___")
else:
author = None
dataset_name = cache_dataset_name
# Make sure the normalized dataset IDs match
dataset_list = [
dataset
for dataset in list_datasets(filter=DatasetFilter(author=author, dataset_name=dataset_name))
if normalize(dataset.id) == normalize(cache_dataset_name)
]
# If there's only one match, get the ID from it
if len(dataset_list) == 1:
self.dataset_id = dataset_list[0].id
def register_model(self, model: "SpanMarkerModel") -> None:
self.model = model
if self.encoder_id is None:
encoder_id_or_path = self.model.config.get("_name_or_path")
if not os.path.exists(encoder_id_or_path):
self.encoder_id = encoder_id_or_path
if not self.model_name:
if self.encoder_id:
self.model_name = f"SpanMarker with {self.encoder_name or self.encoder_id}"
if self.dataset_name or self.dataset_id:
self.model_name += f" on {self.dataset_name or self.dataset_id}"
else:
self.model_name = "SpanMarker"
[docs]
def to_dict(self) -> Dict[str, Any]:
super_dict = {field.name: getattr(self, field.name) for field in fields(self)}
# Compute required formats from the raw data
if self.eval_results_dict:
dataset_split = list(self.eval_results_dict.keys())[0].split("_")[0]
dataset_id = self.dataset_id or "unknown"
dataset_name = self.dataset_name or "Unknown"
eval_results = [
EvalResult(
task_type="token-classification",
dataset_type=dataset_id,
dataset_name=dataset_name,
metric_type="f1",
metric_value=self.eval_results_dict[f"{dataset_split}_overall_f1"],
task_name="Named Entity Recognition",
dataset_split=dataset_split,
dataset_revision=self.dataset_revision,
metric_name="F1",
),
EvalResult(
task_type="token-classification",
dataset_type=dataset_id,
dataset_name=dataset_name,
metric_type="precision",
metric_value=self.eval_results_dict[f"{dataset_split}_overall_precision"],
task_name="Named Entity Recognition",
dataset_split=dataset_split,
dataset_revision=self.dataset_revision,
metric_name="Precision",
),
EvalResult(
task_type="token-classification",
dataset_type=dataset_id,
dataset_name=dataset_name,
metric_type="recall",
metric_value=self.eval_results_dict[f"{dataset_split}_overall_recall"],
task_name="Named Entity Recognition",
dataset_split=dataset_split,
dataset_revision=self.dataset_revision,
metric_name="Recall",
),
]
super_dict["model-index"] = eval_results_to_model_index(self.model_name, eval_results)
super_dict["eval_lines"] = make_markdown_table(self.eval_lines_list)
# Replace |:---:| with |:---| for left alignment
super_dict["label_examples"] = make_markdown_table(self.label_example_list).replace("-:|", "--|")
super_dict["train_set_metrics"] = make_markdown_table(self.train_set_metrics_list).replace("-:|", "--|")
super_dict["metrics_table"] = make_markdown_table(self.metric_lines).replace("-:|", "--|")
if self.code_carbon_callback and self.code_carbon_callback.tracker:
emissions_data = self.code_carbon_callback.tracker._prepare_emissions_data()
super_dict["co2_eq_emissions"] = {
# * 1000 to convert kg to g
"emissions": float(emissions_data.emissions) * 1000,
"source": "codecarbon",
"training_type": "fine-tuning",
"on_cloud": emissions_data.on_cloud == "Y",
"cpu_model": emissions_data.cpu_model,
"ram_total_size": emissions_data.ram_total_size,
"hours_used": round(emissions_data.duration / 3600, 3),
}
if emissions_data.gpu_model:
super_dict["co2_eq_emissions"]["hardware_used"] = emissions_data.gpu_model
if self.dataset_id:
super_dict["datasets"] = [self.dataset_id]
if self.encoder_id:
super_dict["base_model"] = self.encoder_id
super_dict["model_max_length"] = self.model.tokenizer.model_max_length
for key in IGNORED_FIELDS:
super_dict.pop(key, None)
return {
**self.model.config.to_dict(),
**super_dict,
}
[docs]
def to_yaml(self, line_break=None) -> str:
return yaml_dump(
{key: value for key, value in self.to_dict().items() if key in YAML_FIELDS and value is not None},
sort_keys=False,
line_break=line_break,
).strip()
def is_on_huggingface(repo_id: str, is_model: bool = True) -> bool:
# Models with more than two 'sections' certainly are not public models
if len(repo_id.split("/")) > 2:
return False
try:
if is_model:
model_info(repo_id)
else:
dataset_info(repo_id)
return True
except:
# Fetching models can fail for many reasons: Repository not existing, no internet access, HF down, etc.
return False
def generate_model_card(model: "SpanMarkerModel") -> str:
template_path = Path(__file__).parent / "model_card_template.md"
model_card = ModelCard.from_template(card_data=model.model_card_data, template_path=template_path, hf_emoji="🤗")
return model_card.content