from __future__ import annotations import csv import logging import os from typing import TYPE_CHECKING import torch from torch.utils.data import DataLoader from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator from sentence_transformers.util import batch_to_device if TYPE_CHECKING: from sentence_transformers.SentenceTransformer import SentenceTransformer logger = logging.getLogger(__name__) class LabelAccuracyEvaluator(SentenceEvaluator): """ Evaluate a model based on its accuracy on a labeled dataset This requires a model with LossFunction.SOFTMAX The results are written in a CSV. If a CSV already exists, then values are appended. """ def __init__(self, dataloader: DataLoader, name: str = "", softmax_model=None, write_csv: bool = True): """ Constructs an evaluator for the given dataset Args: dataloader (DataLoader): the data for the evaluation """ super().__init__() self.dataloader = dataloader self.name = name self.softmax_model = softmax_model if name: name = "_" + name self.write_csv = write_csv self.csv_file = "accuracy_evaluation" + name + "_results.csv" self.csv_headers = ["epoch", "steps", "accuracy"] self.primary_metric = "accuracy" def __call__( self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1 ) -> dict[str, float]: model.eval() total = 0 correct = 0 if epoch != -1: if steps == -1: out_txt = f" after epoch {epoch}:" else: out_txt = f" in epoch {epoch} after {steps} steps:" else: out_txt = ":" logger.info("Evaluation on the " + self.name + " dataset" + out_txt) self.dataloader.collate_fn = model.smart_batching_collate for step, batch in enumerate(self.dataloader): features, label_ids = batch for idx in range(len(features)): features[idx] = batch_to_device(features[idx], model.device) label_ids = label_ids.to(model.device) with torch.no_grad(): _, prediction = self.softmax_model(features, labels=None) total += prediction.size(0) correct += torch.argmax(prediction, dim=1).eq(label_ids).sum().item() accuracy = correct / total logger.info(f"Accuracy: {accuracy:.4f} ({correct}/{total})\n") if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) if not os.path.isfile(csv_path): with open(csv_path, newline="", mode="w", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(self.csv_headers) writer.writerow([epoch, steps, accuracy]) else: with open(csv_path, newline="", mode="a", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow([epoch, steps, accuracy]) metrics = {"accuracy": accuracy} metrics = self.prefix_name_to_metrics(metrics, self.name) self.store_metrics_in_model_card_data(model, metrics) return metrics
Memory