# type: ignore import json import os from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional import numpy as np from pydantic import Field, SecretStr from unstructured.documents.elements import ( Element, ) from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig from unstructured.utils import FileHandler, requires_dependencies if TYPE_CHECKING: from langchain_google_vertexai import VertexAIEmbeddings class VertexAIEmbeddingConfig(EmbeddingConfig): api_key: SecretStr model_name: Optional[str] = Field(default="textembedding-gecko@001") def register_application_credentials(self): application_credentials_path = os.path.join("/tmp", "google-vertex-app-credentials.json") credentials_file = FileHandler(application_credentials_path) credentials_file.write_file(json.dumps(json.loads(self.api_key.get_secret_value()))) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = application_credentials_path @requires_dependencies( ["langchain", "langchain_google_vertexai"], extras="embed-vertexai", ) def get_client(self) -> "VertexAIEmbeddings": """Creates a Langchain VertexAI python client to embed elements.""" from langchain_google_vertexai import VertexAIEmbeddings self.register_application_credentials() vertexai_client = VertexAIEmbeddings(model_name=self.model_name) return vertexai_client @dataclass class VertexAIEmbeddingEncoder(BaseEmbeddingEncoder): config: VertexAIEmbeddingConfig def get_exemplary_embedding(self) -> List[float]: return self.embed_query(query="A sample query.") def initialize(self): pass def num_of_dimensions(self): exemplary_embedding = self.get_exemplary_embedding() return np.shape(exemplary_embedding) def is_unit_vector(self): exemplary_embedding = self.get_exemplary_embedding() return np.isclose(np.linalg.norm(exemplary_embedding), 1.0) def embed_query(self, query): client = self.config.get_client() result = client.embed_query(str(query)) return result def embed_documents(self, elements: List[Element]) -> List[Element]: client = self.config.get_client() embeddings = client.embed_documents([str(e) for e in elements]) elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings) return elements_with_embeddings def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]: assert len(elements) == len(embeddings) elements_w_embedding = [] for i, element in enumerate(elements): element.embeddings = embeddings[i] elements_w_embedding.append(element) return elements
Memory