Source code for span_marker.label_normalizer
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Tuple
from span_marker.configuration import SpanMarkerConfig
logger = logging.getLogger(__name__)
Entity = Tuple[int, int, int]
"""
Tuple of:
* Entity label
* Word start index
* Word end index
"""
[docs]
class LabelNormalizer(ABC):
"""Class to convert NER training data into a common format used in the :class:`~span_marker.tokenizer.SpanMarkerTokenizer`.
The common format involves 3-tuples with entity labels, word start indices and word end indices.
"""
def __init__(self, config: SpanMarkerConfig) -> None:
super().__init__()
self.config = config
def __call__(self, tokens: List[List[str]], ner_tags: List[List[int]]) -> Dict[str, List[Any]]:
output = {"ner_tags": [], "entity_count": [], "word_count": []}
for tokens, ner_tags in zip(tokens, ner_tags):
ner_tags = list(self.ner_tags_to_entities(ner_tags))
output["ner_tags"].append(ner_tags)
output["entity_count"].append(len(ner_tags))
output["word_count"].append(len(tokens))
return output
[docs]
class LabelNormalizerScheme(LabelNormalizer):
def __init__(self, config: SpanMarkerConfig) -> None:
super().__init__(config)
self.label_ids_by_tag = self.config.group_label_ids_by_tag()
self.start_ids = set()
self.end_ids = set()
[docs]
class LabelNormalizerIOB(LabelNormalizerScheme):
def __init__(self, config: SpanMarkerConfig) -> None:
super().__init__(config)
# Support for IOB2 and IOB, respectively:
logger.info("Detected the IOB or IOB2 labeling scheme.")
self.start_ids = self.label_ids_by_tag["B"] | self.label_ids_by_tag["I"]
self.end_ids = self.label_ids_by_tag["B"] | self.label_ids_by_tag["O"]
[docs]
class LabelNormalizerBIOES(LabelNormalizerScheme):
def __init__(self, config: SpanMarkerConfig) -> None:
super().__init__(config)
logger.info("Detected the BIOES labeling scheme.")
self.start_ids = self.label_ids_by_tag["B"] | self.label_ids_by_tag["S"]
self.end_ids = self.label_ids_by_tag["B"] | self.label_ids_by_tag["O"] | self.label_ids_by_tag["S"]
[docs]
class LabelNormalizerBILOU(LabelNormalizerScheme):
def __init__(self, config: SpanMarkerConfig) -> None:
super().__init__(config)
# Support for BILOU and BILO:
logger.info("Detected the BILOU or BILO labeling scheme.")
self.start_ids = self.label_ids_by_tag["B"] | self.label_ids_by_tag.get("U", set())
self.end_ids = self.label_ids_by_tag["B"] | self.label_ids_by_tag["O"] | self.label_ids_by_tag.get("U", set())
[docs]
class LabelNormalizerNoScheme(LabelNormalizer):
def __init__(self, config: SpanMarkerConfig) -> None:
super().__init__(config)
logger.info("No labeling scheme detected: all label IDs belong to individual entity classes.")
[docs]
class AutoLabelNormalizer:
"""Factory class to return the correct LabelNormalizer subclass."""
[docs]
@staticmethod
def from_config(config: SpanMarkerConfig) -> LabelNormalizer:
if not config.are_labels_schemed():
return LabelNormalizerNoScheme(config)
tags = config.get_scheme_tags()
if tags == set("BIO"):
return LabelNormalizerIOB(config)
if tags == set("BIOES"):
return LabelNormalizerBIOES(config)
if tags == set("BILOU") or tags == set("BILO"):
return LabelNormalizerBILOU(config)
raise ValueError(
"Data labeling scheme not recognized. Expected either IOB, IOB2, BIOES, BILOU "
"or no scheme (i.e. one label ID per class, no B-, I- label prefixes, etc.)"
)