from __future__ import annotations
import json
import logging
import os
import shutil
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
import numpy as np
import torch
import transformers
from packaging import version
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from tqdm.autonotebook import trange
from transformers import TrainerCallback, TrainerControl, TrainerState
from sentence_transformers.datasets.NoDuplicatesDataLoader import NoDuplicatesDataLoader
from sentence_transformers.datasets.SentenceLabelDataset import SentenceLabelDataset
from sentence_transformers.training_args import (
BatchSamplers,
MultiDatasetBatchSamplers,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.util import batch_to_device, fullname, is_datasets_available
from .evaluation import SentenceEvaluator
from .model_card_templates import ModelCardTemplate
if is_datasets_available():
from datasets import Dataset, DatasetDict
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sentence_transformers.readers.InputExample import InputExample
from sentence_transformers.SentenceTransformer import SentenceTransformer
class SaveModelCallback(TrainerCallback):
"""A Callback to save the model to the `output_dir`.
There are two cases:
1. save_best_model is True and evaluator is defined:
We save on evaluate, but only if the new model is better than the currently saved one
according to the evaluator.
2. If evaluator is not defined:
We save after the model has been trained.
"""
def __init__(self, output_dir: str, evaluator: SentenceEvaluator | None, save_best_model: bool) -> None:
super().__init__()
self.output_dir = output_dir
self.evaluator = evaluator
self.save_best_model = save_best_model
self.best_metric = None
def is_better(self, new_metric: float) -> bool:
if getattr(self.evaluator, "greater_is_better", True):
return new_metric > self.best_metric
return new_metric < self.best_metric
def on_evaluate(
self,
args: SentenceTransformerTrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: dict[str, Any],
model: SentenceTransformer,
**kwargs,
) -> None:
if self.evaluator is not None and self.save_best_model:
metric_key = getattr(self.evaluator, "primary_metric", "evaluator")
for key, value in metrics.items():
if key.endswith(metric_key):
if self.best_metric is None or self.is_better(value):
self.best_metric = value
model.save(self.output_dir)
def on_train_end(
self,
args: SentenceTransformerTrainingArguments,
state: TrainerState,
control: TrainerControl,
model: SentenceTransformer,
**kwargs,
) -> None:
if self.evaluator is None:
model.save(self.output_dir)
class EvaluatorCallback(TrainerCallback):
"""The SentenceTransformers.fit method always ran the evaluator on every epoch,
in addition to every "evaluation_steps". This callback is responsible for that.
The `.trainer` must be provided after the trainer has been created.
"""
def __init__(self, evaluator: SentenceEvaluator) -> None:
super().__init__()
self.evaluator = evaluator
self.metric_key_prefix = "eval"
self.trainer = None
def on_epoch_end(
self,
args: SentenceTransformerTrainingArguments,
state: TrainerState,
control: TrainerControl,
model: SentenceTransformer,
**kwargs,
) -> None:
evaluator_metrics = self.evaluator(model, epoch=state.epoch)
if not isinstance(evaluator_metrics, dict):
evaluator_metrics = {"evaluator": evaluator_metrics}
# Prefix all keys with metric_key_prefix + '_'
for key in list(evaluator_metrics.keys()):
if not key.startswith(f"{self.metric_key_prefix}_"):
evaluator_metrics[f"{self.metric_key_prefix}_{key}"] = evaluator_metrics.pop(key)
if self.trainer is not None:
self.trainer.callback_handler.on_evaluate(args, state, control, metrics=evaluator_metrics)
class OriginalCallback(TrainerCallback):
"""A Callback to invoke the original callback function that was provided to SentenceTransformer.fit()
This callback has the following signature: `(score: float, epoch: int, steps: int) -> None`
"""
def __init__(self, callback: Callable[[float, int, int], None], evaluator: SentenceEvaluator) -> None:
super().__init__()
self.callback = callback
self.evaluator = evaluator
def on_evaluate(
self,
args: transformers.TrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: dict[str, Any],
**kwargs,
) -> None:
metric_key = getattr(self.evaluator, "primary_metric", "evaluator")
for key, value in metrics.items():
if key.endswith(metric_key):
return self.callback(value, state.epoch, state.global_step)
class FitMixin:
"""Mixin class for injecting the `fit` method into Sentence Transformers"""
def fit(
self,
train_objectives: Iterable[tuple[DataLoader, nn.Module]],
evaluator: SentenceEvaluator = None,
epochs: int = 1,
steps_per_epoch=None,
scheduler: str = "WarmupLinear",
warmup_steps: int = 10000,
optimizer_class: type[Optimizer] = torch.optim.AdamW,
optimizer_params: dict[str, object] = {"lr": 2e-5},
weight_decay: float = 0.01,
evaluation_steps: int = 0,
output_path: str = None,
save_best_model: bool = True,
max_grad_norm: float = 1,
use_amp: bool = False,
callback: Callable[[float, int, int], None] = None,
show_progress_bar: bool = True,
checkpoint_path: str = None,
checkpoint_save_steps: int = 500,
checkpoint_save_total_limit: int = 0,
) -> None:
"""
Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
:class:`~sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method uses
:class:`~sentence_transformers.trainer.SentenceTransformerTrainer` behind the scenes, but does
not provide as much flexibility as the Trainer itself.
This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader
is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the
smallest one to make sure of equal training with each dataset, i.e. round robin sampling.
This method should produce equivalent results in v3.0+ as before v3.0, but if you encounter any issues
with your existing training scripts, then you may wish to use
:meth:`SentenceTransformer.old_fit <sentence_transformers.SentenceTransformer.old_fit>` instead.
That uses the old training method from before v3.0.
Args:
train_objectives: Tuples of (DataLoader, LossFunction). Pass
more than one for multi-task learning
evaluator: An evaluator (sentence_transformers.evaluation)
evaluates the model performance during training on held-
out dev data. It is used to determine the best model
that is saved to disc.
epochs: Number of epochs for training
steps_per_epoch: Number of training steps per epoch. If set
to None (default), one epoch is equal the DataLoader
size from train_objectives.
scheduler: Learning rate scheduler. Available schedulers:
constantlr, warmupconstant, warmuplinear, warmupcosine,
warmupcosinewithhardrestarts
warmup_steps: Behavior depends on the scheduler. For
WarmupLinear (default), the learning rate is increased
from o up to the maximal learning rate. After these many
training steps, the learning rate is decreased linearly
back to zero.
optimizer_class: Optimizer
optimizer_params: Optimizer parameters
weight_decay: Weight decay for model parameters
evaluation_steps: If > 0, evaluate the model using evaluator
after each number of training steps
output_path: Storage path for the model and evaluation files
save_best_model: If true, the best model (according to
evaluator) is stored at output_path
max_grad_norm: Used for gradient normalization.
use_amp: Use Automatic Mixed Precision (AMP). Only for
Pytorch >= 1.6.0
callback: Callback function that is invoked after each
evaluation. It must accept the following three
parameters in this order: `score`, `epoch`, `steps`
show_progress_bar: If True, output a tqdm progress bar
checkpoint_path: Folder to save checkpoints during training
checkpoint_save_steps: Will save a checkpoint after so many
steps
checkpoint_save_total_limit: Total number of checkpoints to
store
"""
if not is_datasets_available():
raise ImportError("Please install `datasets` to use this function: `pip install datasets`.")
# Delayed import to counter the SentenceTransformers -> FitMixin -> SentenceTransformerTrainer -> SentenceTransformers circular import
from sentence_transformers.trainer import SentenceTransformerTrainer
data_loaders, loss_fns = zip(*train_objectives)
# Clear the dataloaders from collate functions as we just want raw InputExamples
def identity(batch):
return batch
for data_loader in data_loaders:
data_loader.collate_fn = identity
batch_size = 8
batch_sampler = BatchSamplers.BATCH_SAMPLER
# Convert dataloaders into a DatasetDict
# TODO: This is rather inefficient, as we load all data into memory. We might benefit from a more efficient solution
train_dataset_dict = {}
for loader_idx, data_loader in enumerate(data_loaders, start=1):
if isinstance(data_loader, NoDuplicatesDataLoader):
batch_sampler = BatchSamplers.NO_DUPLICATES
elif hasattr(data_loader, "dataset") and isinstance(data_loader.dataset, SentenceLabelDataset):
batch_sampler = BatchSamplers.GROUP_BY_LABEL
batch_size = getattr(data_loader, "batch_size", batch_size)
texts = []
labels = []
for batch in data_loader:
batch_texts, batch_labels = zip(*[(example.texts, example.label) for example in batch])
texts += batch_texts
labels += batch_labels
dataset = Dataset.from_dict({f"sentence_{idx}": text for idx, text in enumerate(zip(*texts))})
# Add label column, unless all labels are 0 (the default value for `labels` in InputExample)
add_label_column = True
try:
if set(labels) == {0}:
add_label_column = False
except TypeError:
pass
if add_label_column:
dataset = dataset.add_column("label", labels)
train_dataset_dict[f"_dataset_{loader_idx}"] = dataset
train_dataset_dict = DatasetDict(train_dataset_dict)
def _default_checkpoint_dir() -> str:
dir_name = "checkpoints/model"
idx = 1
while Path(dir_name).exists() and len(list(Path(dir_name).iterdir())) != 0:
dir_name = f"checkpoints/model_{idx}"
idx += 1
return dir_name
# Convert loss_fns into a dict with `dataset_{idx}` keys
loss_fn_dict = {f"_dataset_{idx}": loss_fn for idx, loss_fn in enumerate(loss_fns, start=1)}
# Use steps_per_epoch to perhaps set max_steps
max_steps = -1
if steps_per_epoch is not None and steps_per_epoch > 0:
if epochs == 1:
max_steps = steps_per_epoch
else:
logger.warning(
"Setting `steps_per_epoch` alongside `epochs` > 1 no longer works. "
"We will train with the full datasets per epoch."
)
steps_per_epoch = None
# Transformers renamed `evaluation_strategy` to `eval_strategy` in v4.41.0
eval_strategy_key = (
"eval_strategy"
if version.parse(transformers.__version__) >= version.parse("4.41.0")
else "evaluation_strategy"
)
args = SentenceTransformerTrainingArguments(
output_dir=checkpoint_path or _default_checkpoint_dir(),
batch_sampler=batch_sampler,
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.ROUND_ROBIN,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
max_steps=max_steps,
**{
eval_strategy_key: "steps" if evaluation_steps is not None and evaluation_steps > 0 else "no",
},
eval_steps=evaluation_steps,
# load_best_model_at_end=save_best_model, # <- TODO: Look into a good solution for save_best_model
max_grad_norm=max_grad_norm,
fp16=use_amp,
disable_tqdm=not show_progress_bar,
save_strategy="steps" if checkpoint_path is not None else "no",
save_steps=checkpoint_save_steps,
save_total_limit=checkpoint_save_total_limit,
)
if steps_per_epoch is None or steps_per_epoch == 0:
steps_per_epoch = min([len(train_dataset) // batch_size for train_dataset in train_dataset_dict.values()])
num_train_steps = int(steps_per_epoch * epochs)
# Prepare optimizer & scheduler
param_optimizer = list(self.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
scheduler_obj = self._get_scheduler(
optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps
)
# Create callbacks
callbacks = []
if evaluator is not None:
callbacks.append(EvaluatorCallback(evaluator))
if callback is not None:
callbacks.append(OriginalCallback(callback, evaluator))
trainer = SentenceTransformerTrainer(
model=self,
args=args,
train_dataset=train_dataset_dict,
eval_dataset=None,
loss=loss_fn_dict,
evaluator=evaluator,
optimizers=(optimizer, scheduler_obj),
callbacks=callbacks,
)
# Set the trainer on the EvaluatorCallback, required for logging the metrics
for callback in trainer.callback_handler.callbacks:
if isinstance(callback, EvaluatorCallback):
callback.trainer = trainer
if output_path is not None:
trainer.add_callback(SaveModelCallback(output_path, evaluator, save_best_model))
trainer.train()
@staticmethod
def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int) -> LambdaLR:
"""
Returns the correct learning rate scheduler. Available scheduler:
- constantlr,
- warmupconstant,
- warmuplinear,
- warmupcosine,
- warmupcosinewithhardrestarts
"""
scheduler = scheduler.lower()
if scheduler == "constantlr":
return transformers.get_constant_schedule(optimizer)
elif scheduler == "warmupconstant":
return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
elif scheduler == "warmuplinear":
return transformers.get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
)
elif scheduler == "warmupcosine":
return transformers.get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
)
elif scheduler == "warmupcosinewithhardrestarts":
return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
)
else:
raise ValueError(f"Unknown scheduler {scheduler}")
def smart_batching_collate(self, batch: list[InputExample]) -> tuple[list[dict[str, Tensor]], Tensor]:
"""
Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
Here, batch is a list of InputExample instances: [InputExample(...), ...]
Args:
batch: a batch from a SmartBatchingDataset
Returns:
a batch of tensors for the model
"""
texts = [example.texts for example in batch]
sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)]
labels = [example.label for example in batch]
# Use torch.from_numpy to convert the numpy array directly to a tensor,
# which is the recommended approach for converting numpy arrays to tensors
if labels and isinstance(labels[0], np.ndarray):
labels_tensor = torch.from_numpy(np.stack(labels))
else:
labels_tensor = torch.tensor(labels)
return sentence_features, labels_tensor
"""
Temporary methods that will be removed when this refactor is complete:
"""
def old_fit(
self,
train_objectives: Iterable[tuple[DataLoader, nn.Module]],
evaluator: SentenceEvaluator = None,
epochs: int = 1,
steps_per_epoch=None,
scheduler: str = "WarmupLinear",
warmup_steps: int = 10000,
optimizer_class: type[Optimizer] = torch.optim.AdamW,
optimizer_params: dict[str, object] = {"lr": 2e-5},
weight_decay: float = 0.01,
evaluation_steps: int = 0,
output_path: str = None,
save_best_model: bool = True,
max_grad_norm: float = 1,
use_amp: bool = False,
callback: Callable[[float, int, int], None] = None,
show_progress_bar: bool = True,
checkpoint_path: str = None,
checkpoint_save_steps: int = 500,
checkpoint_save_total_limit: int = 0,
) -> None:
"""
Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
:class:`sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method should
only be used if you encounter issues with your existing training scripts after upgrading to v3.0+.
This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader
is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the
smallest one to make sure of equal training with each dataset, i.e. round robin sampling.
Args:
train_objectives: Tuples of (DataLoader, LossFunction). Pass
more than one for multi-task learning
evaluator: An evaluator (sentence_transformers.evaluation)
evaluates the model performance during training on held-
out dev data. It is used to determine the best model
that is saved to disc.
epochs: Number of epochs for training
steps_per_epoch: Number of training steps per epoch. If set
to None (default), one epoch is equal the DataLoader
size from train_objectives.
scheduler: Learning rate scheduler. Available schedulers:
constantlr, warmupconstant, warmuplinear, warmupcosine,
warmupcosinewithhardrestarts
warmup_steps: Behavior depends on the scheduler. For
WarmupLinear (default), the learning rate is increased
from o up to the maximal learning rate. After these many
training steps, the learning rate is decreased linearly
back to zero.
optimizer_class: Optimizer
optimizer_params: Optimizer parameters
weight_decay: Weight decay for model parameters
evaluation_steps: If > 0, evaluate the model using evaluator
after each number of training steps
output_path: Storage path for the model and evaluation files
save_best_model: If true, the best model (according to
evaluator) is stored at output_path
max_grad_norm: Used for gradient normalization.
use_amp: Use Automatic Mixed Precision (AMP). Only for
Pytorch >= 1.6.0
callback: Callback function that is invoked after each
evaluation. It must accept the following three
parameters in this order: `score`, `epoch`, `steps`
show_progress_bar: If True, output a tqdm progress bar
checkpoint_path: Folder to save checkpoints during training
checkpoint_save_steps: Will save a checkpoint after so many
steps
checkpoint_save_total_limit: Total number of checkpoints to
store
"""
##Add info to model card
# info_loss_functions = "\n".join(["- {} with {} training examples".format(str(loss), len(dataloader)) for dataloader, loss in train_objectives])
info_loss_functions = []
for dataloader, loss in train_objectives:
info_loss_functions.extend(ModelCardTemplate.get_train_objective_info(dataloader, loss))
info_loss_functions = "\n\n".join([text for text in info_loss_functions])
info_fit_parameters = json.dumps(
{
"evaluator": fullname(evaluator),
"epochs": epochs,
"steps_per_epoch": steps_per_epoch,
"scheduler": scheduler,
"warmup_steps": warmup_steps,
"optimizer_class": str(optimizer_class),
"optimizer_params": optimizer_params,
"weight_decay": weight_decay,
"evaluation_steps": evaluation_steps,
"max_grad_norm": max_grad_norm,
},
indent=4,
sort_keys=True,
)
self._model_card_text = None
self._model_card_vars["{TRAINING_SECTION}"] = ModelCardTemplate.__TRAINING_SECTION__.replace(
"{LOSS_FUNCTIONS}", info_loss_functions
).replace("{FIT_PARAMETERS}", info_fit_parameters)
if use_amp:
from torch.cuda.amp import autocast
scaler = torch.cuda.amp.GradScaler()
self.to(self.device)
dataloaders = [dataloader for dataloader, _ in train_objectives]
# Use smart batching
for dataloader in dataloaders:
dataloader.collate_fn = self.smart_batching_collate
loss_models = [loss for _, loss in train_objectives]
for loss_model in loss_models:
loss_model.to(self.device)
self.best_score = -9999999
if steps_per_epoch is None or steps_per_epoch == 0:
steps_per_epoch = min([len(dataloader) for dataloader in dataloaders])
num_train_steps = int(steps_per_epoch * epochs)
# Prepare optimizers
optimizers = []
schedulers = []
for loss_model in loss_models:
param_optimizer = list(loss_model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
scheduler_obj = self._get_scheduler(
optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps
)
optimizers.append(optimizer)
schedulers.append(scheduler_obj)
global_step = 0
data_iterators = [iter(dataloader) for dataloader in dataloaders]
num_train_objectives = len(train_objectives)
skip_scheduler = False
for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar):
training_steps = 0
for loss_model in loss_models:
loss_model.zero_grad()
loss_model.train()
for _ in trange(steps_per_epoch, desc="Iteration", smoothing=0.05, disable=not show_progress_bar):
for train_idx in range(num_train_objectives):
loss_model = loss_models[train_idx]
optimizer = optimizers[train_idx]
scheduler = schedulers[train_idx]
data_iterator = data_iterators[train_idx]
try:
data = next(data_iterator)
except StopIteration:
data_iterator = iter(dataloaders[train_idx])
data_iterators[train_idx] = data_iterator
data = next(data_iterator)
features, labels = data
labels = labels.to(self.device)
features = list(map(lambda batch: batch_to_device(batch, self.device), features))
if use_amp:
with autocast():
loss_value = loss_model(features, labels)
scale_before_step = scaler.get_scale()
scaler.scale(loss_value).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
skip_scheduler = scaler.get_scale() != scale_before_step
else:
loss_value = loss_model(features, labels)
loss_value.backward()
torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
optimizer.step()
optimizer.zero_grad()
if not skip_scheduler:
scheduler.step()
training_steps += 1
global_step += 1
if evaluation_steps > 0 and training_steps % evaluation_steps == 0:
self._eval_during_training(
evaluator, output_path, save_best_model, epoch, training_steps, callback
)
for loss_model in loss_models:
loss_model.zero_grad()
loss_model.train()
if (
checkpoint_path is not None
and checkpoint_save_steps is not None
and checkpoint_save_steps > 0
and global_step % checkpoint_save_steps == 0
):
self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)
self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)
if evaluator is None and output_path is not None: # No evaluator, but output path: save final model version
self.save(output_path)
if checkpoint_path is not None:
self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step)
def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback) -> None:
"""Runs evaluation during the training"""
eval_path = output_path
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
eval_path = os.path.join(output_path, "eval")
os.makedirs(eval_path, exist_ok=True)
if evaluator is not None:
score = evaluator(self, output_path=eval_path, epoch=epoch, steps=steps)
if callback is not None:
callback(score, epoch, steps)
if score > self.best_score:
self.best_score = score
if save_best_model:
self.save(output_path)
def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step) -> None:
# Store new checkpoint
self.save(os.path.join(checkpoint_path, str(step)))
# Delete old checkpoints
if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0:
old_checkpoints = []
for subdir in os.listdir(checkpoint_path):
if subdir.isdigit():
old_checkpoints.append({"step": int(subdir), "path": os.path.join(checkpoint_path, subdir)})
if len(old_checkpoints) > checkpoint_save_total_limit:
old_checkpoints = sorted(old_checkpoints, key=lambda x: x["step"])
shutil.rmtree(old_checkpoints[0]["path"])