from dataclasses import dataclass from typing import TYPE_CHECKING, List import numpy as np from pydantic import SecretStr from unstructured.documents.elements import ( Element, ) from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig from unstructured.utils import requires_dependencies if TYPE_CHECKING: from langchain_community.embeddings import BedrockEmbeddings class BedrockEmbeddingConfig(EmbeddingConfig): aws_access_key_id: SecretStr aws_secret_access_key: SecretStr region_name: str = "us-west-2" @requires_dependencies( ["boto3", "numpy", "langchain_community"], extras="bedrock", ) def get_client(self) -> "BedrockEmbeddings": # delay import only when needed import boto3 from langchain_community.embeddings import BedrockEmbeddings bedrock_runtime = boto3.client( service_name="bedrock-runtime", aws_access_key_id=self.aws_access_key_id.get_secret_value(), aws_secret_access_key=self.aws_secret_access_key.get_secret_value(), region_name=self.region_name, ) bedrock_client = BedrockEmbeddings(client=bedrock_runtime) return bedrock_client @dataclass class BedrockEmbeddingEncoder(BaseEmbeddingEncoder): config: BedrockEmbeddingConfig def get_exemplary_embedding(self) -> List[float]: return self.embed_query(query="Q") def __post_init__(self): self.initialize() def num_of_dimensions(self): exemplary_embedding = self.get_exemplary_embedding() return np.shape(exemplary_embedding) def is_unit_vector(self): exemplary_embedding = self.get_exemplary_embedding() return np.isclose(np.linalg.norm(exemplary_embedding), 1.0) def embed_query(self, query): bedrock_client = self.config.get_client() return np.array(bedrock_client.embed_query(query)) def embed_documents(self, elements: List[Element]) -> List[Element]: bedrock_client = self.config.get_client() embeddings = bedrock_client.embed_documents([str(e) for e in elements]) elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings) return elements_with_embeddings def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]: assert len(elements) == len(embeddings) elements_w_embedding = [] for i, element in enumerate(elements): element.embeddings = embeddings[i] elements_w_embedding.append(element) return elements
Memory