from dataclasses import dataclass, field from typing import TYPE_CHECKING, List, Optional import numpy as np from pydantic import Field, SecretStr from unstructured.documents.elements import ( Element, ) from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig from unstructured.utils import requires_dependencies if TYPE_CHECKING: from openai import OpenAI class OctoAiEmbeddingConfig(EmbeddingConfig): api_key: SecretStr model_name: str = Field(default="thenlper/gte-large") base_url: str = Field(default="https://text.octoai.run/v1") @requires_dependencies( ["openai", "tiktoken"], extras="embed-octoai", ) def get_client(self) -> "OpenAI": """Creates an OpenAI python client to embed elements. Uses the OpenAI SDK.""" from openai import OpenAI return OpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url) @dataclass class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder): config: OctoAiEmbeddingConfig # Uses the OpenAI SDK _exemplary_embedding: Optional[List[float]] = field(init=False, default=None) def get_exemplary_embedding(self) -> List[float]: return self.embed_query("Q") def initialize(self): pass def num_of_dimensions(self): exemplary_embedding = self.get_exemplary_embedding() return np.shape(exemplary_embedding) def is_unit_vector(self): exemplary_embedding = self.get_exemplary_embedding() return np.isclose(np.linalg.norm(exemplary_embedding), 1.0) def embed_query(self, query): client = self.config.get_client() response = client.embeddings.create(input=str(query), model=self.config.model_name) return response.data[0].embedding def embed_documents(self, elements: List[Element]) -> List[Element]: embeddings = [self.embed_query(e) for e in elements] elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings) return elements_with_embeddings def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]: assert len(elements) == len(embeddings) elements_w_embedding = [] for i, element in enumerate(elements): element.embeddings = embeddings[i] elements_w_embedding.append(element) return elements
Memory