from __future__ import annotations import json import os import torch from safetensors.torch import load_model as load_safetensors_model from safetensors.torch import save_model as save_safetensors_model from torch import Tensor, nn class LayerNorm(nn.Module): def __init__(self, dimension: int): super().__init__() self.dimension = dimension self.norm = nn.LayerNorm(dimension) def forward(self, features: dict[str, Tensor]): features["sentence_embedding"] = self.norm(features["sentence_embedding"]) return features def get_sentence_embedding_dimension(self): return self.dimension def save(self, output_path, safe_serialization: bool = True) -> None: with open(os.path.join(output_path, "config.json"), "w") as fOut: json.dump({"dimension": self.dimension}, fOut, indent=2) if safe_serialization: save_safetensors_model(self, os.path.join(output_path, "model.safetensors")) else: torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin")) @staticmethod def load(input_path): with open(os.path.join(input_path, "config.json")) as fIn: config = json.load(fIn) model = LayerNorm(**config) if os.path.exists(os.path.join(input_path, "model.safetensors")): load_safetensors_model(model, os.path.join(input_path, "model.safetensors")) else: model.load_state_dict( torch.load( os.path.join(input_path, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True ) ) return model
Memory