from dataclasses import dataclass from typing import TYPE_CHECKING, Iterable, List, Optional, cast import numpy as np from pydantic import Field, SecretStr from unstructured.documents.elements import Element from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig from unstructured.utils import requires_dependencies if TYPE_CHECKING: from voyageai import Client DEFAULT_VOYAGE_2_BATCH_SIZE = 72 DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30 DEFAULT_VOYAGE_3_BATCH_SIZE = 10 DEFAULT_BATCH_SIZE = 7 class VoyageAIEmbeddingConfig(EmbeddingConfig): api_key: SecretStr model_name: str show_progress_bar: bool = False batch_size: Optional[int] = Field(default=None) truncation: Optional[bool] = Field(default=None) output_dimension: Optional[int] = Field(default=None) @requires_dependencies( ["voyageai"], extras="embed-voyageai", ) def get_client(self) -> "Client": """Creates a VoyageAI python client to embed elements.""" from voyageai import Client return Client( api_key=self.api_key.get_secret_value(), ) def get_batch_size(self): if self.batch_size is None: if self.model_name in ["voyage-2", "voyage-02"]: self.batch_size = DEFAULT_VOYAGE_2_BATCH_SIZE elif self.model_name == "voyage-3-lite": self.batch_size = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE elif self.model_name == "voyage-3": self.batch_size = DEFAULT_VOYAGE_3_BATCH_SIZE else: self.batch_size = DEFAULT_BATCH_SIZE return self.batch_size @dataclass class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder): config: VoyageAIEmbeddingConfig def get_exemplary_embedding(self) -> List[float]: return self.embed_query(query="A sample query.") def initialize(self): pass @property def num_of_dimensions(self) -> tuple[int, ...]: exemplary_embedding = self.get_exemplary_embedding() return np.shape(exemplary_embedding) @property def is_unit_vector(self) -> bool: exemplary_embedding = self.get_exemplary_embedding() return np.isclose(np.linalg.norm(exemplary_embedding), 1.0) def embed_documents(self, elements: List[Element]) -> List[Element]: client = self.config.get_client() embeddings: List[List[float]] = [] _iter = self._get_batch_iterator(elements) for i in _iter: r = client.embed( texts=[str(e) for e in elements[i : i + self.config.get_batch_size()]], model=self.config.model_name, input_type="document", truncation=self.config.truncation, output_dimension=self.config.output_dimension, ).embeddings embeddings.extend(cast(Iterable[List[float]], r)) return self._add_embeddings_to_elements(elements, embeddings) def embed_query(self, query: str) -> List[float]: client = self.config.get_client() return client.embed( texts=[query], model=self.config.model_name, input_type="query", truncation=self.config.truncation, output_dimension=self.config.output_dimension, ).embeddings[0] @staticmethod def _add_embeddings_to_elements(elements, embeddings) -> List[Element]: assert len(elements) == len(embeddings) elements_w_embedding = [] for i, element in enumerate(elements): element.embeddings = embeddings[i] elements_w_embedding.append(element) return elements def _get_batch_iterator(self, elements: List[Element]) -> Iterable: if self.config.show_progress_bar: try: from tqdm.auto import tqdm # type: ignore except ImportError as e: raise ImportError( "Must have tqdm installed if `show_progress_bar` is set to True. " "Please install with `pip install tqdm`." ) from e _iter = tqdm(range(0, len(elements), self.config.get_batch_size())) else: _iter = range(0, len(elements), self.config.get_batch_size()) # type: ignore return _iter
Memory