Source code for span_marker.evaluation

import warnings
from typing import Dict

import evaluate
import torch
from sklearn.exceptions import UndefinedMetricWarning
from transformers import EvalPrediction

from span_marker.tokenizer import SpanMarkerTokenizer


[docs]def compute_f1_via_seqeval( tokenizer: SpanMarkerTokenizer, eval_prediction: EvalPrediction, is_in_train: bool ) -> Dict[str, float]: """Compute micro-F1, recall, precision and accuracy scores using ``seqeval`` for the evaluation predictions. Note: We assume that samples are not shuffled for the evaluation/prediction. With other words, don't use this on the (shuffled) train dataset! Args: tokenizer (SpanMarkerTokenizer): The model its tokenizer. eval_prediction (~transformers.EvalPrediction): The predictions resulting from the evaluations. Returns: Dict[str, float]: Dictionary with ``"overall_precision"``, ``"overall_recall"``, ``"overall_f1"`` and ``"overall_accuracy"`` keys. """ inputs = eval_prediction.inputs gold_labels = eval_prediction.label_ids logits = eval_prediction.predictions[0] num_words = eval_prediction.predictions[2] has_document_context = len(eval_prediction.predictions) == 5 if has_document_context: document_ids = eval_prediction.predictions[3] sentence_ids = eval_prediction.predictions[4] # Compute probabilities via softmax and extract 'winning' scores/labels probs = torch.tensor(logits, dtype=torch.float32).softmax(dim=-1) scores, pred_labels = probs.max(-1) # Collect all samples in one dict. We do this because some samples are spread between multiple inputs sample_list = [] for sample_idx in range(inputs.shape[0]): tokens = inputs[sample_idx] text = tokenizer.decode(tokens, skip_special_tokens=True) token_hash = hash(text) if not has_document_context else (document_ids[sample_idx], sentence_ids[sample_idx]) if ( not sample_list or sample_list[-1]["hash"] != token_hash or len(sample_list[-1]["spans"]) == len(sample_list[-1]["gold_labels"]) ): mask = gold_labels[sample_idx] != -100 spans = list(tokenizer.get_all_valid_spans(num_words[sample_idx], tokenizer.config.entity_max_length)) sample_list.append( { "text": text, "gold_labels": gold_labels[sample_idx][mask].tolist(), "pred_labels": pred_labels[sample_idx][mask].tolist(), "scores": scores[sample_idx].tolist(), "num_words": num_words[sample_idx], "hash": token_hash, "spans": spans, } ) else: mask = gold_labels[sample_idx] != -100 sample_list[-1]["gold_labels"] += gold_labels[sample_idx][mask].tolist() sample_list[-1]["pred_labels"] += pred_labels[sample_idx][mask].tolist() sample_list[-1]["scores"] += scores[sample_idx].tolist() outside_id = tokenizer.config.outside_id id2label = tokenizer.config.id2label # seqeval works wonders for NER evaluation seqeval = evaluate.load("seqeval") for sample in sample_list: scores = sample["scores"] num_words = sample["num_words"] spans = sample["spans"] gold_labels = sample["gold_labels"] pred_labels = sample["pred_labels"] assert len(gold_labels) == len(pred_labels) and len(spans) == len(pred_labels) # Construct IOB2 format for gold labels, useful for seqeval gold_labels_per_tokens = ["O"] * num_words for span, gold_label in zip(spans, gold_labels): if gold_label != outside_id: gold_labels_per_tokens[span[0]] = "B-" + id2label[gold_label] gold_labels_per_tokens[span[0] + 1 : span[1]] = ["I-" + id2label[gold_label]] * (span[1] - span[0] - 1) # Same for predictions, note that we place most likely spans first and we disallow overlapping spans for now. pred_labels_per_tokens = ["O"] * num_words for _, span, pred_label in sorted(zip(scores, spans, pred_labels), key=lambda tup: tup[0], reverse=True): if pred_label != outside_id and all(pred_labels_per_tokens[i] == "O" for i in range(span[0], span[1])): pred_labels_per_tokens[span[0]] = "B-" + id2label[pred_label] pred_labels_per_tokens[span[0] + 1 : span[1]] = ["I-" + id2label[pred_label]] * (span[1] - span[0] - 1) seqeval.add(prediction=pred_labels_per_tokens, reference=gold_labels_per_tokens) with warnings.catch_warnings(): warnings.simplefilter("ignore", UndefinedMetricWarning) results = seqeval.compute() if is_in_train: return {key: value for key, value in results.items() if isinstance(value, float)} return results