import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import numpy as np
from pydantic import Field, SecretStr
from unstructured.documents.elements import Element
from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
from unstructured.utils import requires_dependencies
USER_AGENT = "@mixedbread-ai/unstructured"
BATCH_SIZE = 128
TIMEOUT = 60
MAX_RETRIES = 3
ENCODING_FORMAT = "float"
TRUNCATION_STRATEGY = "end"
if TYPE_CHECKING:
from mixedbread_ai.client import MixedbreadAI
from mixedbread_ai.core import RequestOptions
class MixedbreadAIEmbeddingConfig(EmbeddingConfig):
"""
Configuration class for Mixedbread AI Embedding Encoder.
Attributes:
api_key (str): API key for accessing Mixedbread AI..
model_name (str): Name of the model to use for embeddings.
"""
api_key: SecretStr = Field(
default_factory=lambda: SecretStr(os.environ.get("MXBAI_API_KEY")),
)
model_name: str = Field(
default="mixedbread-ai/mxbai-embed-large-v1",
)
@requires_dependencies(
["mixedbread_ai"],
extras="embed-mixedbreadai",
)
def get_client(self) -> "MixedbreadAI":
"""
Create the Mixedbread AI client.
Returns:
MixedbreadAI: Initialized client.
"""
from mixedbread_ai.client import MixedbreadAI
return MixedbreadAI(
api_key=self.api_key.get_secret_value(),
)
@dataclass
class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
"""
Embedding encoder for Mixedbread AI.
Attributes:
config (MixedbreadAIEmbeddingConfig): Configuration for the embedding encoder.
"""
config: MixedbreadAIEmbeddingConfig
_exemplary_embedding: Optional[List[float]] = field(init=False, default=None)
_request_options: Optional["RequestOptions"] = field(init=False, default=None)
def get_exemplary_embedding(self) -> List[float]:
"""Get an exemplary embedding to determine dimensions and unit vector status."""
return self._embed(["Q"])[0]
def initialize(self):
if self.config.api_key is None:
raise ValueError(
"The Mixedbread AI API key must be specified."
+ "You either pass it in the constructor using 'api_key'"
+ "or via the 'MXBAI_API_KEY' environment variable."
)
from mixedbread_ai.core import RequestOptions
self._request_options = RequestOptions(
max_retries=MAX_RETRIES,
timeout_in_seconds=TIMEOUT,
additional_headers={"User-Agent": USER_AGENT},
)
@property
def num_of_dimensions(self):
"""Get the number of dimensions for the embeddings."""
exemplary_embedding = self.get_exemplary_embedding()
return np.shape(exemplary_embedding)
@property
def is_unit_vector(self) -> bool:
"""Check if the embedding is a unit vector."""
exemplary_embedding = self.get_exemplary_embedding()
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
def _embed(self, texts: List[str]) -> List[List[float]]:
"""
Embed a list of texts using the Mixedbread AI API.
Args:
texts (List[str]): List of texts to embed.
Returns:
List[List[float]]: List of embeddings.
"""
batch_size = BATCH_SIZE
batch_itr = range(0, len(texts), batch_size)
responses = []
client = self.config.get_client()
for i in batch_itr:
batch = texts[i : i + batch_size]
response = client.embeddings(
model=self.config.model_name,
normalized=True,
encoding_format=ENCODING_FORMAT,
truncation_strategy=TRUNCATION_STRATEGY,
request_options=self._request_options,
input=batch,
)
responses.append(response)
return [item.embedding for response in responses for item in response.data]
@staticmethod
def _add_embeddings_to_elements(
elements: List[Element], embeddings: List[List[float]]
) -> List[Element]:
"""
Add embeddings to elements.
Args:
elements (List[Element]): List of elements.
embeddings (List[List[float]]): List of embeddings.
Returns:
List[Element]: Elements with embeddings added.
"""
assert len(elements) == len(embeddings)
elements_w_embedding = []
for i, element in enumerate(elements):
element.embeddings = embeddings[i]
elements_w_embedding.append(element)
return elements
def embed_documents(self, elements: List[Element]) -> List[Element]:
"""
Embed a list of document elements.
Args:
elements (List[Element]): List of document elements.
Returns:
List[Element]: Elements with embeddings.
"""
embeddings = self._embed([str(e) for e in elements])
return self._add_embeddings_to_elements(elements, embeddings)
def embed_query(self, query: str) -> List[float]:
"""
Embed a query string.
Args:
query (str): Query string to embed.
Returns:
List[float]: Embedding of the query.
"""
return self._embed([query])[0]