from typing import Optional, Sequence from overrides import overrides from chromadb.api.types import GetResult, IncludeEnum, Metadata, QueryResult from chromadb.config import System from chromadb.execution.executor.abstract import Executor from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.segment import MetadataReader, VectorReader from chromadb.segment.impl.manager.local import LocalSegmentManager from chromadb.types import Collection, VectorQuery, VectorQueryResult def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]: """Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map.""" if not metadata: return None result = {} for k, v in metadata.items(): if not k.startswith("chroma:"): result[k] = v if len(result) == 0: return None return result def _doc(metadata: Optional[Metadata]) -> Optional[str]: """Retrieve the document (if any) from a Metadata map""" if metadata and "chroma:document" in metadata: return str(metadata["chroma:document"]) return None def _uri(metadata: Optional[Metadata]) -> Optional[str]: """Retrieve the uri (if any) from a Metadata map""" if metadata and "chroma:uri" in metadata: return str(metadata["chroma:uri"]) return None class LocalExecutor(Executor): _manager: LocalSegmentManager def __init__(self, system: System): super().__init__(system) self._manager = self.require(LocalSegmentManager) @overrides def count(self, plan: CountPlan) -> int: return self._metadata_segment(plan.scan.collection).count(plan.scan.version) @overrides def get(self, plan: GetPlan) -> GetResult: records = self._metadata_segment(plan.scan.collection).get_metadata( request_version_context=plan.scan.version, where=plan.filter.where, where_document=plan.filter.where_document, ids=plan.filter.user_ids, limit=plan.limit.fetch, offset=plan.limit.skip, include_metadata=True, ) ids = [r["id"] for r in records] embeddings = None documents = None uris = None metadatas = None included = list() if plan.projection.embedding: if len(records) > 0: vectors = self._vector_segment(plan.scan.collection).get_vectors( ids=ids, request_version_context=plan.scan.version ) embeddings = [v["embedding"] for v in vectors] else: embeddings = list() included.append(IncludeEnum.embeddings) if plan.projection.document: documents = [_doc(r["metadata"]) for r in records] included.append(IncludeEnum.documents) if plan.projection.uri: uris = [_uri(r["metadata"]) for r in records] included.append(IncludeEnum.uris) if plan.projection.metadata: metadatas = [_clean_metadata(r["metadata"]) for r in records] included.append(IncludeEnum.metadatas) # TODO: Fix typing return GetResult( ids=ids, embeddings=embeddings, documents=documents, # type: ignore[typeddict-item] uris=uris, # type: ignore[typeddict-item] data=None, metadatas=metadatas, # type: ignore[typeddict-item] included=included, ) @overrides def knn(self, plan: KNNPlan) -> QueryResult: prefiltered_ids = None if plan.filter.user_ids or plan.filter.where or plan.filter.where_document: records = self._metadata_segment(plan.scan.collection).get_metadata( request_version_context=plan.scan.version, where=plan.filter.where, where_document=plan.filter.where_document, ids=plan.filter.user_ids, limit=None, offset=0, include_metadata=False, ) prefiltered_ids = [r["id"] for r in records] knns: Sequence[Sequence[VectorQueryResult]] = [[]] * len(plan.knn.embeddings) # Query vectors only when the user did not specify a filter or when the filter # yields non-empty ids. Otherwise, the user specified a filter but it yields # no matching ids, in which case we can return an empty result. if prefiltered_ids is None or len(prefiltered_ids) > 0: query = VectorQuery( vectors=plan.knn.embeddings, k=plan.knn.fetch, allowed_ids=prefiltered_ids, include_embeddings=plan.projection.embedding, options=None, request_version_context=plan.scan.version, ) knns = self._vector_segment(plan.scan.collection).query_vectors(query) ids = [[r["id"] for r in result] for result in knns] embeddings = None documents = None uris = None metadatas = None distances = None included = list() if plan.projection.embedding: embeddings = [[r["embedding"] for r in result] for result in knns] included.append(IncludeEnum.embeddings) if plan.projection.rank: distances = [[r["distance"] for r in result] for result in knns] included.append(IncludeEnum.distances) if plan.projection.document or plan.projection.metadata or plan.projection.uri: merged_ids = list(set([id for result in ids for id in result])) hydrated_records = self._metadata_segment( plan.scan.collection ).get_metadata( request_version_context=plan.scan.version, where=None, where_document=None, ids=merged_ids, limit=None, offset=0, include_metadata=True, ) metadata_by_id = {r["id"]: r["metadata"] for r in hydrated_records} if plan.projection.document: documents = [ [_doc(metadata_by_id.get(id, None)) for id in result] for result in ids ] included.append(IncludeEnum.documents) if plan.projection.uri: uris = [ [_uri(metadata_by_id.get(id, None)) for id in result] for result in ids ] included.append(IncludeEnum.uris) if plan.projection.metadata: metadatas = [ [_clean_metadata(metadata_by_id.get(id, None)) for id in result] for result in ids ] included.append(IncludeEnum.metadatas) # TODO: Fix typing return QueryResult( ids=ids, embeddings=embeddings, # type: ignore[typeddict-item] documents=documents, # type: ignore[typeddict-item] uris=uris, # type: ignore[typeddict-item] data=None, metadatas=metadatas, # type: ignore[typeddict-item] distances=distances, included=included, ) def _metadata_segment(self, collection: Collection) -> MetadataReader: return self._manager.get_segment(collection.id, MetadataReader) def _vector_segment(self, collection: Collection) -> VectorReader: return self._manager.get_segment(collection.id, VectorReader)
Memory