Source code for span_marker.modeling

import logging
import os
import re
from typing import Callable, Dict, List, Optional, Type, TypeVar, Union

import torch
import torch.nn.functional as F
from datasets import Dataset, disable_progress_bar, enable_progress_bar
from packaging.version import Version, parse
from torch import device, nn
from tqdm.autonotebook import trange
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from typing_extensions import Self

from span_marker import __version__ as span_marker_version
from span_marker.configuration import SpanMarkerConfig
from span_marker.data_collator import SpanMarkerDataCollator
from span_marker.model_card import SpanMarkerModelCardData, generate_model_card
from span_marker.output import SpanMarkerOutput
from span_marker.tokenizer import SpanMarkerTokenizer

logger = logging.getLogger(__name__)

T = TypeVar("T", bound="SpanMarkerModel")

UNEXPECTED_KEYWORD_PATTERN = re.compile(r"\S+ got an unexpected keyword argument '([^']*)'")


[docs] class SpanMarkerModel(PreTrainedModel): """ This SpanMarker model allows for Named Entity Recognition (NER) using a variety of underlying encoders, such as BERT and RoBERTa. The model should be initialized using :meth:`~SpanMarkerModel.from_pretrained`, e.g. like so: >>> # Initialize a SpanMarkerModel using a pretrained encoder >>> model = SpanMarkerModel.from_pretrained("bert-base-cased", labels=["O", "B-PER", "I-PER", "B-ORG", "I-ORG", ...]) >>> # Load a pretrained SpanMarker model >>> model = SpanMarkerModel.from_pretrained("tomaarsen/span-marker-bert-base-fewnerd-fine-super") After the model is loaded (and finetuned if it wasn't already), it can be used to predict entities: >>> model.predict("A prototype was fitted in the mid-'60s in a one-off DB5 extended 4'' after the doors and " ... "driven by Marek personally, and a normally 6-cylinder Aston Martin DB7 was equipped with a V8 unit in 1998.") [{'span': 'DB5', 'label': 'product-car', 'score': 0.8675689101219177, 'char_start_index': 52, 'char_end_index': 55}, {'span': 'Marek', 'label': 'person-other', 'score': 0.9100819230079651, 'char_start_index': 99, 'char_end_index': 104}, {'span': 'Aston Martin DB7', 'label': 'product-car', 'score': 0.9931442737579346, 'char_start_index': 143, 'char_end_index': 159}] """ config_class = SpanMarkerConfig base_model_prefix = "encoder" _no_split_modules = [] # To support `load_in_8bit=True`` and `device_map="auto"` def __init__( self, config: SpanMarkerConfig, encoder: Optional[PreTrainedModel] = None, model_card_data: Optional[SpanMarkerModelCardData] = None, **kwargs, ) -> None: """Initialize a SpanMarkerModel using configuration. Do not manually initialize a SpanMarkerModel this way! Use :meth:`~SpanMarkerModel.from_pretrained` instead. Args: config (SpanMarkerConfig): The configuration for this model. encoder (Optional[PreTrainedModel]): A PreTrainedModel acting as the underlying encoder. Defaults to None. """ super().__init__(config) self.config = config # `encoder` will be specified if this Model is initializer via .from_pretrained with an encoder # If .from_pretrained is called with a SpanMarkerModel instance, then we use the "traditional" # PreTrainedModel.from_pretrained, which won't include an encoder keyword argument. In that case, # we must create an "empty" encoder for PreTrainedModel.from_pretrained to fill with the correct # weights. if encoder is None: # Load the encoder via the Config to prevent having to use AutoModel.from_pretrained, which # could load e.g. all of `roberta-large` from the Hub unnecessarily. # However, use the SpanMarkerModel updated vocab_size encoder_config = AutoConfig.from_pretrained(self.config.encoder["_name_or_path"], **self.config.encoder) encoder = AutoModel.from_config(encoder_config) self.encoder = encoder dropout_rate = self.config.get(["hidden_dropout_prob", "dropout_rate"], default=0.1) if dropout_rate: self.dropout = nn.Dropout(dropout_rate) else: self.dropout = nn.Identity() # TODO: Get a less arbitrary default hidden_size = self.config.get("hidden_size", default=768) self.classifier = nn.Linear(hidden_size * 2, self.config.num_labels) self.loss_func = nn.CrossEntropyLoss() # tokenizer and data collator are filled using set_tokenizer self.tokenizer = None self.data_collator = None self.model_card_data = model_card_data or SpanMarkerModelCardData() self.model_card_data.register_model(self) # Initialize weights and apply final processing self.post_init() def set_tokenizer(self, tokenizer: SpanMarkerTokenizer) -> None: self.tokenizer = tokenizer self.data_collator = SpanMarkerDataCollator( tokenizer=tokenizer, marker_max_length=self.config.marker_max_length ) def _init_weights(self, module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 initializer_range = self.config.get("initializer_range", default=0.02) module.weight.data.normal_(mean=0.0, std=initializer_range) elif isinstance(module, torch.nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_()
[docs] def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, start_marker_indices: torch.Tensor, num_marker_pairs: torch.Tensor, labels: Optional[torch.Tensor] = None, num_words: Optional[torch.Tensor] = None, document_ids: Optional[torch.Tensor] = None, sentence_ids: Optional[torch.Tensor] = None, **kwargs, ) -> SpanMarkerOutput: """Forward call of the SpanMarkerModel. Args: input_ids (~torch.Tensor): Input IDs including start/end markers. attention_mask (~torch.Tensor): Attention mask matrix including one-directional attention for markers. position_ids (~torch.Tensor): Position IDs including start/end markers. start_marker_indices (~torch.Tensor): The indices where the start markers begin per batch sample. num_marker_pairs (~torch.Tensor): The number of start/end marker pairs per batch sample. labels (Optional[~torch.Tensor]): The labels for each span candidate. Defaults to None. num_words (Optional[~torch.Tensor]): The number of words for each batch sample. Defaults to None. document_ids (Optional[~torch.Tensor]): The document ID of each batch sample. Defaults to None. sentence_ids (Optional[~torch.Tensor]): The index of each sentence in their respective document. Defaults to None. Returns: SpanMarkerOutput: The output dataclass. """ token_type_ids = torch.zeros_like(input_ids) outputs = self.encoder( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, ) last_hidden_state = outputs[0] last_hidden_state = self.dropout(last_hidden_state) batch_size = last_hidden_state.size(0) sequence_length = last_hidden_state.size(1) # Get the indices where the end markers start end_marker_indices = start_marker_indices + num_marker_pairs # The start marker embeddings concatenated with the end marker embeddings. # This is kind of breaking the cardinal rule of GPU-based ML, as this is processing # the batch iteratively per sample, but every sample produces a different shape matrix # and this is the most convenient way to recombine them into a matrix. embeddings = [] for i in range(batch_size): embeddings.append( torch.cat( ( last_hidden_state[i, start_marker_indices[i] : end_marker_indices[i]], last_hidden_state[i, end_marker_indices[i] : end_marker_indices[i] + num_marker_pairs[i]], ), dim=-1, ) ) padded_embeddings = [ F.pad(embedding, (0, 0, 0, sequence_length // 2 - embedding.shape[0])) for embedding in embeddings ] feature_vector = torch.stack(padded_embeddings) # NOTE: This was wrong in the older tests feature_vector = self.dropout(feature_vector) logits = self.classifier(feature_vector) if labels is not None: loss = self.loss_func(logits.view(-1, self.config.num_labels), labels.view(-1)) return SpanMarkerOutput( loss=loss if labels is not None else None, logits=logits, *outputs[2:], num_marker_pairs=num_marker_pairs, num_words=num_words, document_ids=document_ids, sentence_ids=sentence_ids, )
[docs] @classmethod def from_pretrained( cls: Type[T], pretrained_model_name_or_path: Union[str, os.PathLike], *model_args, labels: Optional[List[str]] = None, config: Optional[SpanMarkerConfig] = None, model_card_data: Optional[SpanMarkerModelCardData] = None, **kwargs, ) -> T: """Instantiate a pretrained pytorch model from a pre-trained model configuration. Example: >>> # Initialize a SpanMarkerModel using a pretrained encoder >>> model = SpanMarkerModel.from_pretrained("bert-base-cased", labels=["O", "B-PER", "I-PER", "B-ORG", "I-ORG", ...]) >>> # Load a pretrained SpanMarker model >>> model = SpanMarkerModel.from_pretrained("tomaarsen/span-marker-bert-base-fewnerd-fine-super") Args: pretrained_model_name_or_path (Union[str, os.PathLike]): Either a pretrained encoder (e.g. ``bert-base-cased``, ``roberta-large``, etc.), or a pretrained SpanMarkerModel. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a *directory* containing model weights saved using :meth:`SpanMarkerModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a *tensorflow index checkpoint file* (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_tf`` should be set to ``True`` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, ``./flax_model/`` containing ``flax_model.msgpack``). In this case, ``from_flax`` should be set to ``True``. labels (List[str], optional): A list of string labels corresponding to the ``ner_tags`` in your datasets. Only necessary when loading a SpanMarker model using a pretrained encoder. Defaults to None. Additional arguments are passed to :class:`~span_marker.configuration.SpanMarkerConfig` and the ``from_pretrained`` methods of :class:`~transformers.AutoConfig`, :class:`~transformers.AutoModel` and :class:`~span_marker.tokenizer.SpanMarkerTokenizer`. Returns: SpanMarkerModel: A :class:`SpanMarkerModel` instance, either ready for training using the :class:`Trainer` or\ for inference via :meth:`SpanMarkerModel.predict`. """ # If loading a SpanMarkerConfig, then we don't want to override id2label and label2id # Create an encoder or SpanMarker config config: PretrainedConfig = config or AutoConfig.from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) # if 'pretrained_model_name_or_path' refers to a SpanMarkerModel instance, initialize it directly loading_span_marker = isinstance(config, cls.config_class) if loading_span_marker: model_span_marker_version = config.get("span_marker_version") or "0.1.0" if parse(model_span_marker_version) < Version("1.0.0.dev"): logger.warning( f"Loading a model trained using SpanMarker v{model_span_marker_version}," f" while SpanMarker v{span_marker_version} is installed. Due to large changes" " introduced in v1.0.0, this is not recommended. Either retrain your model for" f" v{span_marker_version}, or install `span_marker < 1.0.0`." ) model = super().from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **kwargs, model_card_data=model_card_data ) # If 'pretrained_model_name_or_path' refers to an encoder (roberta, bert, distilbert, electra, etc.), # then initialize it and create the SpanMarker config and model using the encoder and its config. else: encoder = SpanMarkerModel._load_encoder_with_kwargs( pretrained_model_name_or_path, config, *model_args, **kwargs.copy() ) if labels is None: raise ValueError( "Please provide a `labels` list to `SpanMarkerModel.from_pretrained()`, e.g.\n" ">>> SpanMarkerModel.from_pretrained(\n" f'... "{pretrained_model_name_or_path}",\n' '... labels=["O", "B-PER", "I-PER", "B-ORG", "I-ORG", ...]\n' "... )\n" "or\n" ">>> SpanMarkerModel.from_pretrained(\n" f'... "{pretrained_model_name_or_path}",\n' '... labels=["O", "PER", "ORG", "LOC", "MISC"]\n' "... )" ) config.id2label = dict(enumerate(labels)) config.label2id = {v: k for k, v in config.id2label.items()} # Set the span_marker version for freshly initialized models config = cls.config_class( encoder_config=config.to_dict(), span_marker_version=span_marker_version, **kwargs ) model = cls(config, encoder, *model_args, **kwargs, model_card_data=model_card_data) # Pass the tokenizer directly to the model for convenience, this way the user doesn't have to # make it themselves. tokenizer = SpanMarkerTokenizer.from_pretrained( config.encoder.get("_name_or_path", pretrained_model_name_or_path), config=config, **kwargs ) model.set_tokenizer(tokenizer) # Since transformers 4.32.0 we should use `pad_to_multiple_of=8`. # That'll fail for earlier versions, so we try-except it. try: model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) except TypeError: model.resize_token_embeddings(len(tokenizer)) return model
@classmethod def _load_encoder_with_kwargs( cls, pretrained_model_name_or_path: str, config: PretrainedConfig, *model_args, **kwargs ) -> PreTrainedModel: """Load an underlying encoder using keyword arguments, even if those kwargs are not (all) supported. Args: pretrained_model_name_or_path (str): The model name or path, e.g. ``bert-base-cased``. config (PretrainedConfig): The config corresponding with the encoder. Returns: PreTrainedModel: The loaded encoder. """ try: return AutoModel.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **kwargs, ) except TypeError as exc: # If we can find a specific keyword that is causing issues, remove it and try again match = UNEXPECTED_KEYWORD_PATTERN.search(str(exc)) if match: problematic_keyword = match.group(1) if problematic_keyword in kwargs: kwargs.pop(problematic_keyword) return SpanMarkerModel._load_encoder_with_kwargs( pretrained_model_name_or_path, config, *model_args, **kwargs ) # Otherwise, just raise the exception raise exc
[docs] def predict( self, inputs: Union[str, List[str], List[List[str]], Dataset], batch_size: int = 4, show_progress_bar: bool = False, ) -> Union[List[Dict[str, Union[str, int, float]]], List[List[Dict[str, Union[str, int, float]]]]]: """Predict named entities from input texts. Example:: >>> model = SpanMarkerModel.from_pretrained(...) >>> model.predict("Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.") [{'span': 'Amelia Earhart', 'label': 'person-other', 'score': 0.7629689574241638, 'char_start_index': 0, 'char_end_index': 14}, {'span': 'Lockheed Vega 5B', 'label': 'product-airplane', 'score': 0.9833564758300781, 'char_start_index': 38, 'char_end_index': 54}, {'span': 'Atlantic', 'label': 'location-bodiesofwater', 'score': 0.7621214389801025, 'char_start_index': 66, 'char_end_index': 74}, {'span': 'Paris', 'label': 'location-GPE', 'score': 0.9807717204093933, 'char_start_index': 78, 'char_end_index': 83}] >>> model.predict(['Caesar', 'led', 'the', 'Roman', 'armies', 'in', 'the', 'Gallic', 'Wars', 'before', 'defeating', 'his', 'political', 'rival', 'Pompey', 'in', 'a', 'civil', 'war']) [{'span': ['Caesar'], 'label': 'person-politician', 'score': 0.683479905128479, 'word_start_index': 0, 'word_end_index': 1}, {'span': ['Roman'], 'label': 'location-GPE', 'score': 0.7114525437355042, 'word_start_index': 3, 'word_end_index': 4}, {'span': ['Gallic', 'Wars'], 'label': 'event-attack/battle/war/militaryconflict', 'score': 0.9015670418739319, 'word_start_index': 7, 'word_end_index': 9}, {'span': ['Pompey'], 'label': 'person-politician', 'score': 0.9601260423660278, 'word_start_index': 14, 'word_end_index': 15}] Args: inputs (Union[str, List[str], List[List[str]], Dataset]): Input sentences from which to extract entities. Valid datastructures are: * str: a string sentence. * List[str]: a pre-tokenized string sentence, i.e. a list of words. * List[str]: a list of multiple string sentences. * List[List[str]]: a list of multiple pre-tokenized string sentences, i.e. a list with lists of words. * Dataset: A 🤗 :class:`~datasets.Dataset` with a ``tokens`` column and optionally ``document_id`` and ``sentence_id`` columns. If the optional columns are provided, they will be used to provide document-level context. batch_size (int): The number of samples to include in a batch, a higher batch size is faster, but requires more memory. Defaults to 4 show_progress_bar (bool): Whether to show a progress bar, useful for longer inputs. Defaults to `False`. Returns: Union[List[Dict[str, Union[str, int, float]]], List[List[Dict[str, Union[str, int, float]]]]]: If the input is a single sentence, then we output a list of dictionaries. Each dictionary represents one predicted entity, and contains the following keys: * ``label``: The predicted entity label. * ``span``: The text that the model deems an entity. * ``score``: The model its confidence. * ``word_start_index`` & ``word_end_index``: The word indices for the start/end of the entity, if the input is pre-tokenized. * ``char_start_index`` & ``char_end_index``: The character indices for the start/end of the entity, if the input is a string. If the input is multiple sentences, then we return a list containing multiple of the aforementioned lists. """ from span_marker.trainer import Trainer if torch.cuda.is_available() and self.device == torch.device("cpu"): logger.warning( "SpanMarker model predictions are being computed on the CPU while CUDA is available." " Moving the model to CUDA using `model.cuda()` before performing predictions is heavily" " recommended to significantly boost prediction speeds.", ) # Disable dropout, etc. self.eval() if not inputs: return [] # Track whether the input was a string sentence or a list of tokens single_input = False # Check if inputs is a string, i.e. a string sentence, or # if it is a list of strings without spaces, i.e. if it's 1 tokenized sentence if isinstance(inputs, str) or ( isinstance(inputs, list) and all(isinstance(element, str) and " " not in element for element in inputs) ): single_input = True dataset = Dataset.from_dict({"tokens": [inputs]}) # Otherwise, we likely have a list of strings, i.e. a list of string sentences, # or a list of lists of strings, i.e. a list of tokenized sentences # if isinstance(inputs, list) and all(isinstance(element, str) and " " not in element for element in inputs): # return [self._predict_one(sentence) for sentence in inputs] elif isinstance(inputs, list): dataset = Dataset.from_dict({"tokens": inputs}) elif isinstance(inputs, Dataset): dataset = inputs else: raise ValueError( "`SpanMarkerModel.predict` could not recognize your input. It accepts the following:\n" "* str: a string sentence.\n" "* List[str]: a pre-tokenized string sentence, i.e. a list of words.\n" "* List[str]: a list of multiple string sentences.\n" "* List[List[str]]: a list of multiple pre-tokenized string sentences, i.e. a list with lists of words.\n" "* Dataset: A 🤗 Dataset with `tokens` column and optionally `document_id` and `sentence_id` columns.\n" " If the optional columns are provided, they will be used to provide document-level context." ) dataset = dataset.remove_columns(set(dataset.column_names) - {"tokens", "document_id", "sentence_id"}) num_inputs = len(dataset) dataset: Dataset = dataset.add_column("id", range(num_inputs)) results = [ { "tokens": tokens, "scores": [], "labels": [], "num_words": None, } for tokens in dataset["tokens"] ] # Tokenize & add start/end markers tokenizer_dict = self.tokenizer( {"tokens": dataset["tokens"]}, return_num_words=True, return_batch_encoding=True ) batch_encoding = tokenizer_dict.pop("batch_encoding") dataset = dataset.remove_columns("tokens") for key, value in tokenizer_dict.items(): dataset = dataset.add_column(key, value) # Add context if possible if {"document_id", "sentence_id"} <= set(dataset.column_names): if not self.config.trained_with_document_context: logger.warning( "This model was trained without document-level context: " "inference with document-level context may cause decreased performance." ) # Add column to be able to revert sorting later dataset = dataset.add_column("__sort_id", range(len(dataset))) # Sorting by doc ID and then sentence ID is required for add_context dataset = dataset.sort(column_names=["document_id", "sentence_id"]) dataset = Trainer.add_context( dataset, self.tokenizer.model_max_length, max_prev_context=self.config.max_prev_context, max_next_context=self.config.max_next_context, show_progress_bar=show_progress_bar, ) dataset = dataset.sort(column_names=["__sort_id"]) dataset = dataset.remove_columns("__sort_id") elif self.config.trained_with_document_context: logger.warning( "This model was trained with document-level context: " "inference without document-level context may cause decreased performance." ) if not show_progress_bar: disable_progress_bar() dataset = dataset.map( Trainer.spread_sample, batched=True, desc="Spreading data between multiple samples", fn_kwargs={ "model_max_length": self.tokenizer.model_max_length, "marker_max_length": self.config.marker_max_length, }, ) if not show_progress_bar: enable_progress_bar() for batch_start_idx in trange(0, len(dataset), batch_size, leave=True, disable=not show_progress_bar): batch = dataset.select(range(batch_start_idx, min(len(dataset), batch_start_idx + batch_size))) # Expanding the small tokenized output into full-scale input_ids, position_ids and attention_mask matrices. batch = self.data_collator(batch) # Moving the inputs to the right device batch = {key: value.to(self.device) for key, value in batch.items()} with torch.no_grad(): output = self(**batch) # Computing probabilities based on the logits probs = output.logits.softmax(-1) # Get the labels and the correponding probability scores scores, labels = probs.max(-1) # TODO: Iterate over output.num_marker_pairs instead with enumerate for iter_idx in range(output.num_marker_pairs.size(0)): input_id = dataset["id"][batch_start_idx + iter_idx] num_marker_pairs = output.num_marker_pairs[iter_idx] results[input_id]["scores"].extend(scores[iter_idx, :num_marker_pairs].tolist()) results[input_id]["labels"].extend(labels[iter_idx, :num_marker_pairs].tolist()) results[input_id]["num_words"] = output.num_words[iter_idx] all_entities = [] id2label = self.config.id2label for sample_idx, sample in enumerate(results): scores = sample["scores"] labels = sample["labels"] num_words = sample["num_words"] sentence = sample["tokens"] # Get all of the valid spans to match with the score and labels spans = list(self.tokenizer.get_all_valid_spans(num_words, self.config.entity_max_length)) word_selected = [False] * num_words sentence_entities = [] assert len(spans) == len(scores) and len(spans) == len(labels) for (word_start_index, word_end_index), score, label_id in sorted( zip(spans, scores, labels), key=lambda tup: tup[1], reverse=True ): if label_id != self.config.outside_id and not any(word_selected[word_start_index:word_end_index]): char_start_index = batch_encoding.word_to_chars(sample_idx, word_start_index).start char_end_index = batch_encoding.word_to_chars(sample_idx, word_end_index - 1).end entity = { "span": sentence[char_start_index:char_end_index] if isinstance(sentence, str) else sentence[word_start_index:word_end_index], "label": id2label[label_id], "score": score, } if isinstance(sentence, str): entity["char_start_index"] = char_start_index entity["char_end_index"] = char_end_index else: entity["word_start_index"] = word_start_index entity["word_end_index"] = word_end_index sentence_entities.append(entity) word_selected[word_start_index:word_end_index] = [True] * (word_end_index - word_start_index) all_entities.append( sorted( sentence_entities, key=lambda entity: entity["char_start_index"] if isinstance(sentence, str) else entity["word_start_index"], ) ) # if the input was a string or a list of tokens, return a list of dictionaries if single_input and len(all_entities) == 1: return all_entities[0] return all_entities
[docs] def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, state_dict: Optional[dict] = None, save_function: Callable = torch.save, push_to_hub: bool = False, max_shard_size: Union[int, str] = "10GB", safe_serialization: bool = False, variant: Optional[str] = None, **kwargs, ) -> None: super().save_pretrained( save_directory, is_main_process=is_main_process, state_dict=state_dict, save_function=save_function, push_to_hub=push_to_hub, max_shard_size=max_shard_size, safe_serialization=safe_serialization, variant=variant, **kwargs, ) self.tokenizer.save_pretrained( save_directory, is_main_process=is_main_process, state_dict=state_dict, save_function=save_function, push_to_hub=push_to_hub, max_shard_size=max_shard_size, safe_serialization=safe_serialization, variant=variant, **kwargs, ) with open(os.path.join(save_directory, "README.md"), "w", encoding="utf-8") as f: f.write(self.generate_model_card())
[docs] def generate_model_card(self) -> str: """Generate and return a model card string based on the model card data. Returns: str: The model card string. """ return generate_model_card(self)
[docs] def try_cuda(self, device: Optional[Union[int, device]] = None) -> Self: """Try to moves all model parameters and buffers to the GPU, do nothing if failed. .. note:: This method modifies the module in-place. Args: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self """ try: return self.cuda(device) except (AssertionError, RuntimeError): return self