import functools from typing import ( TYPE_CHECKING, Callable, Dict, Generic, Optional, Any, Set, TypeVar, Union, cast, ) from chromadb.types import Metadata import numpy as np from uuid import UUID import chromadb.utils.embedding_functions as ef from chromadb.api.types import ( URI, URIs, AddRequest, BaseRecordSet, CollectionMetadata, DataLoader, DeleteRequest, Embedding, Embeddings, FilterSet, GetRequest, IncludeEnum, PyEmbedding, Embeddable, GetResult, Include, Loadable, Document, Image, QueryRequest, QueryResult, IDs, EmbeddingFunction, ID, OneOrMany, UpdateRequest, UpsertRequest, get_default_embeddable_record_set_fields, maybe_cast_one_to_many, normalize_base_record_set, normalize_insert_record_set, validate_base_record_set, validate_ids, validate_include, validate_insert_record_set, validate_metadata, validate_embedding_function, validate_n_results, validate_record_set_contains_any, validate_record_set_for_embedding, validate_filter_set, ) # TODO: We should rename the types in chromadb.types to be Models where # appropriate. This will help to distinguish between manipulation objects # which are essentially API views. And the actual data models which are # stored / retrieved / transmitted. from chromadb.types import Collection as CollectionModel, Where, WhereDocument import logging logger = logging.getLogger(__name__) if TYPE_CHECKING: from chromadb.api import ServerAPI, AsyncServerAPI ClientT = TypeVar("ClientT", "ServerAPI", "AsyncServerAPI") T = TypeVar("T") def validation_context(name: str) -> Callable[[Callable[..., T]], Callable[..., T]]: """A decorator that wraps a method with a try-except block that catches exceptions and adds the method name to the error message. This allows us to provide more context when an error occurs, without rewriting validators. """ def decorator(func: Callable[..., T]) -> Callable[..., T]: @functools.wraps(func) def wrapper(self: Any, *args: Any, **kwargs: Any) -> T: try: return func(self, *args, **kwargs) except Exception as e: msg = f"{str(e)} in {name}." raise type(e)(msg).with_traceback(e.__traceback__) return wrapper return decorator class CollectionCommon(Generic[ClientT]): _model: CollectionModel _client: ClientT _embedding_function: Optional[EmbeddingFunction[Embeddable]] _data_loader: Optional[DataLoader[Loadable]] def __init__( self, client: ClientT, model: CollectionModel, embedding_function: Optional[ EmbeddingFunction[Embeddable] ] = ef.DefaultEmbeddingFunction(), # type: ignore data_loader: Optional[DataLoader[Loadable]] = None, ): """Initializes a new instance of the Collection class.""" self._client = client self._model = model # Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol if embedding_function is not None: validate_embedding_function(embedding_function) self._embedding_function = embedding_function self._data_loader = data_loader # Expose the model properties as read-only properties on the Collection class @property def id(self) -> UUID: return self._model.id @property def name(self) -> str: return self._model.name @property def configuration_json(self) -> Dict[str, Any]: return self._model.configuration_json @property def metadata(self) -> CollectionMetadata: return cast(CollectionMetadata, self._model.metadata) @property def tenant(self) -> str: return self._model.tenant @property def database(self) -> str: return self._model.database def __eq__(self, other: object) -> bool: if not isinstance(other, CollectionCommon): return False id_match = self.id == other.id name_match = self.name == other.name configuration_match = self.configuration_json == other.configuration_json metadata_match = self.metadata == other.metadata tenant_match = self.tenant == other.tenant database_match = self.database == other.database embedding_function_match = self._embedding_function == other._embedding_function data_loader_match = self._data_loader == other._data_loader return ( id_match and name_match and configuration_match and metadata_match and tenant_match and database_match and embedding_function_match and data_loader_match ) def __repr__(self) -> str: return f"Collection(name={self.name})" def get_model(self) -> CollectionModel: return self._model @validation_context("add") def _validate_and_prepare_add_request( self, ids: OneOrMany[ID], embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], ] ], metadatas: Optional[OneOrMany[Metadata]], documents: Optional[OneOrMany[Document]], images: Optional[OneOrMany[Image]], uris: Optional[OneOrMany[URI]], ) -> AddRequest: # Unpack add_records = normalize_insert_record_set( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, ) # Validate validate_insert_record_set(record_set=add_records) validate_record_set_contains_any(record_set=add_records, contains_any={"ids"}) # Prepare if add_records["embeddings"] is None: validate_record_set_for_embedding(record_set=add_records) add_embeddings = self._embed_record_set(record_set=add_records) else: add_embeddings = add_records["embeddings"] return AddRequest( ids=add_records["ids"], embeddings=add_embeddings, metadatas=add_records["metadatas"], documents=add_records["documents"], uris=add_records["uris"], ) @validation_context("get") def _validate_and_prepare_get_request( self, ids: Optional[OneOrMany[ID]], where: Optional[Where], where_document: Optional[WhereDocument], include: Include, ) -> GetRequest: # Unpack unpacked_ids: Optional[IDs] = maybe_cast_one_to_many(target=ids) filters = FilterSet(where=where, where_document=where_document) # Validate if unpacked_ids is not None: validate_ids(ids=unpacked_ids) validate_filter_set(filter_set=filters) validate_include(include=include, dissalowed=[IncludeEnum.distances]) if IncludeEnum.data in include and self._data_loader is None: raise ValueError( "You must set a data loader on the collection if loading from URIs." ) # Prepare request_include = include # We need to include uris in the result from the API to load datas if IncludeEnum.data in include and IncludeEnum.uris not in include: request_include.append(IncludeEnum.uris) return GetRequest( ids=unpacked_ids, where=filters["where"], where_document=filters["where_document"], include=request_include, ) @validation_context("query") def _validate_and_prepare_query_request( self, query_embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], ] ], query_texts: Optional[OneOrMany[Document]], query_images: Optional[OneOrMany[Image]], query_uris: Optional[OneOrMany[URI]], n_results: int, where: Optional[Where], where_document: Optional[WhereDocument], include: Include, ) -> QueryRequest: # Unpack query_records = normalize_base_record_set( embeddings=query_embeddings, documents=query_texts, images=query_images, uris=query_uris, ) filters = FilterSet( where=where, where_document=where_document, ) # Validate validate_base_record_set(record_set=query_records) validate_filter_set(filter_set=filters) validate_include(include=include) validate_n_results(n_results=n_results) # Prepare if query_records["embeddings"] is None: validate_record_set_for_embedding(record_set=query_records) request_embeddings = self._embed_record_set(record_set=query_records) else: request_embeddings = query_records["embeddings"] request_where = filters["where"] request_where_document = filters["where_document"] # We need to manually include uris in the result from the API to load datas request_include = include if "data" in request_include and "uris" not in request_include: request_include.append(IncludeEnum.uris) return QueryRequest( embeddings=request_embeddings, where=request_where, where_document=request_where_document, include=request_include, n_results=n_results, ) @validation_context("update") def _validate_and_prepare_update_request( self, ids: OneOrMany[ID], embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], ] ], metadatas: Optional[OneOrMany[Metadata]], documents: Optional[OneOrMany[Document]], images: Optional[OneOrMany[Image]], uris: Optional[OneOrMany[URI]], ) -> UpdateRequest: # Unpack update_records = normalize_insert_record_set( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, ) # Validate validate_insert_record_set(record_set=update_records) # Prepare if update_records["embeddings"] is None: # TODO: Handle URI updates. if ( update_records["documents"] is not None or update_records["images"] is not None ): validate_record_set_for_embedding( update_records, embeddable_fields={"documents", "images"} ) update_embeddings = self._embed_record_set(record_set=update_records) else: update_embeddings = None else: update_embeddings = update_records["embeddings"] return UpdateRequest( ids=update_records["ids"], embeddings=update_embeddings, metadatas=update_records["metadatas"], documents=update_records["documents"], uris=update_records["uris"], ) @validation_context("upsert") def _validate_and_prepare_upsert_request( self, ids: OneOrMany[ID], embeddings: Optional[ Union[ OneOrMany[Embedding], OneOrMany[PyEmbedding], ] ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> UpsertRequest: # Unpack upsert_records = normalize_insert_record_set( ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, images=images, uris=uris, ) # Validate validate_insert_record_set(record_set=upsert_records) # Prepare if upsert_records["embeddings"] is None: validate_record_set_for_embedding( record_set=upsert_records, embeddable_fields={"documents", "images"} ) upsert_embeddings = self._embed_record_set(record_set=upsert_records) else: upsert_embeddings = upsert_records["embeddings"] return UpsertRequest( ids=upsert_records["ids"], metadatas=upsert_records["metadatas"], embeddings=upsert_embeddings, documents=upsert_records["documents"], uris=upsert_records["uris"], ) @validation_context("delete") def _validate_and_prepare_delete_request( self, ids: Optional[IDs], where: Optional[Where], where_document: Optional[WhereDocument], ) -> DeleteRequest: if ids is None and where is None and where_document is None: raise ValueError( "At least one of ids, where, or where_document must be provided" ) # Unpack if ids is not None: request_ids = cast(IDs, maybe_cast_one_to_many(ids)) else: request_ids = None filters = FilterSet(where=where, where_document=where_document) # Validate if request_ids is not None: validate_ids(ids=request_ids) validate_filter_set(filter_set=filters) return DeleteRequest( ids=request_ids, where=where, where_document=where_document ) def _transform_peek_response(self, response: GetResult) -> GetResult: if response["embeddings"] is not None: response["embeddings"] = np.array(response["embeddings"]) return response def _transform_get_response( self, response: GetResult, include: Include ) -> GetResult: if ( "data" in include and self._data_loader is not None and response["uris"] is not None ): response["data"] = self._data_loader(response["uris"]) if "embeddings" in include: response["embeddings"] = np.array(response["embeddings"]) # Remove URIs from the result if they weren't requested if "uris" not in include: response["uris"] = None return response def _transform_query_response( self, response: QueryResult, include: Include ) -> QueryResult: if ( "data" in include and self._data_loader is not None and response["uris"] is not None ): response["data"] = [self._data_loader(uris) for uris in response["uris"]] if "embeddings" in include and response["embeddings"] is not None: response["embeddings"] = [ np.array(embedding) for embedding in response["embeddings"] ] # Remove URIs from the result if they weren't requested if "uris" not in include: response["uris"] = None return response def _validate_modify_request(self, metadata: Optional[CollectionMetadata]) -> None: if metadata is not None: validate_metadata(metadata) if "hnsw:space" in metadata: raise ValueError( "Changing the distance function of a collection once it is created is not supported currently." ) def _update_model_after_modify_success( self, name: Optional[str], metadata: Optional[CollectionMetadata] ) -> None: if name: self._model["name"] = name if metadata: self._model["metadata"] = metadata def _embed_record_set( self, record_set: BaseRecordSet, embeddable_fields: Optional[Set[str]] = None ) -> Embeddings: if embeddable_fields is None: embeddable_fields = get_default_embeddable_record_set_fields() for field in embeddable_fields: if record_set[field] is not None: # type: ignore[literal-required] # uris require special handling if field == "uris": if self._data_loader is None: raise ValueError( "You must set a data loader on the collection if loading from URIs." ) return self._embed( input=self._data_loader(uris=cast(URIs, record_set[field])) # type: ignore[literal-required] ) else: return self._embed(input=record_set[field]) # type: ignore[literal-required] raise ValueError( "Record does not contain any non-None fields that can be embedded." f"Embeddable Fields: {embeddable_fields}" f"Record Fields: {record_set}" ) def _embed(self, input: Any) -> Embeddings: if self._embedding_function is None: raise ValueError( "You must provide an embedding function to compute embeddings." "https://docs.trychroma.com/guides/embeddings" ) return self._embedding_function(input=input)
Memory