{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Getting Started with SpanMarker\n", "[SpanMarker](https://github.com/tomaarsen/SpanMarkerNER) is an accessible yet powerful Python module for training Named Entity Recognition models.\n", "\n", "In this notebook, we'll have a look at how to train an NER model using SpanMarker." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Setup\n", "First of all, the `span_marker` Python module needs to be installed. If we want to use [Weights and Biases](https://wandb.ai/) for logging, we can install `span_marker` using the `[wandb]` extra." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install span_marker\n", "# %pip install span_marker[wandb]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Loading the dataset\n", "For this example, we'll load the challenging [FewNERD supervised dataset](https://huggingface.co/datasets/DFKI-SLT/few-nerd) from the Hugging Face hub using 🤗 Datasets." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'tokens', 'ner_tags', 'fine_ner_tags'],\n", " num_rows: 131767\n", " })\n", " validation: Dataset({\n", " features: ['id', 'tokens', 'ner_tags', 'fine_ner_tags'],\n", " num_rows: 18824\n", " })\n", " test: Dataset({\n", " features: ['id', 'tokens', 'ner_tags', 'fine_ner_tags'],\n", " num_rows: 37648\n", " })\n", "})" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset_id = \"DFKI-SLT/few-nerd\"\n", "dataset = load_dataset(dataset_id, \"supervised\")\n", "dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let's inspect some samples to get a feel for the data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'id': '0', 'tokens': ['Paul', 'International', 'airport', '.'], 'ner_tags': [0, 0, 0, 0], 'fine_ner_tags': [0, 0, 0, 0]}\n", "{'id': '1', 'tokens': ['It', 'starred', 'Hicks', \"'s\", 'wife', ',', 'Ellaline', 'Terriss', 'and', 'Edmund', 'Payne', '.'], 'ner_tags': [0, 0, 7, 0, 0, 0, 7, 7, 0, 7, 7, 0], 'fine_ner_tags': [0, 0, 51, 0, 0, 0, 50, 50, 0, 50, 50, 0]}\n", "{'id': '2', 'tokens': ['``', 'Time', '``', 'magazine', 'said', 'the', 'film', 'was', '``', 'a', 'multimillion', 'dollar', 'improvisation', 'that', 'does', 'everything', 'but', 'what', 'the', 'title', 'promises', \"''\", 'and', 'suggested', 'that', '``', 'writer', 'George', 'Axelrod', '(', '``', 'The', 'Seven', 'Year', 'Itch', '``', ')', 'and', 'director', 'Richard', 'Quine', 'should', 'have', 'taken', 'a', 'hint', 'from', 'Holden', '[', \"'s\", 'character', 'Richard', 'Benson', ']', ',', 'who', 'writes', 'his', 'movie', ',', 'takes', 'a', 'long', 'sober', 'look', 'at', 'what', 'he', 'has', 'wrought', ',', 'and', 'burns', 'it', '.', \"''\"], 'ner_tags': [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 7, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 7, 7, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'fine_ner_tags': [0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 51, 0, 0, 6, 6, 6, 6, 0, 0, 0, 0, 53, 53, 0, 0, 0, 0, 0, 0, 54, 0, 0, 0, 54, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n" ] } ], "source": [ "for sample in dataset[\"train\"].select(range(3)):\n", " print(sample)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, this dataset contains `tokens`, `ner_tags` and a `fine_ner_tags` columns. Let's have a look at which labels these last two represent using the [Dataset features](https://huggingface.co/docs/datasets/about_dataset_features)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['O', 'art', 'building', 'event', 'location', 'organization', 'other', 'person', 'product']\n" ] } ], "source": [ "labels = dataset[\"train\"].features[\"ner_tags\"].feature.names\n", "print(labels)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['O', 'art-broadcastprogram', 'art-film', 'art-music', 'art-other', 'art-painting', 'art-writtenart', 'building-airport', 'building-hospital', 'building-hotel', 'building-library', 'building-other', 'building-restaurant', 'building-sportsfacility', 'building-theater', 'event-attack/battle/war/militaryconflict', 'event-disaster', 'event-election', 'event-other', 'event-protest', 'event-sportsevent', 'location-GPE', 'location-bodiesofwater', 'location-island', 'location-mountain', 'location-other', 'location-park', 'location-road/railway/highway/transit', 'organization-company', 'organization-education', 'organization-government/governmentagency', 'organization-media/newspaper', 'organization-other', 'organization-politicalparty', 'organization-religion', 'organization-showorganization', 'organization-sportsleague', 'organization-sportsteam', 'other-astronomything', 'other-award', 'other-biologything', 'other-chemicalthing', 'other-currency', 'other-disease', 'other-educationaldegree', 'other-god', 'other-language', 'other-law', 'other-livingthing', 'other-medical', 'person-actor', 'person-artist/author', 'person-athlete', 'person-director', 'person-other', 'person-politician', 'person-scholar', 'person-soldier', 'product-airplane', 'product-car', 'product-food', 'product-game', 'product-other', 'product-ship', 'product-software', 'product-train', 'product-weapon']\n" ] } ], "source": [ "fine_labels = dataset[\"train\"].features[\"fine_ner_tags\"].feature.names\n", "print(fine_labels)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "For the purposes of this tutorial, let's stick with the `ner_tags` coarse-grained labels, but I challenge you to modify this Notebook to train for the fine labels. For the SpanMarker model, any dataset can be used as long as it has a `tokens` and a `ner_tags` column. The `ner_tags` can be annotated using the IOB, IOB2, BIOES or BILOU labeling scheme, but also regular unschemed labels like in this FewNERD example can be used." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Initializing a `SpanMarkerModel`\n", "A SpanMarker model is initialized via [SpanMarkerModel.from_pretrained](https://tomaarsen.github.io/SpanMarkerNER/api/span_marker.modeling.html#span_marker.modeling.SpanMarkerModel.from_pretrained). This method will be familiar to those who know 🤗 Transformers. It accepts either a path to a local model or the name of a model on the [Hugging Face Hub](https://huggingface.co/models).\n", "\n", "Importantly, the model can *either* be an encoder or an already trained and saved SpanMarker model. As we haven't trained anything yet, we will use an encoder. To learn how to load and use a saved SpanMarker model, please have a look at the [Loading & Inferencing](model_loading.ipynb) notebook.\n", "\n", "Reasonable options for encoders include BERT, RoBERTa, mBERT, XLM-RoBERTa, etc., which means that the following are all good options:\n", "\n", "* [prajjwal1/bert-tiny](https://huggingface.co/prajjwal1/bert-tiny)\n", "* [prajjwal1/bert-mini](https://huggingface.co/prajjwal1/bert-mini)\n", "* [prajjwal1/bert-small](https://huggingface.co/prajjwal1/bert-small)\n", "* [prajjwal1/bert-medium](https://huggingface.co/prajjwal1/bert-medium)\n", "* [bert-base-cased](https://huggingface.co/bert-base-cased)\n", "* [bert-large-cased](https://huggingface.co/bert-large-cased)\n", "* [bert-base-multilingual-cased](https://huggingface.co/bert-base-multilingual-cased)\n", "* [bert-base-multilingual-uncased](https://huggingface.co/bert-base-multilingual-uncased)\n", "* [roberta-base](https://huggingface.co/roberta-base)\n", "* [roberta-large](https://huggingface.co/roberta-large)\n", "* [xlm-roberta-base](https://huggingface.co/xlm-roberta-base)\n", "* [xlm-roberta-large](https://huggingface.co/xlm-roberta-large)\n", "\n", "Not all encoders work though, they **must** allow for `position_ids` as an input argument, which disqualifies DistilBERT, T5, DistilRoBERTa, ALBERT & BART. \n", "\n", "Additionally, it's important to consider that cased models typically demand consistent capitalization in the inference data, aligning with how the training data is formatted. In simpler terms, if your training data consistently uses correct capitalization, but your inference data does not, it may lead to suboptimal performance. In such cases, you might find an uncased model more suitable. Although it may exhibit slightly lower F1 scores on the testing set, it remains functional regardless of capitalization, making it potentially more effective in real-world scenarios.\n", "\n", "We'll use `\"bert-base-cased\"` for this notebook. If you're running this on Google Colab, be sure to set hardware accelerator to \"GPU\" in `Runtime` > `Change runtime type`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from span_marker import SpanMarkerModel, SpanMarkerModelCardData\n", "\n", "encoder_id = \"bert-base-cased\"\n", "model = SpanMarkerModel.from_pretrained(\n", " # Required arguments\n", " encoder_id,\n", " labels=labels,\n", " # Optional arguments\n", " model_max_length=256,\n", " entity_max_length=8,\n", " # To improve the generated model card\n", " model_card_data=SpanMarkerModelCardData(\n", " language=[\"en\"],\n", " license=\"cc-by-sa-4.0\",\n", " encoder_id=encoder_id,\n", " dataset_id=dataset_id,\n", " )\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "For us, these warnings are expected, as we are initializing `BertModel` for a new task.\n", "\n", "Note that we provided `SpanMarkerModel.from_pretrained` with a list of our labels. This is required when training a new model using an encoder. Furthermore, we can specify some useful configuration parameters from `SpanMarkerConfig`, such as:\n", "\n", "* `model_max_length`: The maximum number of tokens that the model will process. If you only use short sentences for your model, reducing this number may help training and inference speeds with no loss in performance. Defaults to the encoder maximum, or 512 if the encoder doesn't have a maximum.\n", "* `entity_max_length`: The total number of words that one entity can be. Defaults to 8.\n", "* `model_card_data`: A [SpanMarkerModelCardData](https://tomaarsen.github.io/SpanMarkerNER/api/span_marker.model_card.html#span_marker.model_card.SpanMarkerModelCardData) instance where you can provide a lot of useful data about your model. This data will be automatically included in a generated model card whenever a model is saved or pushed to the Hugging Face Hub.\n", " * Consider adding `language`, `license`, `model_id`, `encoder_id` and `dataset_id` to improve the generated model card README.md file." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Training\n", "At this point, our model is already ready for training! We can import [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) directly from 🤗 Transformers as well as the SpanMarker `Trainer`. The `Trainer` is a subclass of the 🤗 Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) that simplifies some tasks for you, but otherwise it works just like the regular `Trainer`.\n", "\n", "This next snippet shows some reasonable defaults. Feel free to adjust the batch size to a lower value if you experience out of memory exceptions." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "\n", "args = TrainingArguments(\n", " output_dir=\"models/span-marker-bert-base-fewnerd-coarse-super\",\n", " learning_rate=5e-5,\n", " gradient_accumulation_steps=2,\n", " per_device_train_batch_size=4,\n", " per_device_eval_batch_size=4,\n", " num_train_epochs=1,\n", " evaluation_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", " eval_steps=200,\n", " push_to_hub=False,\n", " logging_steps=50,\n", " fp16=True,\n", " warmup_ratio=0.1,\n", " dataloader_num_workers=2,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now we can create a SpanMarker [Trainer](https://tomaarsen.github.io/SpanMarkerNER/api/span_marker.trainer.html#span_marker.trainer.Trainer) in the same way that you would initialize a 🤗 Transformers `Trainer`. We'll train on a subsection of the data to save us some time. Amazingly, this `Trainer` will automatically create logs using exactly the logging tools that you have installed. With other words, if you prefer logging with [Tensorboard](https://www.tensorflow.org/tensorboard), all that you have to do is install it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from span_marker import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=args,\n", " train_dataset=dataset[\"train\"].select(range(8000)),\n", " eval_dataset=dataset[\"validation\"].select(range(2000)),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let's start training!" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 0.6974, 'learning_rate': 1.991869918699187e-05, 'epoch': 0.04}\n", "{'loss': 0.0896, 'learning_rate': 4.0243902439024395e-05, 'epoch': 0.08}\n", "{'loss': 0.0584, 'learning_rate': 4.8822463768115946e-05, 'epoch': 0.12}\n", "{'loss': 0.0382, 'learning_rate': 4.655797101449276e-05, 'epoch': 0.16}\n", "{'eval_loss': 0.03181104362010956, 'eval_overall_precision': 0.6967930029154519, 'eval_overall_recall': 0.5989974937343359, 'eval_overall_f1': 0.6442048517520216, 'eval_overall_accuracy': 0.8993717106605198, 'eval_runtime': 29.16, 'eval_samples_per_second': 83.985, 'eval_steps_per_second': 21.022, 'epoch': 0.16}\n", "{'loss': 0.0333, 'learning_rate': 4.429347826086957e-05, 'epoch': 0.2}\n", "{'loss': 0.0303, 'learning_rate': 4.202898550724638e-05, 'epoch': 0.24}\n", "{'loss': 0.032, 'learning_rate': 3.976449275362319e-05, 'epoch': 0.29}\n", "{'loss': 0.0304, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.33}\n", "{'eval_loss': 0.02394717186689377, 'eval_overall_precision': 0.7350157728706624, 'eval_overall_recall': 0.7187198766146135, 'eval_overall_f1': 0.7267764889365436, 'eval_overall_accuracy': 0.9227489698502713, 'eval_runtime': 29.481, 'eval_samples_per_second': 83.07, 'eval_steps_per_second': 20.793, 'epoch': 0.33}\n", "{'loss': 0.0265, 'learning_rate': 3.5235507246376816e-05, 'epoch': 0.37}\n", "{'loss': 0.0254, 'learning_rate': 3.297101449275363e-05, 'epoch': 0.41}\n", "{'loss': 0.0249, 'learning_rate': 3.0706521739130435e-05, 'epoch': 0.45}\n", "{'loss': 0.0242, 'learning_rate': 2.8442028985507245e-05, 'epoch': 0.49}\n", "{'eval_loss': 0.02163967303931713, 'eval_overall_precision': 0.762808736476832, 'eval_overall_recall': 0.7204549836128783, 'eval_overall_f1': 0.7410271663692247, 'eval_overall_accuracy': 0.9293582473175309, 'eval_runtime': 29.0261, 'eval_samples_per_second': 84.372, 'eval_steps_per_second': 21.119, 'epoch': 0.49}\n", "{'loss': 0.0224, 'learning_rate': 2.6177536231884058e-05, 'epoch': 0.53}\n", "{'loss': 0.0242, 'learning_rate': 2.391304347826087e-05, 'epoch': 0.57}\n", "{'loss': 0.0226, 'learning_rate': 2.164855072463768e-05, 'epoch': 0.61}\n", "{'loss': 0.0245, 'learning_rate': 1.9384057971014493e-05, 'epoch': 0.65}\n", "{'eval_loss': 0.020556513220071793, 'eval_overall_precision': 0.7680876026593665, 'eval_overall_recall': 0.7572778099093889, 'eval_overall_f1': 0.7626444034559751, 'eval_overall_accuracy': 0.9338052303047611, 'eval_runtime': 29.7545, 'eval_samples_per_second': 82.307, 'eval_steps_per_second': 20.602, 'epoch': 0.65}\n", "{'loss': 0.0231, 'learning_rate': 1.7119565217391306e-05, 'epoch': 0.69}\n", "{'loss': 0.0209, 'learning_rate': 1.4855072463768116e-05, 'epoch': 0.73}\n", "{'loss': 0.0202, 'learning_rate': 1.2590579710144929e-05, 'epoch': 0.77}\n", "{'loss': 0.0212, 'learning_rate': 1.032608695652174e-05, 'epoch': 0.81}\n", "{'eval_loss': 0.01960749179124832, 'eval_overall_precision': 0.7743021183923976, 'eval_overall_recall': 0.7540003855793329, 'eval_overall_f1': 0.7640164094549716, 'eval_overall_accuracy': 0.9358247317530904, 'eval_runtime': 29.6794, 'eval_samples_per_second': 82.515, 'eval_steps_per_second': 20.654, 'epoch': 0.81}\n", "{'loss': 0.0202, 'learning_rate': 8.061594202898551e-06, 'epoch': 0.86}\n", "{'loss': 0.0196, 'learning_rate': 5.797101449275362e-06, 'epoch': 0.9}\n", "{'loss': 0.0232, 'learning_rate': 3.5326086956521736e-06, 'epoch': 0.94}\n", "{'loss': 0.0183, 'learning_rate': 1.2681159420289857e-06, 'epoch': 0.98}\n", "{'eval_loss': 0.019303549081087112, 'eval_overall_precision': 0.7719162141194724, 'eval_overall_recall': 0.7673028725660305, 'eval_overall_f1': 0.769602629797931, 'eval_overall_accuracy': 0.9378442332014197, 'eval_runtime': 29.1715, 'eval_samples_per_second': 83.952, 'eval_steps_per_second': 21.014, 'epoch': 0.98}\n", "{'train_runtime': 450.609, 'train_samples_per_second': 21.788, 'train_steps_per_second': 2.723, 'train_loss': 0.056268237500824186, 'epoch': 1.0}\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=1227, training_loss=0.056268237500824186, metrics={'train_runtime': 450.609, 'train_samples_per_second': 21.788, 'train_steps_per_second': 2.723, 'train_loss': 0.056268237500824186, 'epoch': 1.0})" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "And now the final step is to compute the model's performance." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.019375888630747795,\n", " 'eval_art': {'precision': 0.7661290322580645,\n", " 'recall': 0.7723577235772358,\n", " 'f1': 0.7692307692307692,\n", " 'number': 246},\n", " 'eval_building': {'precision': 0.5842293906810035,\n", " 'recall': 0.6127819548872181,\n", " 'f1': 0.5981651376146789,\n", " 'number': 266},\n", " 'eval_event': {'precision': 0.5497382198952879,\n", " 'recall': 0.5965909090909091,\n", " 'f1': 0.5722070844686648,\n", " 'number': 176},\n", " 'eval_location': {'precision': 0.8036732108929703,\n", " 'recall': 0.8409542743538767,\n", " 'f1': 0.8218911917098446,\n", " 'number': 1509},\n", " 'eval_organization': {'precision': 0.7474226804123711,\n", " 'recall': 0.6998069498069498,\n", " 'f1': 0.7228315054835494,\n", " 'number': 1036},\n", " 'eval_other': {'precision': 0.6775818639798489,\n", " 'recall': 0.5604166666666667,\n", " 'f1': 0.61345496009122,\n", " 'number': 480},\n", " 'eval_person': {'precision': 0.8636363636363636,\n", " 'recall': 0.9063313096270599,\n", " 'f1': 0.8844688954718578,\n", " 'number': 1153},\n", " 'eval_product': {'precision': 0.7366666666666667,\n", " 'recall': 0.6884735202492211,\n", " 'f1': 0.7117552334943639,\n", " 'number': 321},\n", " 'eval_overall_precision': 0.7705836876691148,\n", " 'eval_overall_recall': 0.7686524002313476,\n", " 'eval_overall_f1': 0.7696168323520897,\n", " 'eval_overall_accuracy': 0.9381502182693484,\n", " 'eval_runtime': 28.5583,\n", " 'eval_samples_per_second': 85.754,\n", " 'eval_steps_per_second': 21.465,\n", " 'epoch': 1.0}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metrics = trainer.evaluate()\n", "metrics" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Additionally, we should evaluate using the test set." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "This SpanMarker model won't be able to predict 0.285605% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 8 words.\n", "These are the frequencies of the missed entities due to maximum entity length out of 5252 total entities:\n", "- 5 missed entities with 9 words (0.095202%)\n", "- 1 missed entities with 10 words (0.019040%)\n", "- 3 missed entities with 11 words (0.057121%)\n", "- 2 missed entities with 12 words (0.038081%)\n", "- 1 missed entities with 13 words (0.019040%)\n", "- 1 missed entities with 17 words (0.019040%)\n", "- 1 missed entities with 19 words (0.019040%)\n", "- 1 missed entities with 40 words (0.019040%)\n" ] }, { "data": { "text/plain": [ "{'test_loss': 0.01918497122824192,\n", " 'test_art': {'precision': 0.7419354838709677,\n", " 'recall': 0.7488372093023256,\n", " 'f1': 0.7453703703703703,\n", " 'number': 215},\n", " 'test_building': {'precision': 0.6236559139784946,\n", " 'recall': 0.710204081632653,\n", " 'f1': 0.6641221374045801,\n", " 'number': 245},\n", " 'test_event': {'precision': 0.6153846153846154,\n", " 'recall': 0.5529953917050692,\n", " 'f1': 0.5825242718446603,\n", " 'number': 217},\n", " 'test_location': {'precision': 0.812192118226601,\n", " 'recall': 0.8515171078114913,\n", " 'f1': 0.8313898518751971,\n", " 'number': 1549},\n", " 'test_organization': {'precision': 0.7320754716981132,\n", " 'recall': 0.6897777777777778,\n", " 'f1': 0.7102974828375286,\n", " 'number': 1125},\n", " 'test_other': {'precision': 0.7375886524822695,\n", " 'recall': 0.6328600405679513,\n", " 'f1': 0.6812227074235807,\n", " 'number': 493},\n", " 'test_person': {'precision': 0.8805309734513275,\n", " 'recall': 0.9061930783242259,\n", " 'f1': 0.8931777378815081,\n", " 'number': 1098},\n", " 'test_product': {'precision': 0.6641221374045801,\n", " 'recall': 0.5898305084745763,\n", " 'f1': 0.6247755834829445,\n", " 'number': 295},\n", " 'test_overall_precision': 0.7766859344894027,\n", " 'test_overall_recall': 0.7697154859652473,\n", " 'test_overall_f1': 0.7731850004795243,\n", " 'test_overall_accuracy': 0.938954021816699,\n", " 'test_runtime': 29.8808,\n", " 'test_samples_per_second': 81.658,\n", " 'test_steps_per_second': 20.414,\n", " 'epoch': 1.0}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(dataset[\"test\"].select(range(2000)), metric_key_prefix=\"test\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let's try the model out with some predictions. For this we can use the `model.predict` method, which accepts either:\n", "\n", "* A sentence as a string.\n", "* A tokenized sentence as a list of strings.\n", "* A list of sentences as a list of strings.\n", "* A list of tokenized sentences as a list of lists of strings.\n", "\n", "The method returns a list of dictionaries for each sentence, with the following keys:\n", "\n", "* `\"label\"`: The string label for the found entity.\n", "* `\"score\"`: The probability score indicating the model its confidence.\n", "* `\"span\"`: The entity span as a string.\n", "* `\"word_start_index\"` and `\"word_end_index\"`: Integers useful for indexing the entity from a tokenized sentence.\n", "* `\"char_start_index\"` and `\"char_end_index\"`: Integers useful for indexing the entity from a string sentence." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Battle of Camulodunum => event\n", "Quintus Petillius Cerialis => person\n", "Boudica => location\n", "Camulodunum => location\n", "Colchester => location\n", "\n", "Wellingborough => location\n", "Northamptonshire => location\n", "Victoria Junior School => organization\n", "Westfield Boys School => organization\n", "Sir Christopher Hatton School => organization\n", "\n", "Nintendo => organization\n", "Wii => product\n", "Wii Mini => product\n", "Wii U => product\n", "Wii U => product\n", "\n", "Dorsa => person\n", "Bachelor of Music in Composition => other\n", "California State University => organization\n", "Northridge => location\n", "Master of Music in Harpsichord Performance => other\n", "Cal State Northridge => organization\n", "Doctor of Musical Arts => other\n", "University of Michigan => organization\n", "Ann Arbor => location" ] } ], "source": [ "sentences = [\n", " \"The Ninth suffered a serious defeat at the Battle of Camulodunum under Quintus Petillius Cerialis in the rebellion of Boudica (61), when most of the foot-soldiers were killed in a disastrous attempt to relieve the besieged city of Camulodunum (Colchester).\",\n", " \"He was born in Wellingborough, Northamptonshire, where he attended Victoria Junior School, Westfield Boys School and Sir Christopher Hatton School.\",\n", " \"Nintendo continued to sell the revised Wii model and the Wii Mini alongside the Wii U during the Wii U's first release year.\",\n", " \"Dorsa has a Bachelor of Music in Composition from California State University, Northridge in 2001, Master of Music in Harpsichord Performance at Cal State Northridge in 2004, and a Doctor of Musical Arts at the University of Michigan, Ann Arbor in 2008.\"\n", "]\n", "\n", "entities_per_sentence = model.predict(sentences)\n", "\n", "for entities in entities_per_sentence:\n", " for entity in entities:\n", " print(entity[\"span\"], \"=>\", entity[\"label\"])\n", " print()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Very impressive performance for less than 8 minutes trained! 🎉" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Once trained, we can save our new model locally. The saved model also comes with a flashy `README.md` such as [this one](https://huggingface.co/tomaarsen/span-marker-bert-base-uncased-bionlp)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "trainer.save_model(\"models/span-marker-bert-base-fewnerd-coarse-super/checkpoint-final\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Or we can push it to the 🤗 Hub like so." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.push_to_hub(repo_id=\"span-marker-bert-base-fewnerd-coarse-super\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If we want to use it again, we can just load it using the checkpoint or using the model name on the Hub. This is how it would be done using a local checkpoint." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# model = SpanMarkerModel.from_pretrained(\"models/span-marker-bert-base-fewnerd-coarse-super/checkpoint-final\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "That was all! As simple as that. If we put it all together into a single script, it looks something like this:\n", "```python\n", "from datasets import load_dataset\n", "from span_marker import SpanMarkerModel, SpanMarkerModelCardData, Trainer\n", "from transformers import TrainingArguments\n", "\n", "def main():\n", " dataset_id = \"DFKI-SLT/few-nerd\"\n", " dataset = load_dataset(dataset_id, \"supervised\")\n", " labels = dataset[\"train\"].features[\"ner_tags\"].feature.names\n", "\n", " encoder_id = \"bert-base-cased\"\n", " model = SpanMarkerModel.from_pretrained(\n", " # Required arguments\n", " encoder_id,\n", " labels=labels,\n", " # Optional arguments\n", " model_max_length=256,\n", " entity_max_length=8,\n", " # To improve the generated model card\n", " model_card_data=SpanMarkerModelCardData(\n", " language=[\"en\"],\n", " license=\"cc-by-sa-4.0\",\n", " encoder_id=encoder_id,\n", " dataset_id=dataset_id,\n", " )\n", " )\n", "\n", " args = TrainingArguments(\n", " output_dir=\"models/span-marker-bert-base-fewnerd-coarse-super\",\n", " learning_rate=5e-5,\n", " gradient_accumulation_steps=2,\n", " per_device_train_batch_size=4,\n", " per_device_eval_batch_size=4,\n", " num_train_epochs=1,\n", " evaluation_strategy=\"steps\",\n", " save_strategy=\"steps\",\n", " eval_steps=200,\n", " push_to_hub=False,\n", " logging_steps=50,\n", " fp16=True,\n", " warmup_ratio=0.1,\n", " dataloader_num_workers=2,\n", " )\n", "\n", " trainer = Trainer(\n", " model=model,\n", " args=args,\n", " train_dataset=dataset[\"train\"].select(range(8000)),\n", " eval_dataset=dataset[\"validation\"].select(range(2000)),\n", " )\n", " trainer.train()\n", "\n", " metrics = trainer.evaluate()\n", " print(metrics)\n", "\n", " trainer.save_model(\"models/span-marker-bert-base-fewnerd-coarse-super/checkpoint-final\")\n", "\n", "if __name__ == \"__main__\":\n", " main()\n", "```" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "With `wandb` initialized, you can enjoy their very useful training graphs straight in your browser. It ends up looking something like this.\n", "![image](https://user-images.githubusercontent.com/37621491/235196250-15d595f4-6d72-4625-bde9-f3783484997b.png)\n", "![image](https://user-images.githubusercontent.com/37621491/235196335-6f36a7fb-5274-4ce5-a1f3-1d2ad35b26a4.png)\n", "\n", "Furthermore, you can use the `wandb` hyperparameter search functionality using the tutorial from the Hugging Face documentation [here](https://huggingface.co/docs/transformers/hpo_train). This transfers very well to the SpanMarker `Trainer`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "span-marker-ner", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "c231fc6d0de0df4a232423539031d78e3a72f0f8d848d7b948e520fe3bfbe8ca" } } }, "nbformat": 4, "nbformat_minor": 2 }