from __future__ import annotations import math import os from pathlib import Path import numpy as np import torch from safetensors.torch import load_file as load_safetensors_file from safetensors.torch import save_file as save_safetensors_file from tokenizers import Tokenizer from torch import nn from transformers import PreTrainedTokenizerFast from sentence_transformers.util import get_device_name class StaticEmbedding(nn.Module): def __init__( self, tokenizer: Tokenizer | PreTrainedTokenizerFast, embedding_weights: np.array | torch.Tensor | None = None, embedding_dim: int | None = None, **kwargs, ) -> None: """ Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that takes the mean of trained per-token embeddings to compute text embeddings. Args: tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. Must be a fast tokenizer from ``transformers`` or ``tokenizers``. embedding_weights (np.array | torch.Tensor | None, optional): Pre-trained embedding weights. Defaults to None. embedding_dim (int | None, optional): Dimension of the embeddings. Required if embedding_weights is not provided. Defaults to None. Example:: from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding from tokenizers import Tokenizer # Pre-distilled embeddings: static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output") # or distill your own embeddings: static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cuda") # or start with randomized embeddings: tokenizer = Tokenizer.from_pretrained("FacebookAI/xlm-roberta-base") static_embedding = StaticEmbedding(tokenizer, embedding_dim=512) model = SentenceTransformer(modules=[static_embedding]) embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊猫; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."]) similarity = model.similarity(embeddings[0], embeddings[1]) # tensor([[0.9177]]) (If you use the distilled bge-base) Raises: ValueError: If the tokenizer is not a fast tokenizer. ValueError: If neither `embedding_weights` nor `embedding_dim` is provided. """ super().__init__() if isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = tokenizer._tokenizer elif not isinstance(tokenizer, Tokenizer): raise ValueError( "The tokenizer must be fast (i.e. Rust-backed) to use this class. " "Use Tokenizer.from_pretrained() from `tokenizers` to load a fast tokenizer." ) if embedding_weights is not None: if isinstance(embedding_weights, np.ndarray): embedding_weights = torch.from_numpy(embedding_weights) self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False) elif embedding_dim is not None: self.embedding = nn.EmbeddingBag(tokenizer.get_vocab_size(), embedding_dim) else: raise ValueError("Either `embedding_weights` or `embedding_dim` must be provided.") self.num_embeddings = self.embedding.num_embeddings self.embedding_dim = self.embedding.embedding_dim self.tokenizer: Tokenizer = tokenizer self.tokenizer.no_padding() # For the model card self.base_model = kwargs.get("base_model", None) def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor]: encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False) encodings_ids = [encoding.ids for encoding in encodings] offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]])) input_ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long) return {"input_ids": input_ids, "offsets": offsets} def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"]) return features def get_config_dict(self) -> dict[str, float]: return {} @property def max_seq_length(self) -> int: return math.inf def get_sentence_embedding_dimension(self) -> int: return self.embedding_dim def save(self, save_dir: str, safe_serialization: bool = True, **kwargs) -> None: if safe_serialization: save_safetensors_file(self.state_dict(), os.path.join(save_dir, "model.safetensors")) else: torch.save(self.state_dict(), os.path.join(save_dir, "pytorch_model.bin")) self.tokenizer.save(str(Path(save_dir) / "tokenizer.json")) def load(load_dir: str, **kwargs) -> StaticEmbedding: tokenizer = Tokenizer.from_file(str(Path(load_dir) / "tokenizer.json")) if os.path.exists(os.path.join(load_dir, "model.safetensors")): weights = load_safetensors_file(os.path.join(load_dir, "model.safetensors")) else: weights = torch.load( os.path.join(load_dir, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True ) weights = weights["embedding.weight"] return StaticEmbedding(tokenizer, embedding_weights=weights) @classmethod def from_distillation( cls, model_name: str, vocabulary: list[str] | None = None, device: str | None = None, pca_dims: int | None = 256, apply_zipf: bool = True, use_subword: bool = True, ) -> StaticEmbedding: """ Creates a StaticEmbedding instance from a distillation process using the `model2vec` package. Args: model_name (str): The name of the model to distill. vocabulary (list[str] | None, optional): A list of vocabulary words to use. Defaults to None. device (str): The device to run the distillation on (e.g., 'cpu', 'cuda'). If not specified, the strongest device is automatically detected. Defaults to None. pca_dims (int | None, optional): The number of dimensions for PCA reduction. Defaults to 256. apply_zipf (bool): Whether to apply Zipf's law during distillation. Defaults to True. use_subword (bool): Whether to use subword tokenization. Defaults to True. Returns: StaticEmbedding: An instance of StaticEmbedding initialized with the distilled model's tokenizer and embedding weights. Raises: ImportError: If the `model2vec` package is not installed. """ try: from model2vec.distill import distill except ImportError: raise ImportError( "To use this method, please install the `model2vec` package: `pip install model2vec[distill]`" ) device = get_device_name() static_model = distill( model_name, vocabulary=vocabulary, device=device, pca_dims=pca_dims, apply_zipf=apply_zipf, use_subword=use_subword, ) if isinstance(static_model.embedding, np.ndarray): embedding_weights = torch.from_numpy(static_model.embedding) else: embedding_weights = static_model.embedding.weight tokenizer: Tokenizer = static_model.tokenizer return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_name) @classmethod def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding: """ Create a StaticEmbedding instance from a model2vec model. This method loads a pre-trained model2vec model and extracts the embedding weights and tokenizer to create a StaticEmbedding instance. Args: model_id_or_path (str): The identifier or path to the pre-trained model2vec model. Returns: StaticEmbedding: An instance of StaticEmbedding initialized with the tokenizer and embedding weights the model2vec model. Raises: ImportError: If the `model2vec` package is not installed. """ try: from model2vec import StaticModel except ImportError: raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`") static_model = StaticModel.from_pretrained(model_id_or_path) if isinstance(static_model.embedding, np.ndarray): embedding_weights = torch.from_numpy(static_model.embedding) else: embedding_weights = static_model.embedding.weight tokenizer: Tokenizer = static_model.tokenizer return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_id_or_path)
Memory