from tenacity import retry, stop_after_attempt, retry_if_exception, wait_fixed from chromadb.api import ServerAPI from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.auth import UserIdentity from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.db.system import SysDB from chromadb.quota import QuotaEnforcer, Action from chromadb.rate_limit import RateLimitEnforcer, AsyncRateLimitEnforcer from chromadb.segment import SegmentManager from chromadb.execution.executor.abstract import Executor from chromadb.execution.expression.operator import Scan, Filter, Limit, KNN, Projection from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.telemetry.opentelemetry import ( add_attributes_to_current_span, OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) from chromadb.telemetry.product import ProductTelemetryClient from chromadb.ingest import Producer from chromadb.types import Collection as CollectionModel from chromadb import __version__ from chromadb.errors import ( InvalidDimensionException, InvalidCollectionException, VersionMismatchError, ) from chromadb.api.types import ( CollectionMetadata, IDs, Embeddings, Metadatas, Documents, URIs, Where, WhereDocument, Include, IncludeEnum, GetResult, QueryResult, validate_metadata, validate_update_metadata, validate_where, validate_where_document, validate_batch, ) from chromadb.telemetry.product.events import ( CollectionAddEvent, CollectionDeleteEvent, CollectionGetEvent, CollectionUpdateEvent, CollectionQueryEvent, ClientCreateCollectionEvent, ) import chromadb.types as t from typing import ( Optional, Sequence, Generator, List, Any, Callable, TypeVar, ) from overrides import override from uuid import UUID, uuid4 from functools import wraps import time import logging import re T = TypeVar("T", bound=Callable[..., Any]) logger = logging.getLogger(__name__) # mimics s3 bucket requirements for naming def check_index_name(index_name: str) -> None: msg = ( "Expected collection name that " "(1) contains 3-63 characters, " "(2) starts and ends with an alphanumeric character, " "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " "(4) contains no two consecutive periods (..) and " "(5) is not a valid IPv4 address, " f"got {index_name}" ) if len(index_name) < 3 or len(index_name) > 63: raise ValueError(msg) if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): raise ValueError(msg) if ".." in index_name: raise ValueError(msg) if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): raise ValueError(msg) def rate_limit(func: T) -> T: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: self = args[0] return self._rate_limit_enforcer.rate_limit(func)(*args, **kwargs) return wrapper # type: ignore class SegmentAPI(ServerAPI): """API implementation utilizing the new segment-based internal architecture""" _settings: Settings _sysdb: SysDB _manager: SegmentManager _executor: Executor _producer: Producer _product_telemetry_client: ProductTelemetryClient _opentelemetry_client: OpenTelemetryClient _tenant_id: str _topic_ns: str _rate_limit_enforcer: RateLimitEnforcer def __init__(self, system: System): super().__init__(system) self._settings = system.settings self._sysdb = self.require(SysDB) self._manager = self.require(SegmentManager) self._executor = self.require(Executor) self._quota_enforcer = self.require(QuotaEnforcer) self._product_telemetry_client = self.require(ProductTelemetryClient) self._opentelemetry_client = self.require(OpenTelemetryClient) self._producer = self.require(Producer) self._rate_limit_enforcer = self._system.require(RateLimitEnforcer) @override def heartbeat(self) -> int: return int(time.time_ns()) @trace_method("SegmentAPI.create_database", OpenTelemetryGranularity.OPERATION) @override def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: if len(name) < 3: raise ValueError("Database name must be at least 3 characters long") self._quota_enforcer.enforce( action=Action.CREATE_DATABASE, tenant=tenant, name=name, ) self._sysdb.create_database( id=uuid4(), name=name, tenant=tenant, ) @trace_method("SegmentAPI.get_database", OpenTelemetryGranularity.OPERATION) @override def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database: return self._sysdb.get_database(name=name, tenant=tenant) @trace_method("SegmentAPI.delete_database", OpenTelemetryGranularity.OPERATION) @override def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: self._sysdb.delete_database(name=name, tenant=tenant) @trace_method("SegmentAPI.list_databases", OpenTelemetryGranularity.OPERATION) @override def list_databases( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, ) -> Sequence[t.Database]: return self._sysdb.list_databases(limit=limit, offset=offset, tenant=tenant) @trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION) @override def create_tenant(self, name: str) -> None: if len(name) < 3: raise ValueError("Tenant name must be at least 3 characters long") self._sysdb.create_tenant( name=name, ) @override def get_user_identity(self) -> UserIdentity: return UserIdentity( user_id="", tenant=DEFAULT_TENANT, databases=[DEFAULT_DATABASE], ) @trace_method("SegmentAPI.get_tenant", OpenTelemetryGranularity.OPERATION) @override def get_tenant(self, name: str) -> t.Tenant: return self._sysdb.get_tenant(name=name) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. @trace_method("SegmentAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override @rate_limit def create_collection( self, name: str, configuration: Optional[CollectionConfigurationInternal] = None, metadata: Optional[CollectionMetadata] = None, get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: if metadata is not None: validate_metadata(metadata) # TODO: remove backwards compatibility in naming requirements check_index_name(name) self._quota_enforcer.enforce( action=Action.CREATE_COLLECTION, tenant=tenant, name=name, metadata=metadata, ) id = uuid4() model = CollectionModel( id=id, name=name, metadata=metadata, configuration=configuration if configuration is not None else CollectionConfigurationInternal(), # Use default configuration if none is provided tenant=tenant, database=database, dimension=None, ) # TODO: Let sysdb create the collection directly from the model coll, created = self._sysdb.create_collection( id=model.id, name=model.name, configuration=model.get_configuration(), segments=[], # Passing empty till backend changes are deployed. metadata=model.metadata, dimension=None, # This is lazily populated on the first add get_or_create=get_or_create, tenant=tenant, database=database, ) if created: segments = self._manager.prepare_segments_for_new_collection(coll) for segment in segments: self._sysdb.create_segment(segment) else: logger.debug( f"Collection {name} already exists, returning existing collection." ) # TODO: This event doesn't capture the get_or_create case appropriately # TODO: Re-enable embedding function tracking in create_collection self._product_telemetry_client.capture( ClientCreateCollectionEvent( collection_uuid=str(id), # embedding_function=embedding_function.__class__.__name__, ) ) add_attributes_to_current_span({"collection_uuid": str(id)}) return coll @trace_method( "SegmentAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION ) @override @rate_limit def get_or_create_collection( self, name: str, configuration: Optional[CollectionConfigurationInternal] = None, metadata: Optional[CollectionMetadata] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: return self.create_collection( name=name, metadata=metadata, configuration=configuration, get_or_create=True, tenant=tenant, database=database, ) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings @trace_method("SegmentAPI.get_collection", OpenTelemetryGranularity.OPERATION) @override @rate_limit def get_collection( self, name: Optional[str] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: existing = self._sysdb.get_collections( name=name, tenant=tenant, database=database ) if existing: return existing[0] else: raise InvalidCollectionException(f"Collection {name} does not exist.") @trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION) @override @rate_limit def list_collections( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Sequence[CollectionModel]: self._quota_enforcer.enforce( action=Action.LIST_COLLECTIONS, tenant=tenant, limit=limit, ) return self._sysdb.get_collections( limit=limit, offset=offset, tenant=tenant, database=database ) @trace_method("SegmentAPI.count_collections", OpenTelemetryGranularity.OPERATION) @override @rate_limit def count_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> int: collection_count = len( self._sysdb.get_collections(tenant=tenant, database=database) ) return collection_count @trace_method("SegmentAPI._modify", OpenTelemetryGranularity.OPERATION) @override @rate_limit def _modify( self, id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: if new_name: # backwards compatibility in naming requirements (for now) check_index_name(new_name) if new_metadata: validate_update_metadata(new_metadata) # Ensure the collection exists _ = self._get_collection(id) self._quota_enforcer.enforce( action=Action.UPDATE_COLLECTION, tenant=tenant, name=new_name, metadata=new_metadata, ) # TODO eventually we'll want to use OptionalArgument and Unspecified in the # signature of `_modify` but not changing the API right now. if new_name and new_metadata: self._sysdb.update_collection(id, name=new_name, metadata=new_metadata) elif new_name: self._sysdb.update_collection(id, name=new_name) elif new_metadata: self._sysdb.update_collection(id, metadata=new_metadata) @trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override @rate_limit def delete_collection( self, name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: existing = self._sysdb.get_collections( name=name, tenant=tenant, database=database ) if existing: self._sysdb.delete_collection( existing[0].id, tenant=tenant, database=database ) self._manager.delete_segments(existing[0].id) else: raise ValueError(f"Collection {name} does not exist.") @trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION) @override @rate_limit def _add( self, ids: IDs, collection_id: UUID, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) validate_batch( (ids, embeddings, metadatas, documents, uris), {"max_batch_size": self.get_max_batch_size()}, ) records_to_submit = list( _records( t.Operation.ADD, ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, uris=uris, ) ) self._validate_embedding_record_set(coll, records_to_submit) self._quota_enforcer.enforce( action=Action.ADD, tenant=tenant, ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, uris=uris, ) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( CollectionAddEvent( collection_uuid=str(collection_id), add_amount=len(ids), with_metadata=len(ids) if metadatas is not None else 0, with_documents=len(ids) if documents is not None else 0, with_uris=len(ids) if uris is not None else 0, ) ) return True @trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION) @override @rate_limit def _update( self, collection_id: UUID, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) validate_batch( (ids, embeddings, metadatas, documents, uris), {"max_batch_size": self.get_max_batch_size()}, ) records_to_submit = list( _records( t.Operation.UPDATE, ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, uris=uris, ) ) self._validate_embedding_record_set(coll, records_to_submit) self._quota_enforcer.enforce( action=Action.UPDATE, tenant=tenant, ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, uris=uris, ) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( CollectionUpdateEvent( collection_uuid=str(collection_id), update_amount=len(ids), with_embeddings=len(embeddings) if embeddings else 0, with_metadata=len(metadatas) if metadatas else 0, with_documents=len(documents) if documents else 0, with_uris=len(uris) if uris else 0, ) ) return True @trace_method("SegmentAPI._upsert", OpenTelemetryGranularity.OPERATION) @override @rate_limit def _upsert( self, collection_id: UUID, ids: IDs, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) validate_batch( (ids, embeddings, metadatas, documents, uris), {"max_batch_size": self.get_max_batch_size()}, ) records_to_submit = list( _records( t.Operation.UPSERT, ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, uris=uris, ) ) self._validate_embedding_record_set(coll, records_to_submit) self._quota_enforcer.enforce( action=Action.UPSERT, tenant=tenant, ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents, uris=uris, ) self._producer.submit_embeddings(collection_id, records_to_submit) return True @trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION) @retry( # type: ignore[misc] retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)), wait=wait_fixed(2), stop=stop_after_attempt(5), reraise=True, ) @override @rate_limit def _get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = None, sort: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, page: Optional[int] = None, page_size: Optional[int] = None, where_document: Optional[WhereDocument] = None, include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item] tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> GetResult: add_attributes_to_current_span( { "collection_id": str(collection_id), "ids_count": len(ids) if ids else 0, } ) scan = self._scan(collection_id) # TODO: Replace with unified validation if where is not None: validate_where(where) if where_document is not None: validate_where_document(where_document) self._quota_enforcer.enforce( action=Action.GET, tenant=tenant, ids=ids, where=where, where_document=where_document, limit=limit, ) if sort is not None: raise NotImplementedError("Sorting is not yet supported") if page and page_size: offset = (page - 1) * page_size limit = page_size ids_amount = len(ids) if ids else 0 self._product_telemetry_client.capture( CollectionGetEvent( collection_uuid=str(collection_id), ids_count=ids_amount, limit=limit if limit else 0, include_metadata=ids_amount if "metadatas" in include else 0, include_documents=ids_amount if "documents" in include else 0, include_uris=ids_amount if "uris" in include else 0, ) ) return self._executor.get( GetPlan( scan, Filter(ids, where, where_document), Limit(offset or 0, limit), Projection( IncludeEnum.documents in include, IncludeEnum.embeddings in include, IncludeEnum.metadatas in include, False, IncludeEnum.uris in include, ), ) ) @trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION) @override @rate_limit def _delete( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: add_attributes_to_current_span( { "collection_id": str(collection_id), "ids_count": len(ids) if ids else 0, } ) # TODO: Replace with unified validation if where is not None: validate_where(where) if where_document is not None: validate_where_document(where_document) # You must have at least one of non-empty ids, where, or where_document. if ( (ids is None or (ids is not None and len(ids) == 0)) and (where is None or (where is not None and len(where) == 0)) and ( where_document is None or (where_document is not None and len(where_document) == 0) ) ): raise ValueError( """ You must provide either ids, where, or where_document to delete. If you want to delete all data in a collection you can delete the collection itself using the delete_collection method. Or alternatively, you can get() all the relevant ids and then delete them. """ ) scan = self._scan(collection_id) self._quota_enforcer.enforce( action=Action.DELETE, tenant=tenant, ids=ids, where=where, where_document=where_document, ) self._manager.hint_use_collection(collection_id, t.Operation.DELETE) if (where or where_document) or not ids: ids_to_delete = self._executor.get( GetPlan(scan, Filter(ids, where, where_document)) )["ids"] else: ids_to_delete = ids if len(ids_to_delete) == 0: return records_to_submit = list( _records(operation=t.Operation.DELETE, ids=ids_to_delete) ) self._validate_embedding_record_set(scan.collection, records_to_submit) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( CollectionDeleteEvent( collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) ) ) @trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION) @retry( # type: ignore[misc] retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)), wait=wait_fixed(2), stop=stop_after_attempt(5), reraise=True, ) @override @rate_limit def _count( self, collection_id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> int: add_attributes_to_current_span({"collection_id": str(collection_id)}) return self._executor.count(CountPlan(self._scan(collection_id))) @trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION) # We retry on version mismatch errors because the version of the collection # may have changed between the time we got the version and the time we # actually query the collection on the FE. We are fine with fixed # wait time because the version mismatch error is not a error due to # network issues or other transient issues. It is a result of the # collection being updated between the time we got the version and # the time we actually query the collection on the FE. @retry( # type: ignore[misc] retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)), wait=wait_fixed(2), stop=stop_after_attempt(5), reraise=True, ) @override @rate_limit def _query( self, collection_id: UUID, query_embeddings: Embeddings, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, include: Include = ["documents", "metadatas", "distances"], # type: ignore[list-item] tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> QueryResult: add_attributes_to_current_span( { "collection_id": str(collection_id), "n_results": n_results, "where": str(where), } ) query_amount = len(query_embeddings) self._product_telemetry_client.capture( CollectionQueryEvent( collection_uuid=str(collection_id), query_amount=query_amount, n_results=n_results, with_metadata_filter=query_amount if where is not None else 0, with_document_filter=query_amount if where_document is not None else 0, include_metadatas=query_amount if "metadatas" in include else 0, include_documents=query_amount if "documents" in include else 0, include_uris=query_amount if "uris" in include else 0, include_distances=query_amount if "distances" in include else 0, ) ) # TODO: Replace with unified validation if where is not None: validate_where(where) if where_document is not None: validate_where_document(where_document) scan = self._scan(collection_id) for embedding in query_embeddings: self._validate_dimension(scan.collection, len(embedding), update=False) self._quota_enforcer.enforce( action=Action.QUERY, tenant=tenant, where=where, where_document=where_document, query_embeddings=query_embeddings, n_results=n_results, ) return self._executor.knn( KNNPlan( scan, KNN(query_embeddings, n_results), Filter(None, where, where_document), Projection( IncludeEnum.documents in include, IncludeEnum.embeddings in include, IncludeEnum.metadatas in include, IncludeEnum.distances in include, IncludeEnum.uris in include, ), ) ) @trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION) @override @rate_limit def _peek( self, collection_id: UUID, n: int = 10, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> GetResult: add_attributes_to_current_span({"collection_id": str(collection_id)}) return self._get(collection_id, limit=n) # type: ignore @override def get_version(self) -> str: return __version__ @override def reset_state(self) -> None: pass @override def reset(self) -> bool: self._system.reset_state() return True @override def get_settings(self) -> Settings: return self._settings @override def get_max_batch_size(self) -> int: return self._producer.max_batch_size # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. # TODO: promote collection -> topic to a base class method so that it can be # used for channel assignment in the distributed version of the system. @trace_method( "SegmentAPI._validate_embedding_record_set", OpenTelemetryGranularity.ALL ) def _validate_embedding_record_set( self, collection: t.Collection, records: List[t.OperationRecord] ) -> None: """Validate the dimension of an embedding record before submitting it to the system.""" add_attributes_to_current_span({"collection_id": str(collection["id"])}) for record in records: if record["embedding"] is not None: self._validate_dimension( collection, len(record["embedding"]), update=True ) # This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings. def _validate_dimension( self, collection: t.Collection, dim: int, update: bool ) -> None: """Validate that a collection supports records of the given dimension. If update is true, update the collection if the collection doesn't already have a dimension.""" if collection["dimension"] is None: if update: id = collection.id self._sysdb.update_collection(id=id, dimension=dim) collection["dimension"] = dim elif collection["dimension"] != dim: raise InvalidDimensionException( f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}" ) else: return # all is well @trace_method("SegmentAPI._get_collection", OpenTelemetryGranularity.ALL) def _get_collection(self, collection_id: UUID) -> t.Collection: collections = self._sysdb.get_collections(id=collection_id) if not collections or len(collections) == 0: raise InvalidCollectionException( f"Collection {collection_id} does not exist." ) return collections[0] @trace_method("SegmentAPI._scan", OpenTelemetryGranularity.ALL) def _scan(self, collection_id: UUID) -> Scan: collection_and_segments = self._sysdb.get_collection_with_segments( collection_id ) # For now collection should have exactly one segment per scope: # - Local scopes: vector, metadata # - Distributed scopes: vector, metadata, record scope_to_segment = { segment["scope"]: segment for segment in collection_and_segments["segments"] } return Scan( collection=collection_and_segments["collection"], knn=scope_to_segment[t.SegmentScope.VECTOR], metadata=scope_to_segment[t.SegmentScope.METADATA], # Local chroma do not have record segment, and this is not used by the local executor record=scope_to_segment.get(t.SegmentScope.RECORD, None), # type: ignore[arg-type] ) def _records( operation: t.Operation, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> Generator[t.OperationRecord, None, None]: """Convert parallel lists of embeddings, metadatas and documents to a sequence of SubmitEmbeddingRecords""" # Presumes that callers were invoked via Collection model, which means # that we know that the embeddings, metadatas and documents have already been # normalized and are guaranteed to be consistently named lists. if embeddings == []: embeddings = None for i, id in enumerate(ids): metadata = None if metadatas: metadata = metadatas[i] if documents: document = documents[i] if metadata: metadata = {**metadata, "chroma:document": document} else: metadata = {"chroma:document": document} if uris: uri = uris[i] if metadata: metadata = {**metadata, "chroma:uri": uri} else: metadata = {"chroma:uri": uri} record = t.OperationRecord( id=id, embedding=embeddings[i] if embeddings is not None else None, encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now metadata=metadata, operation=operation, ) yield record
Memory