from __future__ import annotations from collections.abc import Iterable from typing import Any import torch from torch import Tensor, nn from sentence_transformers import util from sentence_transformers.SentenceTransformer import SentenceTransformer class CoSENTLoss(nn.Module): def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.pairwise_cos_sim) -> None: """ This class implements CoSENT (Cosine Sentence) loss. It expects that each of the InputExamples consists of a pair of texts and a float valued label, representing the expected similarity score between the pair. It computes the following loss function: ``loss = logsum(1+exp(s(k,l)-s(i,j))+exp...)``, where ``(i,j)`` and ``(k,l)`` are any of the input pairs in the batch such that the expected similarity of ``(i,j)`` is greater than ``(k,l)``. The summation is over all possible pairs of input pairs in the batch that match this condition. Anecdotal experiments show that this loss function produces a more powerful training signal than :class:`CosineSimilarityLoss`, resulting in faster convergence and a final model with superior performance. Consequently, CoSENTLoss may be used as a drop-in replacement for :class:`CosineSimilarityLoss` in any training script. Args: model: SentenceTransformerModel similarity_fct: Function to compute the PAIRWISE similarity between embeddings. Default is ``util.pairwise_cos_sim``. scale: Output of similarity function is multiplied by scale value. Represents the inverse temperature. References: - For further details, see: https://kexue.fm/archives/8847 Requirements: - Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1]. Inputs: +--------------------------------+------------------------+ | Texts | Labels | +================================+========================+ | (sentence_A, sentence_B) pairs | float similarity score | +--------------------------------+------------------------+ Relations: - :class:`AnglELoss` is CoSENTLoss with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. - :class:`CosineSimilarityLoss` seems to produce a weaker training signal than CoSENTLoss. In our experiments, CoSENTLoss is recommended. Example: :: from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "sentence1": ["It's nice weather outside today.", "He drove to work."], "sentence2": ["It's so sunny.", "She walked to the store."], "score": [1.0, 0.3], }) loss = losses.CoSENTLoss(model) trainer = SentenceTransformerTrainer( model=model, train_dataset=train_dataset, loss=loss, ) trainer.train() """ super().__init__() self.model = model self.similarity_fct = similarity_fct self.scale = scale def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] scores = self.similarity_fct(embeddings[0], embeddings[1]) scores = scores * self.scale scores = scores[:, None] - scores[None, :] # label matrix indicating which pairs are relevant labels = labels[:, None] < labels[None, :] labels = labels.float() # mask out irrelevant pairs so they are negligible after exp() scores = scores - (1 - labels) * 1e12 # append a zero as e^0 = 1 scores = torch.cat((torch.zeros(1).to(scores.device), scores.view(-1)), dim=0) loss = torch.logsumexp(scores, dim=0) return loss def get_config_dict(self) -> dict[str, Any]: return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__} @property def citation(self) -> str: return """ @online{kexuefm-8847, title={CoSENT: A more efficient sentence vector scheme than Sentence-BERT}, author={Su Jianlin}, year={2022}, month={Jan}, url={https://kexue.fm/archives/8847}, } """
Memory