Getting Started with SpanMarker¶

SpanMarker is an accessible yet powerful Python module for training Named Entity Recognition models.

In this notebook, we’ll have a look at how to train an NER model using SpanMarker.

Setup¶

First of all, the span_marker Python module needs to be installed. If we want to use Weights and Biases for logging, we can install span_marker using the [wandb] extra.

[ ]:
%pip install span_marker
# %pip install span_marker[wandb]

Loading the dataset¶

For this example, we’ll load the challenging FewNERD supervised dataset from the Hugging Face hub using 🤗 Datasets.

[2]:
from datasets import load_dataset

dataset_id = "DFKI-SLT/few-nerd"
dataset = load_dataset(dataset_id, "supervised")
dataset
[2]:
DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'fine_ner_tags'],
        num_rows: 131767
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'fine_ner_tags'],
        num_rows: 18824
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'fine_ner_tags'],
        num_rows: 37648
    })
})

Let’s inspect some samples to get a feel for the data.

[3]:
for sample in dataset["train"].select(range(3)):
    print(sample)
{'id': '0', 'tokens': ['Paul', 'International', 'airport', '.'], 'ner_tags': [0, 0, 0, 0], 'fine_ner_tags': [0, 0, 0, 0]}
{'id': '1', 'tokens': ['It', 'starred', 'Hicks', "'s", 'wife', ',', 'Ellaline', 'Terriss', 'and', 'Edmund', 'Payne', '.'], 'ner_tags': [0, 0, 7, 0, 0, 0, 7, 7, 0, 7, 7, 0], 'fine_ner_tags': [0, 0, 51, 0, 0, 0, 50, 50, 0, 50, 50, 0]}
{'id': '2', 'tokens': ['``', 'Time', '``', 'magazine', 'said', 'the', 'film', 'was', '``', 'a', 'multimillion', 'dollar', 'improvisation', 'that', 'does', 'everything', 'but', 'what', 'the', 'title', 'promises', "''", 'and', 'suggested', 'that', '``', 'writer', 'George', 'Axelrod', '(', '``', 'The', 'Seven', 'Year', 'Itch', '``', ')', 'and', 'director', 'Richard', 'Quine', 'should', 'have', 'taken', 'a', 'hint', 'from', 'Holden', '[', "'s", 'character', 'Richard', 'Benson', ']', ',', 'who', 'writes', 'his', 'movie', ',', 'takes', 'a', 'long', 'sober', 'look', 'at', 'what', 'he', 'has', 'wrought', ',', 'and', 'burns', 'it', '.', "''"], 'ner_tags': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 7, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 7, 7, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'fine_ner_tags': [0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 51, 0, 0, 6, 6, 6, 6, 0, 0, 0, 0, 53, 53, 0, 0, 0, 0, 0, 0, 54, 0, 0, 0, 54, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

As you can see, this dataset contains tokens, ner_tags and a fine_ner_tags columns. Let’s have a look at which labels these last two represent using the Dataset features.

[4]:
labels = dataset["train"].features["ner_tags"].feature.names
print(labels)
['O', 'art', 'building', 'event', 'location', 'organization', 'other', 'person', 'product']
[5]:
fine_labels = dataset["train"].features["fine_ner_tags"].feature.names
print(fine_labels)
['O', 'art-broadcastprogram', 'art-film', 'art-music', 'art-other', 'art-painting', 'art-writtenart', 'building-airport', 'building-hospital', 'building-hotel', 'building-library', 'building-other', 'building-restaurant', 'building-sportsfacility', 'building-theater', 'event-attack/battle/war/militaryconflict', 'event-disaster', 'event-election', 'event-other', 'event-protest', 'event-sportsevent', 'location-GPE', 'location-bodiesofwater', 'location-island', 'location-mountain', 'location-other', 'location-park', 'location-road/railway/highway/transit', 'organization-company', 'organization-education', 'organization-government/governmentagency', 'organization-media/newspaper', 'organization-other', 'organization-politicalparty', 'organization-religion', 'organization-showorganization', 'organization-sportsleague', 'organization-sportsteam', 'other-astronomything', 'other-award', 'other-biologything', 'other-chemicalthing', 'other-currency', 'other-disease', 'other-educationaldegree', 'other-god', 'other-language', 'other-law', 'other-livingthing', 'other-medical', 'person-actor', 'person-artist/author', 'person-athlete', 'person-director', 'person-other', 'person-politician', 'person-scholar', 'person-soldier', 'product-airplane', 'product-car', 'product-food', 'product-game', 'product-other', 'product-ship', 'product-software', 'product-train', 'product-weapon']

For the purposes of this tutorial, let’s stick with the ner_tags coarse-grained labels, but I challenge you to modify this Notebook to train for the fine labels. For the SpanMarker model, any dataset can be used as long as it has a tokens and a ner_tags column. The ner_tags can be annotated using the IOB, IOB2, BIOES or BILOU labeling scheme, but also regular unschemed labels like in this FewNERD example can be used.

Initializing a SpanMarkerModel¶

A SpanMarker model is initialized via SpanMarkerModel.from_pretrained. This method will be familiar to those who know 🤗 Transformers. It accepts either a path to a local model or the name of a model on the Hugging Face Hub.

Importantly, the model can either be an encoder or an already trained and saved SpanMarker model. As we haven’t trained anything yet, we will use an encoder. To learn how to load and use a saved SpanMarker model, please have a look at the Loading & Inferencing notebook.

Reasonable options for encoders include BERT, RoBERTa, mBERT, XLM-RoBERTa, etc., which means that the following are all good options:

Not all encoders work though, they must allow for position_ids as an input argument, which disqualifies DistilBERT, T5, DistilRoBERTa, ALBERT & BART.

Additionally, it’s important to consider that cased models typically demand consistent capitalization in the inference data, aligning with how the training data is formatted. In simpler terms, if your training data consistently uses correct capitalization, but your inference data does not, it may lead to suboptimal performance. In such cases, you might find an uncased model more suitable. Although it may exhibit slightly lower F1 scores on the testing set, it remains functional regardless of capitalization, making it potentially more effective in real-world scenarios.

We’ll use "bert-base-cased" for this notebook. If you’re running this on Google Colab, be sure to set hardware accelerator to “GPU” in Runtime > Change runtime type.

[ ]:
from span_marker import SpanMarkerModel, SpanMarkerModelCardData

encoder_id = "bert-base-cased"
model = SpanMarkerModel.from_pretrained(
    # Required arguments
    encoder_id,
    labels=labels,
    # Optional arguments
    model_max_length=256,
    entity_max_length=8,
    # To improve the generated model card
    model_card_data=SpanMarkerModelCardData(
        language=["en"],
        license="cc-by-sa-4.0",
        encoder_id=encoder_id,
        dataset_id=dataset_id,
    )
)

For us, these warnings are expected, as we are initializing BertModel for a new task.

Note that we provided SpanMarkerModel.from_pretrained with a list of our labels. This is required when training a new model using an encoder. Furthermore, we can specify some useful configuration parameters from SpanMarkerConfig, such as:

  • model_max_length: The maximum number of tokens that the model will process. If you only use short sentences for your model, reducing this number may help training and inference speeds with no loss in performance. Defaults to the encoder maximum, or 512 if the encoder doesn’t have a maximum.

  • entity_max_length: The total number of words that one entity can be. Defaults to 8.

  • model_card_data: A SpanMarkerModelCardData instance where you can provide a lot of useful data about your model. This data will be automatically included in a generated model card whenever a model is saved or pushed to the Hugging Face Hub.

    • Consider adding language, license, model_id, encoder_id and dataset_id to improve the generated model card README.md file.

Training¶

At this point, our model is already ready for training! We can import TrainingArguments directly from 🤗 Transformers as well as the SpanMarker Trainer. The Trainer is a subclass of the 🤗 Transformers Trainer that simplifies some tasks for you, but otherwise it works just like the regular Trainer.

This next snippet shows some reasonable defaults. Feel free to adjust the batch size to a lower value if you experience out of memory exceptions.

[7]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="models/span-marker-bert-base-fewnerd-coarse-super",
    learning_rate=5e-5,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=200,
    push_to_hub=False,
    logging_steps=50,
    fp16=True,
    warmup_ratio=0.1,
    dataloader_num_workers=2,
)

Now we can create a SpanMarker Trainer in the same way that you would initialize a 🤗 Transformers Trainer. We’ll train on a subsection of the data to save us some time. Amazingly, this Trainer will automatically create logs using exactly the logging tools that you have installed. With other words, if you prefer logging with Tensorboard, all that you have to do is install it.

[ ]:
from span_marker import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"].select(range(8000)),
    eval_dataset=dataset["validation"].select(range(2000)),
)

Let’s start training!

[9]:
trainer.train()
{'loss': 0.6974, 'learning_rate': 1.991869918699187e-05, 'epoch': 0.04}
{'loss': 0.0896, 'learning_rate': 4.0243902439024395e-05, 'epoch': 0.08}
{'loss': 0.0584, 'learning_rate': 4.8822463768115946e-05, 'epoch': 0.12}
{'loss': 0.0382, 'learning_rate': 4.655797101449276e-05, 'epoch': 0.16}
{'eval_loss': 0.03181104362010956, 'eval_overall_precision': 0.6967930029154519, 'eval_overall_recall': 0.5989974937343359, 'eval_overall_f1': 0.6442048517520216, 'eval_overall_accuracy': 0.8993717106605198, 'eval_runtime': 29.16, 'eval_samples_per_second': 83.985, 'eval_steps_per_second': 21.022, 'epoch': 0.16}
{'loss': 0.0333, 'learning_rate': 4.429347826086957e-05, 'epoch': 0.2}
{'loss': 0.0303, 'learning_rate': 4.202898550724638e-05, 'epoch': 0.24}
{'loss': 0.032, 'learning_rate': 3.976449275362319e-05, 'epoch': 0.29}
{'loss': 0.0304, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.33}
{'eval_loss': 0.02394717186689377, 'eval_overall_precision': 0.7350157728706624, 'eval_overall_recall': 0.7187198766146135, 'eval_overall_f1': 0.7267764889365436, 'eval_overall_accuracy': 0.9227489698502713, 'eval_runtime': 29.481, 'eval_samples_per_second': 83.07, 'eval_steps_per_second': 20.793, 'epoch': 0.33}
{'loss': 0.0265, 'learning_rate': 3.5235507246376816e-05, 'epoch': 0.37}
{'loss': 0.0254, 'learning_rate': 3.297101449275363e-05, 'epoch': 0.41}
{'loss': 0.0249, 'learning_rate': 3.0706521739130435e-05, 'epoch': 0.45}
{'loss': 0.0242, 'learning_rate': 2.8442028985507245e-05, 'epoch': 0.49}
{'eval_loss': 0.02163967303931713, 'eval_overall_precision': 0.762808736476832, 'eval_overall_recall': 0.7204549836128783, 'eval_overall_f1': 0.7410271663692247, 'eval_overall_accuracy': 0.9293582473175309, 'eval_runtime': 29.0261, 'eval_samples_per_second': 84.372, 'eval_steps_per_second': 21.119, 'epoch': 0.49}
{'loss': 0.0224, 'learning_rate': 2.6177536231884058e-05, 'epoch': 0.53}
{'loss': 0.0242, 'learning_rate': 2.391304347826087e-05, 'epoch': 0.57}
{'loss': 0.0226, 'learning_rate': 2.164855072463768e-05, 'epoch': 0.61}
{'loss': 0.0245, 'learning_rate': 1.9384057971014493e-05, 'epoch': 0.65}
{'eval_loss': 0.020556513220071793, 'eval_overall_precision': 0.7680876026593665, 'eval_overall_recall': 0.7572778099093889, 'eval_overall_f1': 0.7626444034559751, 'eval_overall_accuracy': 0.9338052303047611, 'eval_runtime': 29.7545, 'eval_samples_per_second': 82.307, 'eval_steps_per_second': 20.602, 'epoch': 0.65}
{'loss': 0.0231, 'learning_rate': 1.7119565217391306e-05, 'epoch': 0.69}
{'loss': 0.0209, 'learning_rate': 1.4855072463768116e-05, 'epoch': 0.73}
{'loss': 0.0202, 'learning_rate': 1.2590579710144929e-05, 'epoch': 0.77}
{'loss': 0.0212, 'learning_rate': 1.032608695652174e-05, 'epoch': 0.81}
{'eval_loss': 0.01960749179124832, 'eval_overall_precision': 0.7743021183923976, 'eval_overall_recall': 0.7540003855793329, 'eval_overall_f1': 0.7640164094549716, 'eval_overall_accuracy': 0.9358247317530904, 'eval_runtime': 29.6794, 'eval_samples_per_second': 82.515, 'eval_steps_per_second': 20.654, 'epoch': 0.81}
{'loss': 0.0202, 'learning_rate': 8.061594202898551e-06, 'epoch': 0.86}
{'loss': 0.0196, 'learning_rate': 5.797101449275362e-06, 'epoch': 0.9}
{'loss': 0.0232, 'learning_rate': 3.5326086956521736e-06, 'epoch': 0.94}
{'loss': 0.0183, 'learning_rate': 1.2681159420289857e-06, 'epoch': 0.98}
{'eval_loss': 0.019303549081087112, 'eval_overall_precision': 0.7719162141194724, 'eval_overall_recall': 0.7673028725660305, 'eval_overall_f1': 0.769602629797931, 'eval_overall_accuracy': 0.9378442332014197, 'eval_runtime': 29.1715, 'eval_samples_per_second': 83.952, 'eval_steps_per_second': 21.014, 'epoch': 0.98}
{'train_runtime': 450.609, 'train_samples_per_second': 21.788, 'train_steps_per_second': 2.723, 'train_loss': 0.056268237500824186, 'epoch': 1.0}
[9]:
TrainOutput(global_step=1227, training_loss=0.056268237500824186, metrics={'train_runtime': 450.609, 'train_samples_per_second': 21.788, 'train_steps_per_second': 2.723, 'train_loss': 0.056268237500824186, 'epoch': 1.0})

And now the final step is to compute the model’s performance.

[10]:
metrics = trainer.evaluate()
metrics
[10]:
{'eval_loss': 0.019375888630747795,
 'eval_art': {'precision': 0.7661290322580645,
  'recall': 0.7723577235772358,
  'f1': 0.7692307692307692,
  'number': 246},
 'eval_building': {'precision': 0.5842293906810035,
  'recall': 0.6127819548872181,
  'f1': 0.5981651376146789,
  'number': 266},
 'eval_event': {'precision': 0.5497382198952879,
  'recall': 0.5965909090909091,
  'f1': 0.5722070844686648,
  'number': 176},
 'eval_location': {'precision': 0.8036732108929703,
  'recall': 0.8409542743538767,
  'f1': 0.8218911917098446,
  'number': 1509},
 'eval_organization': {'precision': 0.7474226804123711,
  'recall': 0.6998069498069498,
  'f1': 0.7228315054835494,
  'number': 1036},
 'eval_other': {'precision': 0.6775818639798489,
  'recall': 0.5604166666666667,
  'f1': 0.61345496009122,
  'number': 480},
 'eval_person': {'precision': 0.8636363636363636,
  'recall': 0.9063313096270599,
  'f1': 0.8844688954718578,
  'number': 1153},
 'eval_product': {'precision': 0.7366666666666667,
  'recall': 0.6884735202492211,
  'f1': 0.7117552334943639,
  'number': 321},
 'eval_overall_precision': 0.7705836876691148,
 'eval_overall_recall': 0.7686524002313476,
 'eval_overall_f1': 0.7696168323520897,
 'eval_overall_accuracy': 0.9381502182693484,
 'eval_runtime': 28.5583,
 'eval_samples_per_second': 85.754,
 'eval_steps_per_second': 21.465,
 'epoch': 1.0}

Additionally, we should evaluate using the test set.

[11]:
trainer.evaluate(dataset["test"].select(range(2000)), metric_key_prefix="test")
This SpanMarker model won't be able to predict 0.285605% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 8 words.
These are the frequencies of the missed entities due to maximum entity length out of 5252 total entities:
- 5 missed entities with 9 words (0.095202%)
- 1 missed entities with 10 words (0.019040%)
- 3 missed entities with 11 words (0.057121%)
- 2 missed entities with 12 words (0.038081%)
- 1 missed entities with 13 words (0.019040%)
- 1 missed entities with 17 words (0.019040%)
- 1 missed entities with 19 words (0.019040%)
- 1 missed entities with 40 words (0.019040%)
[11]:
{'test_loss': 0.01918497122824192,
 'test_art': {'precision': 0.7419354838709677,
  'recall': 0.7488372093023256,
  'f1': 0.7453703703703703,
  'number': 215},
 'test_building': {'precision': 0.6236559139784946,
  'recall': 0.710204081632653,
  'f1': 0.6641221374045801,
  'number': 245},
 'test_event': {'precision': 0.6153846153846154,
  'recall': 0.5529953917050692,
  'f1': 0.5825242718446603,
  'number': 217},
 'test_location': {'precision': 0.812192118226601,
  'recall': 0.8515171078114913,
  'f1': 0.8313898518751971,
  'number': 1549},
 'test_organization': {'precision': 0.7320754716981132,
  'recall': 0.6897777777777778,
  'f1': 0.7102974828375286,
  'number': 1125},
 'test_other': {'precision': 0.7375886524822695,
  'recall': 0.6328600405679513,
  'f1': 0.6812227074235807,
  'number': 493},
 'test_person': {'precision': 0.8805309734513275,
  'recall': 0.9061930783242259,
  'f1': 0.8931777378815081,
  'number': 1098},
 'test_product': {'precision': 0.6641221374045801,
  'recall': 0.5898305084745763,
  'f1': 0.6247755834829445,
  'number': 295},
 'test_overall_precision': 0.7766859344894027,
 'test_overall_recall': 0.7697154859652473,
 'test_overall_f1': 0.7731850004795243,
 'test_overall_accuracy': 0.938954021816699,
 'test_runtime': 29.8808,
 'test_samples_per_second': 81.658,
 'test_steps_per_second': 20.414,
 'epoch': 1.0}

Let’s try the model out with some predictions. For this we can use the model.predict method, which accepts either:

  • A sentence as a string.

  • A tokenized sentence as a list of strings.

  • A list of sentences as a list of strings.

  • A list of tokenized sentences as a list of lists of strings.

The method returns a list of dictionaries for each sentence, with the following keys:

  • "label": The string label for the found entity.

  • "score": The probability score indicating the model its confidence.

  • "span": The entity span as a string.

  • "word_start_index" and "word_end_index": Integers useful for indexing the entity from a tokenized sentence.

  • "char_start_index" and "char_end_index": Integers useful for indexing the entity from a string sentence.

[12]:
sentences = [
    "The Ninth suffered a serious defeat at the Battle of Camulodunum under Quintus Petillius Cerialis in the rebellion of Boudica (61), when most of the foot-soldiers were killed in a disastrous attempt to relieve the besieged city of Camulodunum (Colchester).",
    "He was born in Wellingborough, Northamptonshire, where he attended Victoria Junior School, Westfield Boys School and Sir Christopher Hatton School.",
    "Nintendo continued to sell the revised Wii model and the Wii Mini alongside the Wii U during the Wii U's first release year.",
    "Dorsa has a Bachelor of Music in Composition from California State University, Northridge in 2001, Master of Music in Harpsichord Performance at Cal State Northridge in 2004, and a Doctor of Musical Arts at the University of Michigan, Ann Arbor in 2008."
]

entities_per_sentence = model.predict(sentences)

for entities in entities_per_sentence:
    for entity in entities:
        print(entity["span"], "=>", entity["label"])
    print()
Battle of Camulodunum => event
Quintus Petillius Cerialis => person
Boudica => location
Camulodunum => location
Colchester => location

Wellingborough => location
Northamptonshire => location
Victoria Junior School => organization
Westfield Boys School => organization
Sir Christopher Hatton School => organization

Nintendo => organization
Wii => product
Wii Mini => product
Wii U => product
Wii U => product

Dorsa => person
Bachelor of Music in Composition => other
California State University => organization
Northridge => location
Master of Music in Harpsichord Performance => other
Cal State Northridge => organization
Doctor of Musical Arts => other
University of Michigan => organization
Ann Arbor => location

Very impressive performance for less than 8 minutes trained! 🎉

Once trained, we can save our new model locally. The saved model also comes with a flashy README.md such as this one.

[13]:
trainer.save_model("models/span-marker-bert-base-fewnerd-coarse-super/checkpoint-final")

Or we can push it to the 🤗 Hub like so.

[ ]:
trainer.push_to_hub(repo_id="span-marker-bert-base-fewnerd-coarse-super")

If we want to use it again, we can just load it using the checkpoint or using the model name on the Hub. This is how it would be done using a local checkpoint.

[15]:
# model = SpanMarkerModel.from_pretrained("models/span-marker-bert-base-fewnerd-coarse-super/checkpoint-final")

That was all! As simple as that. If we put it all together into a single script, it looks something like this:

from datasets import load_dataset
from span_marker import SpanMarkerModel, SpanMarkerModelCardData, Trainer
from transformers import TrainingArguments

def main():
    dataset_id = "DFKI-SLT/few-nerd"
    dataset = load_dataset(dataset_id, "supervised")
    labels = dataset["train"].features["ner_tags"].feature.names

    encoder_id = "bert-base-cased"
    model = SpanMarkerModel.from_pretrained(
        # Required arguments
        encoder_id,
        labels=labels,
        # Optional arguments
        model_max_length=256,
        entity_max_length=8,
        # To improve the generated model card
        model_card_data=SpanMarkerModelCardData(
            language=["en"],
            license="cc-by-sa-4.0",
            encoder_id=encoder_id,
            dataset_id=dataset_id,
        )
    )

    args = TrainingArguments(
        output_dir="models/span-marker-bert-base-fewnerd-coarse-super",
        learning_rate=5e-5,
        gradient_accumulation_steps=2,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=1,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=200,
        push_to_hub=False,
        logging_steps=50,
        fp16=True,
        warmup_ratio=0.1,
        dataloader_num_workers=2,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset["train"].select(range(8000)),
        eval_dataset=dataset["validation"].select(range(2000)),
    )
    trainer.train()

    metrics = trainer.evaluate()
    print(metrics)

    trainer.save_model("models/span-marker-bert-base-fewnerd-coarse-super/checkpoint-final")

if __name__ == "__main__":
    main()

With wandb initialized, you can enjoy their very useful training graphs straight in your browser. It ends up looking something like this. image image1

Furthermore, you can use the wandb hyperparameter search functionality using the tutorial from the Hugging Face documentation here. This transfers very well to the SpanMarker Trainer.