from __future__ import annotations
import json
import os
import torch
from safetensors.torch import load_model as load_safetensors_model
from safetensors.torch import save_model as save_safetensors_model
from torch import Tensor, nn
class LayerNorm(nn.Module):
def __init__(self, dimension: int):
super().__init__()
self.dimension = dimension
self.norm = nn.LayerNorm(dimension)
def forward(self, features: dict[str, Tensor]):
features["sentence_embedding"] = self.norm(features["sentence_embedding"])
return features
def get_sentence_embedding_dimension(self):
return self.dimension
def save(self, output_path, safe_serialization: bool = True) -> None:
with open(os.path.join(output_path, "config.json"), "w") as fOut:
json.dump({"dimension": self.dimension}, fOut, indent=2)
if safe_serialization:
save_safetensors_model(self, os.path.join(output_path, "model.safetensors"))
else:
torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
@staticmethod
def load(input_path):
with open(os.path.join(input_path, "config.json")) as fIn:
config = json.load(fIn)
model = LayerNorm(**config)
if os.path.exists(os.path.join(input_path, "model.safetensors")):
load_safetensors_model(model, os.path.join(input_path, "model.safetensors"))
else:
model.load_state_dict(
torch.load(
os.path.join(input_path, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True
)
)
return model