from tenacity import retry, stop_after_attempt, retry_if_exception, wait_fixed
from chromadb.api import ServerAPI
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.auth import UserIdentity
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.db.system import SysDB
from chromadb.quota import QuotaEnforcer, Action
from chromadb.rate_limit import RateLimitEnforcer, AsyncRateLimitEnforcer
from chromadb.segment import SegmentManager
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan, Filter, Limit, KNN, Projection
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.ingest import Producer
from chromadb.types import Collection as CollectionModel
from chromadb import __version__
from chromadb.errors import (
InvalidDimensionException,
InvalidCollectionException,
VersionMismatchError,
)
from chromadb.api.types import (
CollectionMetadata,
IDs,
Embeddings,
Metadatas,
Documents,
URIs,
Where,
WhereDocument,
Include,
IncludeEnum,
GetResult,
QueryResult,
validate_metadata,
validate_update_metadata,
validate_where,
validate_where_document,
validate_batch,
)
from chromadb.telemetry.product.events import (
CollectionAddEvent,
CollectionDeleteEvent,
CollectionGetEvent,
CollectionUpdateEvent,
CollectionQueryEvent,
ClientCreateCollectionEvent,
)
import chromadb.types as t
from typing import (
Optional,
Sequence,
Generator,
List,
Any,
Callable,
TypeVar,
)
from overrides import override
from uuid import UUID, uuid4
from functools import wraps
import time
import logging
import re
T = TypeVar("T", bound=Callable[..., Any])
logger = logging.getLogger(__name__)
# mimics s3 bucket requirements for naming
def check_index_name(index_name: str) -> None:
msg = (
"Expected collection name that "
"(1) contains 3-63 characters, "
"(2) starts and ends with an alphanumeric character, "
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), "
"(4) contains no two consecutive periods (..) and "
"(5) is not a valid IPv4 address, "
f"got {index_name}"
)
if len(index_name) < 3 or len(index_name) > 63:
raise ValueError(msg)
if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name):
raise ValueError(msg)
if ".." in index_name:
raise ValueError(msg)
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name):
raise ValueError(msg)
def rate_limit(func: T) -> T:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
self = args[0]
return self._rate_limit_enforcer.rate_limit(func)(*args, **kwargs)
return wrapper # type: ignore
class SegmentAPI(ServerAPI):
"""API implementation utilizing the new segment-based internal architecture"""
_settings: Settings
_sysdb: SysDB
_manager: SegmentManager
_executor: Executor
_producer: Producer
_product_telemetry_client: ProductTelemetryClient
_opentelemetry_client: OpenTelemetryClient
_tenant_id: str
_topic_ns: str
_rate_limit_enforcer: RateLimitEnforcer
def __init__(self, system: System):
super().__init__(system)
self._settings = system.settings
self._sysdb = self.require(SysDB)
self._manager = self.require(SegmentManager)
self._executor = self.require(Executor)
self._quota_enforcer = self.require(QuotaEnforcer)
self._product_telemetry_client = self.require(ProductTelemetryClient)
self._opentelemetry_client = self.require(OpenTelemetryClient)
self._producer = self.require(Producer)
self._rate_limit_enforcer = self._system.require(RateLimitEnforcer)
@override
def heartbeat(self) -> int:
return int(time.time_ns())
@trace_method("SegmentAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
if len(name) < 3:
raise ValueError("Database name must be at least 3 characters long")
self._quota_enforcer.enforce(
action=Action.CREATE_DATABASE,
tenant=tenant,
name=name,
)
self._sysdb.create_database(
id=uuid4(),
name=name,
tenant=tenant,
)
@trace_method("SegmentAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
return self._sysdb.get_database(name=name, tenant=tenant)
@trace_method("SegmentAPI.delete_database", OpenTelemetryGranularity.OPERATION)
@override
def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
self._sysdb.delete_database(name=name, tenant=tenant)
@trace_method("SegmentAPI.list_databases", OpenTelemetryGranularity.OPERATION)
@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[t.Database]:
return self._sysdb.list_databases(limit=limit, offset=offset, tenant=tenant)
@trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
if len(name) < 3:
raise ValueError("Tenant name must be at least 3 characters long")
self._sysdb.create_tenant(
name=name,
)
@override
def get_user_identity(self) -> UserIdentity:
return UserIdentity(
user_id="",
tenant=DEFAULT_TENANT,
databases=[DEFAULT_DATABASE],
)
@trace_method("SegmentAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> t.Tenant:
return self._sysdb.get_tenant(name=name)
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
# causes the system to somehow convert all values to strings.
@trace_method("SegmentAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def create_collection(
self,
name: str,
configuration: Optional[CollectionConfigurationInternal] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
if metadata is not None:
validate_metadata(metadata)
# TODO: remove backwards compatibility in naming requirements
check_index_name(name)
self._quota_enforcer.enforce(
action=Action.CREATE_COLLECTION,
tenant=tenant,
name=name,
metadata=metadata,
)
id = uuid4()
model = CollectionModel(
id=id,
name=name,
metadata=metadata,
configuration=configuration
if configuration is not None
else CollectionConfigurationInternal(), # Use default configuration if none is provided
tenant=tenant,
database=database,
dimension=None,
)
# TODO: Let sysdb create the collection directly from the model
coll, created = self._sysdb.create_collection(
id=model.id,
name=model.name,
configuration=model.get_configuration(),
segments=[], # Passing empty till backend changes are deployed.
metadata=model.metadata,
dimension=None, # This is lazily populated on the first add
get_or_create=get_or_create,
tenant=tenant,
database=database,
)
if created:
segments = self._manager.prepare_segments_for_new_collection(coll)
for segment in segments:
self._sysdb.create_segment(segment)
else:
logger.debug(
f"Collection {name} already exists, returning existing collection."
)
# TODO: This event doesn't capture the get_or_create case appropriately
# TODO: Re-enable embedding function tracking in create_collection
self._product_telemetry_client.capture(
ClientCreateCollectionEvent(
collection_uuid=str(id),
# embedding_function=embedding_function.__class__.__name__,
)
)
add_attributes_to_current_span({"collection_uuid": str(id)})
return coll
@trace_method(
"SegmentAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
)
@override
@rate_limit
def get_or_create_collection(
self,
name: str,
configuration: Optional[CollectionConfigurationInternal] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
return self.create_collection(
name=name,
metadata=metadata,
configuration=configuration,
get_or_create=True,
tenant=tenant,
database=database,
)
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
# causes the system to somehow convert all values to strings
@trace_method("SegmentAPI.get_collection", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def get_collection(
self,
name: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
existing = self._sysdb.get_collections(
name=name, tenant=tenant, database=database
)
if existing:
return existing[0]
else:
raise InvalidCollectionException(f"Collection {name} does not exist.")
@trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
self._quota_enforcer.enforce(
action=Action.LIST_COLLECTIONS,
tenant=tenant,
limit=limit,
)
return self._sysdb.get_collections(
limit=limit, offset=offset, tenant=tenant, database=database
)
@trace_method("SegmentAPI.count_collections", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def count_collections(
self,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
collection_count = len(
self._sysdb.get_collections(tenant=tenant, database=database)
)
return collection_count
@trace_method("SegmentAPI._modify", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
if new_name:
# backwards compatibility in naming requirements (for now)
check_index_name(new_name)
if new_metadata:
validate_update_metadata(new_metadata)
# Ensure the collection exists
_ = self._get_collection(id)
self._quota_enforcer.enforce(
action=Action.UPDATE_COLLECTION,
tenant=tenant,
name=new_name,
metadata=new_metadata,
)
# TODO eventually we'll want to use OptionalArgument and Unspecified in the
# signature of `_modify` but not changing the API right now.
if new_name and new_metadata:
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata)
elif new_name:
self._sysdb.update_collection(id, name=new_name)
elif new_metadata:
self._sysdb.update_collection(id, metadata=new_metadata)
@trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
existing = self._sysdb.get_collections(
name=name, tenant=tenant, database=database
)
if existing:
self._sysdb.delete_collection(
existing[0].id, tenant=tenant, database=database
)
self._manager.delete_segments(existing[0].id)
else:
raise ValueError(f"Collection {name} does not exist.")
@trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)
records_to_submit = list(
_records(
t.Operation.ADD,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._quota_enforcer.enforce(
action=Action.ADD,
tenant=tenant,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
self._producer.submit_embeddings(collection_id, records_to_submit)
self._product_telemetry_client.capture(
CollectionAddEvent(
collection_uuid=str(collection_id),
add_amount=len(ids),
with_metadata=len(ids) if metadatas is not None else 0,
with_documents=len(ids) if documents is not None else 0,
with_uris=len(ids) if uris is not None else 0,
)
)
return True
@trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)
records_to_submit = list(
_records(
t.Operation.UPDATE,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._quota_enforcer.enforce(
action=Action.UPDATE,
tenant=tenant,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
self._producer.submit_embeddings(collection_id, records_to_submit)
self._product_telemetry_client.capture(
CollectionUpdateEvent(
collection_uuid=str(collection_id),
update_amount=len(ids),
with_embeddings=len(embeddings) if embeddings else 0,
with_metadata=len(metadatas) if metadatas else 0,
with_documents=len(documents) if documents else 0,
with_uris=len(uris) if uris else 0,
)
)
return True
@trace_method("SegmentAPI._upsert", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)
records_to_submit = list(
_records(
t.Operation.UPSERT,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._quota_enforcer.enforce(
action=Action.UPSERT,
tenant=tenant,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
self._producer.submit_embeddings(collection_id, records_to_submit)
return True
@trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION)
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@override
@rate_limit
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
sort: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item]
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
add_attributes_to_current_span(
{
"collection_id": str(collection_id),
"ids_count": len(ids) if ids else 0,
}
)
scan = self._scan(collection_id)
# TODO: Replace with unified validation
if where is not None:
validate_where(where)
if where_document is not None:
validate_where_document(where_document)
self._quota_enforcer.enforce(
action=Action.GET,
tenant=tenant,
ids=ids,
where=where,
where_document=where_document,
limit=limit,
)
if sort is not None:
raise NotImplementedError("Sorting is not yet supported")
if page and page_size:
offset = (page - 1) * page_size
limit = page_size
ids_amount = len(ids) if ids else 0
self._product_telemetry_client.capture(
CollectionGetEvent(
collection_uuid=str(collection_id),
ids_count=ids_amount,
limit=limit if limit else 0,
include_metadata=ids_amount if "metadatas" in include else 0,
include_documents=ids_amount if "documents" in include else 0,
include_uris=ids_amount if "uris" in include else 0,
)
)
return self._executor.get(
GetPlan(
scan,
Filter(ids, where, where_document),
Limit(offset or 0, limit),
Projection(
IncludeEnum.documents in include,
IncludeEnum.embeddings in include,
IncludeEnum.metadatas in include,
False,
IncludeEnum.uris in include,
),
)
)
@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
add_attributes_to_current_span(
{
"collection_id": str(collection_id),
"ids_count": len(ids) if ids else 0,
}
)
# TODO: Replace with unified validation
if where is not None:
validate_where(where)
if where_document is not None:
validate_where_document(where_document)
# You must have at least one of non-empty ids, where, or where_document.
if (
(ids is None or (ids is not None and len(ids) == 0))
and (where is None or (where is not None and len(where) == 0))
and (
where_document is None
or (where_document is not None and len(where_document) == 0)
)
):
raise ValueError(
"""
You must provide either ids, where, or where_document to delete. If
you want to delete all data in a collection you can delete the
collection itself using the delete_collection method. Or alternatively,
you can get() all the relevant ids and then delete them.
"""
)
scan = self._scan(collection_id)
self._quota_enforcer.enforce(
action=Action.DELETE,
tenant=tenant,
ids=ids,
where=where,
where_document=where_document,
)
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)
if (where or where_document) or not ids:
ids_to_delete = self._executor.get(
GetPlan(scan, Filter(ids, where, where_document))
)["ids"]
else:
ids_to_delete = ids
if len(ids_to_delete) == 0:
return
records_to_submit = list(
_records(operation=t.Operation.DELETE, ids=ids_to_delete)
)
self._validate_embedding_record_set(scan.collection, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)
self._product_telemetry_client.capture(
CollectionDeleteEvent(
collection_uuid=str(collection_id), delete_amount=len(ids_to_delete)
)
)
@trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION)
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@override
@rate_limit
def _count(
self,
collection_id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
add_attributes_to_current_span({"collection_id": str(collection_id)})
return self._executor.count(CountPlan(self._scan(collection_id)))
@trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
# We retry on version mismatch errors because the version of the collection
# may have changed between the time we got the version and the time we
# actually query the collection on the FE. We are fine with fixed
# wait time because the version mismatch error is not a error due to
# network issues or other transient issues. It is a result of the
# collection being updated between the time we got the version and
# the time we actually query the collection on the FE.
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@override
@rate_limit
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["documents", "metadatas", "distances"], # type: ignore[list-item]
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
add_attributes_to_current_span(
{
"collection_id": str(collection_id),
"n_results": n_results,
"where": str(where),
}
)
query_amount = len(query_embeddings)
self._product_telemetry_client.capture(
CollectionQueryEvent(
collection_uuid=str(collection_id),
query_amount=query_amount,
n_results=n_results,
with_metadata_filter=query_amount if where is not None else 0,
with_document_filter=query_amount if where_document is not None else 0,
include_metadatas=query_amount if "metadatas" in include else 0,
include_documents=query_amount if "documents" in include else 0,
include_uris=query_amount if "uris" in include else 0,
include_distances=query_amount if "distances" in include else 0,
)
)
# TODO: Replace with unified validation
if where is not None:
validate_where(where)
if where_document is not None:
validate_where_document(where_document)
scan = self._scan(collection_id)
for embedding in query_embeddings:
self._validate_dimension(scan.collection, len(embedding), update=False)
self._quota_enforcer.enforce(
action=Action.QUERY,
tenant=tenant,
where=where,
where_document=where_document,
query_embeddings=query_embeddings,
n_results=n_results,
)
return self._executor.knn(
KNNPlan(
scan,
KNN(query_embeddings, n_results),
Filter(None, where, where_document),
Projection(
IncludeEnum.documents in include,
IncludeEnum.embeddings in include,
IncludeEnum.metadatas in include,
IncludeEnum.distances in include,
IncludeEnum.uris in include,
),
)
)
@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
@rate_limit
def _peek(
self,
collection_id: UUID,
n: int = 10,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
add_attributes_to_current_span({"collection_id": str(collection_id)})
return self._get(collection_id, limit=n) # type: ignore
@override
def get_version(self) -> str:
return __version__
@override
def reset_state(self) -> None:
pass
@override
def reset(self) -> bool:
self._system.reset_state()
return True
@override
def get_settings(self) -> Settings:
return self._settings
@override
def get_max_batch_size(self) -> int:
return self._producer.max_batch_size
# TODO: This could potentially cause race conditions in a distributed version of the
# system, since the cache is only local.
# TODO: promote collection -> topic to a base class method so that it can be
# used for channel assignment in the distributed version of the system.
@trace_method(
"SegmentAPI._validate_embedding_record_set", OpenTelemetryGranularity.ALL
)
def _validate_embedding_record_set(
self, collection: t.Collection, records: List[t.OperationRecord]
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
add_attributes_to_current_span({"collection_id": str(collection["id"])})
for record in records:
if record["embedding"] is not None:
self._validate_dimension(
collection, len(record["embedding"]), update=True
)
# This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings.
def _validate_dimension(
self, collection: t.Collection, dim: int, update: bool
) -> None:
"""Validate that a collection supports records of the given dimension. If update
is true, update the collection if the collection doesn't already have a
dimension."""
if collection["dimension"] is None:
if update:
id = collection.id
self._sysdb.update_collection(id=id, dimension=dim)
collection["dimension"] = dim
elif collection["dimension"] != dim:
raise InvalidDimensionException(
f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}"
)
else:
return # all is well
@trace_method("SegmentAPI._get_collection", OpenTelemetryGranularity.ALL)
def _get_collection(self, collection_id: UUID) -> t.Collection:
collections = self._sysdb.get_collections(id=collection_id)
if not collections or len(collections) == 0:
raise InvalidCollectionException(
f"Collection {collection_id} does not exist."
)
return collections[0]
@trace_method("SegmentAPI._scan", OpenTelemetryGranularity.ALL)
def _scan(self, collection_id: UUID) -> Scan:
collection_and_segments = self._sysdb.get_collection_with_segments(
collection_id
)
# For now collection should have exactly one segment per scope:
# - Local scopes: vector, metadata
# - Distributed scopes: vector, metadata, record
scope_to_segment = {
segment["scope"]: segment for segment in collection_and_segments["segments"]
}
return Scan(
collection=collection_and_segments["collection"],
knn=scope_to_segment[t.SegmentScope.VECTOR],
metadata=scope_to_segment[t.SegmentScope.METADATA],
# Local chroma do not have record segment, and this is not used by the local executor
record=scope_to_segment.get(t.SegmentScope.RECORD, None), # type: ignore[arg-type]
)
def _records(
operation: t.Operation,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> Generator[t.OperationRecord, None, None]:
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of
SubmitEmbeddingRecords"""
# Presumes that callers were invoked via Collection model, which means
# that we know that the embeddings, metadatas and documents have already been
# normalized and are guaranteed to be consistently named lists.
if embeddings == []:
embeddings = None
for i, id in enumerate(ids):
metadata = None
if metadatas:
metadata = metadatas[i]
if documents:
document = documents[i]
if metadata:
metadata = {**metadata, "chroma:document": document}
else:
metadata = {"chroma:document": document}
if uris:
uri = uris[i]
if metadata:
metadata = {**metadata, "chroma:uri": uri}
else:
metadata = {"chroma:uri": uri}
record = t.OperationRecord(
id=id,
embedding=embeddings[i] if embeddings is not None else None,
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now
metadata=metadata,
operation=operation,
)
yield record