Initializing & Training with SpanMarker¶

SpanMarker is an accessible yet powerful Python module for training Named Entity Recognition models.

In this short notebook, we’ll have a look at how to initialize and train an NER model using SpanMarker. For a larger and more general tutorial on how to use SpanMarker, please have a look at the Getting Started notebook.

Setup¶

First of all, the span_marker Python module needs to be installed. If we want to use Weights and Biases for logging, we can install span_marker using the [wandb] extra.

[ ]:
%pip install span_marker
# %pip install span_marker[wandb]

Loading the dataset¶

For this example, we’ll load the commonly used CoNLL2003 dataset from the Hugging Face hub using 🤗 Datasets.

[2]:
from datasets import load_dataset

dataset_id = "conll2003"
dataset = load_dataset(dataset_id)
dataset
[2]:
DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})
[3]:
labels = dataset["train"].features["ner_tags"].feature.names
labels
[3]:
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

SpanMarker accepts any dataset as long as it has tokens and ner_tags columns. The ner_tags can be annotated using the IOB, IOB2, BIOES or BILOU labeling scheme, but also regular unschemed labels. This CoNLL dataset uses the common IOB or IOB2 labeling scheme, with PER, ORG, LOC and MISC labels.

Initializing a SpanMarkerModel¶

A SpanMarker model is initialized via SpanMarkerModel.from_pretrained. This method will be familiar to those who know 🤗 Transformers. It accepts either a path to a local model or the name of a model on the Hugging Face Hub.

Importantly, the model can either be an encoder or an already trained and saved SpanMarker model. As we haven’t trained anything yet, we will use an encoder. To learn how to load and use a saved SpanMarker model, please have a look at the Loading & Inferencing notebook.

Reasonable options for encoders include BERT and RoBERTa, which means that the following are all good options:

Not all encoders work though, they must allow for position_ids as an input argument, which disqualifies DistilBERT, T5, DistilRoBERTa, ALBERT & BART.

Additionally, it’s important to consider that cased models typically demand consistent capitalization in the inference data, aligning with how the training data is formatted. In simpler terms, if your training data consistently uses correct capitalization, but your inference data does not, it may lead to suboptimal performance. In such cases, you might find an uncased model more suitable. Although it may exhibit slightly lower F1 scores on the testing set, it remains functional regardless of capitalization, making it potentially more effective in real-world scenarios.

We’ll use "roberta-base" for this notebook. If you’re running this on Google Colab, be sure to set hardware accelerator to “GPU” in Runtime > Change runtime type.

[6]:
from span_marker import SpanMarkerModel, SpanMarkerModelCardData

encoder_id = "roberta-base"
model = SpanMarkerModel.from_pretrained(
    # Required arguments
    encoder_id,
    labels=labels,
    # Optional arguments
    model_max_length=256,
    entity_max_length=6,
    # To improve the generated model card
    model_card_data=SpanMarkerModelCardData(
        language=["en"],
        license="apache-2.0",
        encoder_id=encoder_id,
        dataset_id=dataset_id,
    )
)
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

For us, these warnings are expected, as we are initializing RobertaModel for a new task.

Note that we provided SpanMarkerModel.from_pretrained with a list of our labels. This is required when training a new model using an encoder. Furthermore, we can specify some useful configuration parameters from SpanMarkerConfig, such as:

  • model_max_length: The maximum number of tokens that the model will process. If you only use short sentences for your model, reducing this number may help training and inference speeds with no loss in performance. Defaults to the encoder maximum, or 512 if the encoder doesn’t have a maximum.

  • entity_max_length: The total number of words that one entity can be. Defaults to 8.

  • model_card_data: A SpanMarkerModelCardData instance where you can provide a lot of useful data about your model. This data will be automatically included in a generated model card whenever a model is saved or pushed to the Hugging Face Hub.

    • Consider adding language, license, model_id, encoder_id and dataset_id to improve the generated model card README.md file.

Training¶

At this point, our model is already ready for training! We can import TrainingArguments directly from 🤗 Transformers as well as the SpanMarker Trainer. The Trainer is a subclass of the 🤗 Transformers Trainer that simplifies some tasks for you, but otherwise it works just like the regular Trainer.

This next snippet shows some reasonable defaults. Feel free to adjust the batch size to a lower value if you experience out of memory exceptions.

[7]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="models/span-marker-roberta-base-conll03",
    learning_rate=1e-5,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=500,
    push_to_hub=False,
    logging_steps=50,
    fp16=True,
    warmup_ratio=0.1,
)

Now we can create a SpanMarker Trainer in the same way that you would initialize a 🤗 Transformers Trainer. We’ll train on a subsection of the data to save us some time. Amazingly, this Trainer will automatically create logs using exactly the logging tools that you have installed. With other words, if you prefer logging with Tensorboard, all that you have to do is install it.

[8]:
from span_marker import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"].select(range(2000)),
)
trainer.train()
This SpanMarker model will ignore 0.097877% of all annotated entities in the train dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words.
These are the frequencies of the missed entities due to maximum entity length out of 23499 total entities:
- 18 missed entities with 7 words (0.076599%)
- 2 missed entities with 8 words (0.008511%)
- 3 missed entities with 10 words (0.012767%)
{'loss': 1.1135, 'learning_rate': 2.707182320441989e-06, 'epoch': 0.03}
{'loss': 0.245, 'learning_rate': 5.469613259668509e-06, 'epoch': 0.06}
{'loss': 0.1466, 'learning_rate': 8.232044198895029e-06, 'epoch': 0.08}
{'loss': 0.1077, 'learning_rate': 9.888957433682912e-06, 'epoch': 0.11}
{'loss': 0.0839, 'learning_rate': 9.58050586057989e-06, 'epoch': 0.14}
{'loss': 0.0702, 'learning_rate': 9.272054287476866e-06, 'epoch': 0.17}
{'loss': 0.0614, 'learning_rate': 8.963602714373844e-06, 'epoch': 0.19}
{'loss': 0.0476, 'learning_rate': 8.65515114127082e-06, 'epoch': 0.22}
{'loss': 0.0446, 'learning_rate': 8.346699568167798e-06, 'epoch': 0.25}
{'loss': 0.0327, 'learning_rate': 8.038247995064774e-06, 'epoch': 0.28}
This SpanMarker model won't be able to predict 0.172563% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words.
These are the frequencies of the missed entities due to maximum entity length out of 3477 total entities:
- 5 missed entities with 7 words (0.143802%)
- 1 missed entities with 10 words (0.028760%)
{'eval_loss': 0.02650175243616104, 'eval_overall_precision': 0.8974691758598313, 'eval_overall_recall': 0.7968885047536733, 'eval_overall_f1': 0.8441934991606898, 'eval_overall_accuracy': 0.9632217370208637, 'eval_runtime': 20.1351, 'eval_samples_per_second': 102.656, 'eval_steps_per_second': 25.676, 'epoch': 0.28}
{'loss': 0.0348, 'learning_rate': 7.729796421961752e-06, 'epoch': 0.31}
{'loss': 0.0378, 'learning_rate': 7.42134484885873e-06, 'epoch': 0.33}
{'loss': 0.0275, 'learning_rate': 7.112893275755707e-06, 'epoch': 0.36}
{'loss': 0.0242, 'learning_rate': 6.804441702652684e-06, 'epoch': 0.39}
{'loss': 0.0255, 'learning_rate': 6.495990129549661e-06, 'epoch': 0.42}
{'loss': 0.0235, 'learning_rate': 6.187538556446638e-06, 'epoch': 0.44}
{'loss': 0.0223, 'learning_rate': 5.879086983343616e-06, 'epoch': 0.47}
{'loss': 0.0183, 'learning_rate': 5.570635410240592e-06, 'epoch': 0.5}
{'loss': 0.0194, 'learning_rate': 5.26218383713757e-06, 'epoch': 0.53}
{'loss': 0.0191, 'learning_rate': 4.953732264034547e-06, 'epoch': 0.55}
This SpanMarker model won't be able to predict 0.172563% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words.
These are the frequencies of the missed entities due to maximum entity length out of 3477 total entities:
- 5 missed entities with 7 words (0.143802%)
- 1 missed entities with 10 words (0.028760%)
{'eval_loss': 0.016905048862099648, 'eval_overall_precision': 0.9247838616714698, 'eval_overall_recall': 0.9245174301354077, 'eval_overall_f1': 0.9246506267108485, 'eval_overall_accuracy': 0.9844412097687207, 'eval_runtime': 20.2213, 'eval_samples_per_second': 102.219, 'eval_steps_per_second': 25.567, 'epoch': 0.55}
{'loss': 0.0206, 'learning_rate': 4.645280690931524e-06, 'epoch': 0.58}
{'loss': 0.0198, 'learning_rate': 4.336829117828501e-06, 'epoch': 0.61}
{'loss': 0.0184, 'learning_rate': 4.028377544725479e-06, 'epoch': 0.64}
{'loss': 0.0203, 'learning_rate': 3.7199259716224557e-06, 'epoch': 0.67}
{'loss': 0.0206, 'learning_rate': 3.4114743985194327e-06, 'epoch': 0.69}
{'loss': 0.0187, 'learning_rate': 3.1030228254164097e-06, 'epoch': 0.72}
{'loss': 0.015, 'learning_rate': 2.794571252313387e-06, 'epoch': 0.75}
{'loss': 0.0221, 'learning_rate': 2.486119679210364e-06, 'epoch': 0.78}
{'loss': 0.0189, 'learning_rate': 2.177668106107341e-06, 'epoch': 0.8}
{'loss': 0.0158, 'learning_rate': 1.8692165330043186e-06, 'epoch': 0.83}
{'eval_loss': 0.01296199019998312, 'eval_overall_precision': 0.9394202898550724, 'eval_overall_recall': 0.933736675309709, 'eval_overall_f1': 0.9365698598468429, 'eval_overall_accuracy': 0.9868348698043021, 'eval_runtime': 20.2701, 'eval_samples_per_second': 101.973, 'eval_steps_per_second': 25.506, 'epoch': 0.83}
{'loss': 0.0165, 'learning_rate': 1.5607649599012956e-06, 'epoch': 0.86}
{'loss': 0.017, 'learning_rate': 1.2523133867982728e-06, 'epoch': 0.89}
{'loss': 0.0183, 'learning_rate': 9.438618136952499e-07, 'epoch': 0.92}
{'loss': 0.0164, 'learning_rate': 6.35410240592227e-07, 'epoch': 0.94}
{'loss': 0.0162, 'learning_rate': 3.2695866748920424e-07, 'epoch': 0.97}
{'loss': 0.021, 'learning_rate': 1.850709438618137e-08, 'epoch': 1.0}
{'train_runtime': 479.9392, 'train_samples_per_second': 30.033, 'train_steps_per_second': 3.755, 'train_loss': 0.06940532092560087, 'epoch': 1.0}
[8]:
TrainOutput(global_step=1802, training_loss=0.06940532092560087, metrics={'train_runtime': 479.9392, 'train_samples_per_second': 30.033, 'train_steps_per_second': 3.755, 'train_loss': 0.06940532092560087, 'epoch': 1.0})

And now the final step is to compute the model’s performance.

[9]:
metrics = trainer.evaluate()
metrics
[9]:
{'eval_loss': 0.012707239016890526,
 'eval_LOC': {'precision': 0.9642857142857143,
  'recall': 0.9503610108303249,
  'f1': 0.9572727272727273,
  'number': 1108},
 'eval_MISC': {'precision': 0.8805309734513275,
  'recall': 0.8378947368421052,
  'f1': 0.8586839266450916,
  'number': 475},
 'eval_ORG': {'precision': 0.8736842105263158,
  'recall': 0.9021739130434783,
  'f1': 0.8877005347593583,
  'number': 736},
 'eval_PER': {'precision': 0.9776247848537005,
  'recall': 0.9861111111111112,
  'f1': 0.9818496110630942,
  'number': 1152},
 'eval_overall_precision': 0.9379688401615696,
 'eval_overall_recall': 0.9366176894266782,
 'eval_overall_f1': 0.9372927778578637,
 'eval_overall_accuracy': 0.9872553776483908,
 'eval_runtime': 19.9052,
 'eval_samples_per_second': 103.842,
 'eval_steps_per_second': 25.973,
 'epoch': 1.0}

Additionally, we should evaluate using the test set.

[10]:
trainer.evaluate(dataset["test"], metric_key_prefix="test")
[10]:
{'test_loss': 0.029485255479812622,
 'test_LOC': {'precision': 0.9335384615384615,
  'recall': 0.9094724220623501,
  'f1': 0.9213483146067416,
  'number': 1668},
 'test_MISC': {'precision': 0.7503429355281207,
  'recall': 0.7792022792022792,
  'f1': 0.76450034940601,
  'number': 702},
 'test_ORG': {'precision': 0.8538243626062323,
  'recall': 0.9072847682119205,
  'f1': 0.87974314068885,
  'number': 1661},
 'test_PER': {'precision': 0.9658808933002482,
  'recall': 0.9628942486085343,
  'f1': 0.964385258593992,
  'number': 1617},
 'test_overall_precision': 0.8947827604257547,
 'test_overall_recall': 0.9079320113314447,
 'test_overall_f1': 0.9013094296511117,
 'test_overall_accuracy': 0.9782276300204588,
 'test_runtime': 33.9555,
 'test_samples_per_second': 104.401,
 'test_steps_per_second': 26.122,
 'epoch': 1.0}

Great performance for 8 minutes trained! 🎉

Once trained, we can save our new model locally.

[9]:
trainer.save_model("models/span-marker-roberta-base-conll03/checkpoint-final")

Or we can push it to the 🤗 Hub like so.

[ ]:
trainer.push_to_hub(repo_id="span-marker-roberta-base-conll03")

If we want to use it again, we can just load it using the checkpoint or using the model name on the Hub. This is how it would be done using a local checkpoint. See the Loading & Inferencing notebook for more details.

[11]:
# model = SpanMarkerModel.from_pretrained("models/span-marker-roberta-base-conll03/checkpoint-final")

That was all! As simple as that. If we put it all together into a single script, it looks something like this:

from datasets import load_dataset
from span_marker import SpanMarkerModel, Trainer
from transformers import TrainingArguments

def main() -> None:
    dataset_id = "conll2003"
    dataset = load_dataset(dataset_id)
    labels = dataset["train"].features["ner_tags"].feature.names

    encoder_id = "roberta-base"
    model = SpanMarkerModel.from_pretrained(
        # Required arguments
        encoder_id,
        labels=labels,
        # Optional arguments
        model_max_length=256,
        entity_max_length=6,
        # To improve the generated model card
        model_card_data=SpanMarkerModelCardData(
            language=["en"],
            license="apache-2.0",
            encoder_id=encoder_id,
            dataset_id=dataset_id,
        )
    )

    args = TrainingArguments(
        output_dir="models/span-marker-roberta-base-conll03",
        learning_rate=1e-5,
        gradient_accumulation_steps=2,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=1,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=500,
        push_to_hub=False,
        logging_steps=50,
        fp16=True,
        warmup_ratio=0.1,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset["train"].select(range(8000)),
        eval_dataset=dataset["validation"].select(range(2000)),
    )
    trainer.train()

    metrics = trainer.evaluate()
    print(metrics)

    trainer.save_model("models/span-marker-roberta-base-conll03/checkpoint-final")

if __name__ == "__main__":
    main()

With wandb initialized, you can enjoy their very useful training graphs straight in your browser. It ends up looking something like this. image image1

Furthermore, you can use the wandb hyperparameter search functionality using the tutorial from the Hugging Face documentation here. This transfers very well to the SpanMarker Trainer.