from typing import Dict, Optional, Sequence, Tuple, TypedDict, Union, cast from uuid import UUID import numpy as np from numpy.typing import NDArray import chromadb.proto.chroma_pb2 as chroma_pb import chromadb.proto.query_executor_pb2 as query_pb from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.api.types import Embedding, Where, WhereDocument from chromadb.execution.expression.operator import ( KNN, Filter, Limit, Projection, Scan, ) from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.types import ( Collection, LogRecord, Metadata, Operation, OperationRecord, RequestVersionContext, ScalarEncoding, Segment, SegmentScope, SeqId, UpdateMetadata, Vector, VectorEmbeddingRecord, VectorQueryResult, ) class ProjectionRecord(TypedDict): id: str document: Optional[str] embedding: Optional[Vector] metadata: Optional[Metadata] class KNNProjectionRecord(TypedDict): record: ProjectionRecord distance: Optional[float] # TODO: Unit tests for this file, handling optional states etc def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> chroma_pb.Vector: if encoding == ScalarEncoding.FLOAT32: as_bytes = np.array(vector, dtype=np.float32).tobytes() proto_encoding = chroma_pb.ScalarEncoding.FLOAT32 elif encoding == ScalarEncoding.INT32: as_bytes = np.array(vector, dtype=np.int32).tobytes() proto_encoding = chroma_pb.ScalarEncoding.INT32 else: raise ValueError( f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \ or {ScalarEncoding.INT32}" ) return chroma_pb.Vector(dimension=vector.size, vector=as_bytes, encoding=proto_encoding) def from_proto_vector(vector: chroma_pb.Vector) -> Tuple[Embedding, ScalarEncoding]: encoding = vector.encoding as_array: Union[NDArray[np.int32], NDArray[np.float32]] if encoding == chroma_pb.ScalarEncoding.FLOAT32: as_array = np.frombuffer(vector.vector, dtype=np.float32) out_encoding = ScalarEncoding.FLOAT32 elif encoding == chroma_pb.ScalarEncoding.INT32: as_array = np.frombuffer(vector.vector, dtype=np.int32) out_encoding = ScalarEncoding.INT32 else: raise ValueError( f"Unknown encoding {encoding}, expected one of \ {chroma_pb.ScalarEncoding.FLOAT32} or {chroma_pb.ScalarEncoding.INT32}" ) return (as_array, out_encoding) def from_proto_operation(operation: chroma_pb.Operation) -> Operation: if operation == chroma_pb.Operation.ADD: return Operation.ADD elif operation == chroma_pb.Operation.UPDATE: return Operation.UPDATE elif operation == chroma_pb.Operation.UPSERT: return Operation.UPSERT elif operation == chroma_pb.Operation.DELETE: return Operation.DELETE else: # TODO: full error raise RuntimeError(f"Unknown operation {operation}") def from_proto_metadata(metadata: chroma_pb.UpdateMetadata) -> Optional[Metadata]: return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False)) def from_proto_update_metadata( metadata: chroma_pb.UpdateMetadata, ) -> Optional[UpdateMetadata]: return cast( Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True) ) def _from_proto_metadata_handle_none( metadata: chroma_pb.UpdateMetadata, is_update: bool ) -> Optional[Union[UpdateMetadata, Metadata]]: if not metadata.metadata: return None out_metadata: Dict[str, Union[str, int, float, bool, None]] = {} for key, value in metadata.metadata.items(): if value.HasField("bool_value"): out_metadata[key] = value.bool_value elif value.HasField("string_value"): out_metadata[key] = value.string_value elif value.HasField("int_value"): out_metadata[key] = value.int_value elif value.HasField("float_value"): out_metadata[key] = value.float_value elif is_update: out_metadata[key] = None else: raise ValueError(f"Metadata key {key} value cannot be None") return out_metadata def to_proto_update_metadata(metadata: UpdateMetadata) -> chroma_pb.UpdateMetadata: return chroma_pb.UpdateMetadata( metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()} ) def from_proto_submit( operation_record: chroma_pb.OperationRecord, seq_id: SeqId ) -> LogRecord: embedding, encoding = from_proto_vector(operation_record.vector) record = LogRecord( log_offset=seq_id, record=OperationRecord( id=operation_record.id, embedding=embedding, encoding=encoding, metadata=from_proto_update_metadata(operation_record.metadata), operation=from_proto_operation(operation_record.operation), ), ) return record def from_proto_segment(segment: chroma_pb.Segment) -> Segment: return Segment( id=UUID(hex=segment.id), type=segment.type, scope=from_proto_segment_scope(segment.scope), collection=UUID(hex=segment.collection), metadata=from_proto_metadata(segment.metadata) if segment.HasField("metadata") else None, file_paths={name: [path for path in paths.paths] for name, paths in segment.file_paths.items()} ) def to_proto_segment(segment: Segment) -> chroma_pb.Segment: return chroma_pb.Segment( id=segment["id"].hex, type=segment["type"], scope=to_proto_segment_scope(segment["scope"]), collection=segment["collection"].hex, metadata=None if segment["metadata"] is None else to_proto_update_metadata(segment["metadata"]), file_paths={name: chroma_pb.FilePaths(paths=paths) for name, paths in segment["file_paths"].items()} ) def from_proto_segment_scope(segment_scope: chroma_pb.SegmentScope) -> SegmentScope: if segment_scope == chroma_pb.SegmentScope.VECTOR: return SegmentScope.VECTOR elif segment_scope == chroma_pb.SegmentScope.METADATA: return SegmentScope.METADATA elif segment_scope == chroma_pb.SegmentScope.RECORD: return SegmentScope.RECORD else: raise RuntimeError(f"Unknown segment scope {segment_scope}") def to_proto_segment_scope(segment_scope: SegmentScope) -> chroma_pb.SegmentScope: if segment_scope == SegmentScope.VECTOR: return chroma_pb.SegmentScope.VECTOR elif segment_scope == SegmentScope.METADATA: return chroma_pb.SegmentScope.METADATA elif segment_scope == SegmentScope.RECORD: return chroma_pb.SegmentScope.RECORD else: raise RuntimeError(f"Unknown segment scope {segment_scope}") def to_proto_metadata_update_value( value: Union[str, int, float, bool, None] ) -> chroma_pb.UpdateMetadataValue: # Be careful with the order here. Since bools are a subtype of int in python, # isinstance(value, bool) and isinstance(value, int) both return true # for a value of bool type. if isinstance(value, bool): return chroma_pb.UpdateMetadataValue(bool_value=value) elif isinstance(value, str): return chroma_pb.UpdateMetadataValue(string_value=value) elif isinstance(value, int): return chroma_pb.UpdateMetadataValue(int_value=value) elif isinstance(value, float): return chroma_pb.UpdateMetadataValue(float_value=value) # None is used to delete the metadata key. elif value is None: return chroma_pb.UpdateMetadataValue() else: raise ValueError( f"Unknown metadata value type {type(value)}, expected one of str, int, \ float, or None" ) def from_proto_collection(collection: chroma_pb.Collection) -> Collection: return Collection( id=UUID(hex=collection.id), name=collection.name, configuration=CollectionConfigurationInternal.from_json_str( collection.configuration_json_str ), metadata=from_proto_metadata(collection.metadata) if collection.HasField("metadata") else None, dimension=collection.dimension if collection.HasField("dimension") and collection.dimension else None, database=collection.database, tenant=collection.tenant, version=collection.version, log_position=collection.log_position, ) def to_proto_collection(collection: Collection) -> chroma_pb.Collection: return chroma_pb.Collection( id=collection["id"].hex, name=collection["name"], configuration_json_str=collection.get_configuration().to_json_str(), metadata=None if collection["metadata"] is None else to_proto_update_metadata(collection["metadata"]), dimension=collection["dimension"], tenant=collection["tenant"], database=collection["database"], log_position=collection["log_position"], version=collection["version"], ) def to_proto_operation(operation: Operation) -> chroma_pb.Operation: if operation == Operation.ADD: return chroma_pb.Operation.ADD elif operation == Operation.UPDATE: return chroma_pb.Operation.UPDATE elif operation == Operation.UPSERT: return chroma_pb.Operation.UPSERT elif operation == Operation.DELETE: return chroma_pb.Operation.DELETE else: raise ValueError( f"Unknown operation {operation}, expected one of {Operation.ADD}, \ {Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}" ) def to_proto_submit( submit_record: OperationRecord, ) -> chroma_pb.OperationRecord: vector = None if submit_record["embedding"] is not None and submit_record["encoding"] is not None: vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"]) metadata = None if submit_record["metadata"] is not None: metadata = to_proto_update_metadata(submit_record["metadata"]) return chroma_pb.OperationRecord( id=submit_record["id"], vector=vector, metadata=metadata, operation=to_proto_operation(submit_record["operation"]), ) def from_proto_vector_embedding_record( embedding_record: chroma_pb.VectorEmbeddingRecord, ) -> VectorEmbeddingRecord: return VectorEmbeddingRecord( id=embedding_record.id, embedding=from_proto_vector(embedding_record.vector)[0], ) def to_proto_vector_embedding_record( embedding_record: VectorEmbeddingRecord, encoding: ScalarEncoding, ) -> chroma_pb.VectorEmbeddingRecord: return chroma_pb.VectorEmbeddingRecord( id=embedding_record["id"], vector=to_proto_vector(embedding_record["embedding"], encoding), ) def from_proto_vector_query_result( vector_query_result: chroma_pb.VectorQueryResult, ) -> VectorQueryResult: return VectorQueryResult( id=vector_query_result.id, distance=vector_query_result.distance, embedding=from_proto_vector(vector_query_result.vector)[0], ) def from_proto_request_version_context( request_version_context: chroma_pb.RequestVersionContext, ) -> RequestVersionContext: return RequestVersionContext( collection_version=request_version_context.collection_version, log_position=request_version_context.log_position, ) def to_proto_request_version_context( request_version_context: RequestVersionContext, ) -> chroma_pb.RequestVersionContext: return chroma_pb.RequestVersionContext( collection_version=request_version_context["collection_version"], log_position=request_version_context["log_position"], ) def to_proto_where(where: Where) -> chroma_pb.Where: response = chroma_pb.Where() if len(where) != 1: raise ValueError(f"Expected where to have exactly one operator, got {where}") for key, value in where.items(): if not isinstance(key, str): raise ValueError(f"Expected where key to be a str, got {key}") if key == "$and" or key == "$or": if not isinstance(value, list): raise ValueError( f"Expected where value for $and or $or to be a list of where expressions, got {value}" ) children: chroma_pb.WhereChildren = chroma_pb.WhereChildren( children=[to_proto_where(w) for w in value] ) if key == "$and": children.operator = chroma_pb.BooleanOperator.AND else: children.operator = chroma_pb.BooleanOperator.OR response.children.CopyFrom(children) return response # At this point we know we're at a direct comparison. It can either # be of the form {"key": "value"} or {"key": {"$operator": "value"}}. dc = chroma_pb.DirectComparison() dc.key = key if not isinstance(value, dict): # {'key': 'value'} case if type(value) is str: ssc = chroma_pb.SingleStringComparison() ssc.value = value ssc.comparator = chroma_pb.GenericComparator.EQ dc.single_string_operand.CopyFrom(ssc) elif type(value) is bool: sbc = chroma_pb.SingleBoolComparison() sbc.value = value sbc.comparator = chroma_pb.GenericComparator.EQ dc.single_bool_operand.CopyFrom(sbc) elif type(value) is int: sic = chroma_pb.SingleIntComparison() sic.value = value sic.generic_comparator = chroma_pb.GenericComparator.EQ dc.single_int_operand.CopyFrom(sic) elif type(value) is float: sdc = chroma_pb.SingleDoubleComparison() sdc.value = value sdc.generic_comparator = chroma_pb.GenericComparator.EQ dc.single_double_operand.CopyFrom(sdc) else: raise ValueError( f"Expected where value to be a string, int, or float, got {value}" ) else: for operator, operand in value.items(): if operator in ["$in", "$nin"]: if not isinstance(operand, list): raise ValueError( f"Expected where value for $in or $nin to be a list of values, got {value}" ) if len(operand) == 0 or not all( isinstance(x, type(operand[0])) for x in operand ): raise ValueError( f"Expected where operand value to be a non-empty list, and all values to be of the same type " f"got {operand}" ) list_operator = None if operator == "$in": list_operator = chroma_pb.ListOperator.IN else: list_operator = chroma_pb.ListOperator.NIN if type(operand[0]) is str: slo = chroma_pb.StringListComparison() for x in operand: slo.values.extend([x]) # type: ignore slo.list_operator = list_operator dc.string_list_operand.CopyFrom(slo) elif type(operand[0]) is bool: blo = chroma_pb.BoolListComparison() for x in operand: blo.values.extend([x]) # type: ignore blo.list_operator = list_operator dc.bool_list_operand.CopyFrom(blo) elif type(operand[0]) is int: ilo = chroma_pb.IntListComparison() for x in operand: ilo.values.extend([x]) # type: ignore ilo.list_operator = list_operator dc.int_list_operand.CopyFrom(ilo) elif type(operand[0]) is float: dlo = chroma_pb.DoubleListComparison() for x in operand: dlo.values.extend([x]) # type: ignore dlo.list_operator = list_operator dc.double_list_operand.CopyFrom(dlo) else: raise ValueError( f"Expected where operand value to be a list of strings, ints, or floats, got {operand}" ) elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]: # Direct comparison to a single value. if type(operand) is str: ssc = chroma_pb.SingleStringComparison() ssc.value = operand if operator == "$eq": ssc.comparator = chroma_pb.GenericComparator.EQ elif operator == "$ne": ssc.comparator = chroma_pb.GenericComparator.NE else: raise ValueError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_string_operand.CopyFrom(ssc) elif type(operand) is bool: sbc = chroma_pb.SingleBoolComparison() sbc.value = operand if operator == "$eq": sbc.comparator = chroma_pb.GenericComparator.EQ elif operator == "$ne": sbc.comparator = chroma_pb.GenericComparator.NE else: raise ValueError( f"Expected where operator to be $eq or $ne, got {operator}" ) dc.single_bool_operand.CopyFrom(sbc) elif type(operand) is int: sic = chroma_pb.SingleIntComparison() sic.value = operand if operator == "$eq": sic.generic_comparator = chroma_pb.GenericComparator.EQ elif operator == "$ne": sic.generic_comparator = chroma_pb.GenericComparator.NE elif operator == "$gt": sic.number_comparator = chroma_pb.NumberComparator.GT elif operator == "$lt": sic.number_comparator = chroma_pb.NumberComparator.LT elif operator == "$gte": sic.number_comparator = chroma_pb.NumberComparator.GTE elif operator == "$lte": sic.number_comparator = chroma_pb.NumberComparator.LTE else: raise ValueError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" ) dc.single_int_operand.CopyFrom(sic) elif type(operand) is float: sfc = chroma_pb.SingleDoubleComparison() sfc.value = operand if operator == "$eq": sfc.generic_comparator = chroma_pb.GenericComparator.EQ elif operator == "$ne": sfc.generic_comparator = chroma_pb.GenericComparator.NE elif operator == "$gt": sfc.number_comparator = chroma_pb.NumberComparator.GT elif operator == "$lt": sfc.number_comparator = chroma_pb.NumberComparator.LT elif operator == "$gte": sfc.number_comparator = chroma_pb.NumberComparator.GTE elif operator == "$lte": sfc.number_comparator = chroma_pb.NumberComparator.LTE else: raise ValueError( f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}" ) dc.single_double_operand.CopyFrom(sfc) else: raise ValueError( f"Expected where operand value to be a string, int, or float, got {operand}" ) else: # This case should never happen, as we've already # handled the case for direct comparisons. pass response.direct_comparison.CopyFrom(dc) return response def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDocument: response = chroma_pb.WhereDocument() if len(where_document) != 1: raise ValueError( f"Expected where_document to have exactly one operator, got {where_document}" ) for operator, operand in where_document.items(): if operator == "$and" or operator == "$or": # Nested "$and" or "$or" expression. if not isinstance(operand, list): raise ValueError( f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}" ) children: chroma_pb.WhereDocumentChildren = chroma_pb.WhereDocumentChildren( children=[to_proto_where_document(w) for w in operand] ) if operator == "$and": children.operator = chroma_pb.BooleanOperator.AND else: children.operator = chroma_pb.BooleanOperator.OR response.children.CopyFrom(children) else: # Direct "$contains" or "$not_contains" comparison to a single # value. if not isinstance(operand, str): raise ValueError( f"Expected where_document operand to be a string, got {operand}" ) dwd = chroma_pb.DirectWhereDocument() dwd.document = operand if operator == "$contains": dwd.operator = chroma_pb.WhereDocumentOperator.CONTAINS elif operator == "$not_contains": dwd.operator = chroma_pb.WhereDocumentOperator.NOT_CONTAINS else: raise ValueError( f"Expected where_document operator to be one of $contains, $not_contains, got {operator}" ) response.direct.CopyFrom(dwd) return response def to_proto_scan(scan: Scan) -> query_pb.ScanOperator: return query_pb.ScanOperator( collection=to_proto_collection(scan.collection), knn=to_proto_segment(scan.knn), metadata=to_proto_segment(scan.metadata), record=to_proto_segment(scan.record), ) def to_proto_filter(filter: Filter) -> query_pb.FilterOperator: return query_pb.FilterOperator( ids=chroma_pb.UserIds(ids=filter.user_ids) if filter.user_ids is not None else None, where=to_proto_where(filter.where) if filter.where else None, where_document=to_proto_where_document(filter.where_document) if filter.where_document else None, ) def to_proto_knn(knn: KNN) -> query_pb.KNNOperator: return query_pb.KNNOperator( embeddings=[ to_proto_vector(vector=embedding, encoding=ScalarEncoding.FLOAT32) for embedding in knn.embeddings ], fetch=knn.fetch, ) def to_proto_limit(limit: Limit) -> query_pb.LimitOperator: return query_pb.LimitOperator(skip=limit.skip, fetch=limit.fetch) def to_proto_projection(projection: Projection) -> query_pb.ProjectionOperator: return query_pb.ProjectionOperator( document=projection.document, embedding=projection.embedding, metadata=projection.metadata, ) def to_proto_knn_projection(projection: Projection) -> query_pb.KNNProjectionOperator: return query_pb.KNNProjectionOperator( projection=to_proto_projection(projection), distance=projection.rank ) def to_proto_count_plan(count: CountPlan) -> query_pb.CountPlan: return query_pb.CountPlan(scan=to_proto_scan(count.scan)) def from_proto_count_result(result: query_pb.CountResult) -> int: return result.count def to_proto_get_plan(get: GetPlan) -> query_pb.GetPlan: return query_pb.GetPlan( scan=to_proto_scan(get.scan), filter=to_proto_filter(get.filter), limit=to_proto_limit(get.limit), projection=to_proto_projection(get.projection), ) def from_proto_projection_record(record: query_pb.ProjectionRecord) -> ProjectionRecord: return ProjectionRecord( id=record.id, document=record.document if record.document else None, embedding=from_proto_vector(record.embedding)[0] if record.embedding is not None else None, metadata=from_proto_metadata(record.metadata), ) def from_proto_get_result(result: query_pb.GetResult) -> Sequence[ProjectionRecord]: return [from_proto_projection_record(record) for record in result.records] def to_proto_knn_plan(knn: KNNPlan) -> query_pb.KNNPlan: return query_pb.KNNPlan( scan=to_proto_scan(knn.scan), filter=to_proto_filter(knn.filter), knn=to_proto_knn(knn.knn), projection=to_proto_knn_projection(knn.projection), ) def from_proto_knn_projection_record( record: query_pb.KNNProjectionRecord, ) -> KNNProjectionRecord: return KNNProjectionRecord( record=from_proto_projection_record(record.record), distance=record.distance ) def from_proto_knn_batch_result( results: query_pb.KNNBatchResult, ) -> Sequence[Sequence[KNNProjectionRecord]]: return [ [from_proto_knn_projection_record(record) for record in result.records] for result in results.results ]
Memory