from typing import List, Optional, Sequence, Tuple, Union, cast from uuid import UUID from overrides import overrides from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger from chromadb.db.system import SysDB from chromadb.errors import NotFoundError, UniqueConstraintError, InternalError from chromadb.proto.convert import ( from_proto_collection, from_proto_segment, to_proto_update_metadata, to_proto_segment, to_proto_segment_scope, ) from chromadb.proto.coordinator_pb2 import ( CreateCollectionRequest, CreateDatabaseRequest, CreateSegmentRequest, CreateTenantRequest, DeleteCollectionRequest, DeleteDatabaseRequest, DeleteSegmentRequest, GetCollectionsRequest, GetCollectionsResponse, GetCollectionWithSegmentsRequest, GetCollectionWithSegmentsResponse, GetDatabaseRequest, GetSegmentsRequest, GetTenantRequest, ListDatabasesRequest, UpdateCollectionRequest, UpdateSegmentRequest, ) from chromadb.proto.coordinator_pb2_grpc import SysDBStub from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor from chromadb.types import ( Collection, CollectionAndSegments, Database, Metadata, OptionalArgument, Segment, SegmentScope, Tenant, Unspecified, UpdateMetadata, ) from google.protobuf.empty_pb2 import Empty import grpc class GrpcSysDB(SysDB): """A gRPC implementation of the SysDB. In the distributed system, the SysDB is also called the 'Coordinator'. This implementation is used by Chroma frontend servers to call a remote SysDB (Coordinator) service.""" _sys_db_stub: SysDBStub _channel: grpc.Channel _coordinator_url: str _coordinator_port: int _request_timeout_seconds: int def __init__(self, system: System): self._coordinator_url = system.settings.require("chroma_coordinator_host") # TODO: break out coordinator_port into a separate setting? self._coordinator_port = system.settings.require("chroma_server_grpc_port") self._request_timeout_seconds = system.settings.require( "chroma_sysdb_request_timeout_seconds" ) return super().__init__(system) @overrides def start(self) -> None: self._channel = grpc.insecure_channel( f"{self._coordinator_url}:{self._coordinator_port}", ) interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] self._channel = grpc.intercept_channel(self._channel, *interceptors) self._sys_db_stub = SysDBStub(self._channel) # type: ignore return super().start() @overrides def stop(self) -> None: self._channel.close() return super().stop() @overrides def reset_state(self) -> None: self._sys_db_stub.ResetState(Empty()) return super().reset_state() @overrides def create_database( self, id: UUID, name: str, tenant: str = DEFAULT_TENANT ) -> None: try: request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant) response = self._sys_db_stub.CreateDatabase( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: logger.info( f"Failed to create database name {name} and database id {id} for tenant {tenant} due to error: {e}" ) if e.code() == grpc.StatusCode.ALREADY_EXISTS: raise UniqueConstraintError() raise InternalError() @overrides def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: try: request = GetDatabaseRequest(name=name, tenant=tenant) response = self._sys_db_stub.GetDatabase( request, timeout=self._request_timeout_seconds ) return Database( id=UUID(hex=response.database.id), name=response.database.name, tenant=response.database.tenant, ) except grpc.RpcError as e: logger.info( f"Failed to get database {name} for tenant {tenant} due to error: {e}" ) if e.code() == grpc.StatusCode.NOT_FOUND: raise NotFoundError() raise InternalError() @overrides def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: try: request = DeleteDatabaseRequest(name=name, tenant=tenant) self._sys_db_stub.DeleteDatabase( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: logger.info( f"Failed to delete database {name} for tenant {tenant} due to error: {e}" ) if e.code() == grpc.StatusCode.NOT_FOUND: raise NotFoundError() raise InternalError @overrides def list_databases( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, ) -> Sequence[Database]: try: request = ListDatabasesRequest(limit=limit, offset=offset, tenant=tenant) response = self._sys_db_stub.ListDatabases( request, timeout=self._request_timeout_seconds ) results: List[Database] = [] for proto_database in response.databases: results.append( Database( id=UUID(hex=proto_database.id), name=proto_database.name, tenant=proto_database.tenant, ) ) return results except grpc.RpcError as e: logger.info( f"Failed to list databases for tenant {tenant} due to error: {e}" ) raise InternalError() @overrides def create_tenant(self, name: str) -> None: try: request = CreateTenantRequest(name=name) response = self._sys_db_stub.CreateTenant( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: logger.info(f"Failed to create tenant {name} due to error: {e}") if e.code() == grpc.StatusCode.ALREADY_EXISTS: raise UniqueConstraintError() raise InternalError() @overrides def get_tenant(self, name: str) -> Tenant: try: request = GetTenantRequest(name=name) response = self._sys_db_stub.GetTenant( request, timeout=self._request_timeout_seconds ) return Tenant( name=response.tenant.name, ) except grpc.RpcError as e: logger.info(f"Failed to get tenant {name} due to error: {e}") if e.code() == grpc.StatusCode.NOT_FOUND: raise NotFoundError() raise InternalError() @overrides def create_segment(self, segment: Segment) -> None: try: proto_segment = to_proto_segment(segment) request = CreateSegmentRequest( segment=proto_segment, ) response = self._sys_db_stub.CreateSegment( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: logger.info(f"Failed to create segment {segment}, error: {e}") if e.code() == grpc.StatusCode.ALREADY_EXISTS: raise UniqueConstraintError() raise InternalError() @overrides def delete_segment(self, collection: UUID, id: UUID) -> None: try: request = DeleteSegmentRequest( id=id.hex, collection=collection.hex, ) response = self._sys_db_stub.DeleteSegment( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: logger.info( f"Failed to delete segment with id {id} for collection {collection} due to error: {e}" ) if e.code() == grpc.StatusCode.NOT_FOUND: raise NotFoundError() raise InternalError() @overrides def get_segments( self, collection: UUID, id: Optional[UUID] = None, type: Optional[str] = None, scope: Optional[SegmentScope] = None, ) -> Sequence[Segment]: try: request = GetSegmentsRequest( id=id.hex if id else None, type=type, scope=to_proto_segment_scope(scope) if scope else None, collection=collection.hex, ) response = self._sys_db_stub.GetSegments( request, timeout=self._request_timeout_seconds ) results: List[Segment] = [] for proto_segment in response.segments: segment = from_proto_segment(proto_segment) results.append(segment) return results except grpc.RpcError as e: logger.info( f"Failed to get segment id {id}, type {type}, scope {scope} for collection {collection} due to error: {e}" ) raise InternalError() @overrides def update_segment( self, collection: UUID, id: UUID, metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), ) -> None: try: write_metadata = None if metadata != Unspecified(): write_metadata = cast(Union[UpdateMetadata, None], metadata) request = UpdateSegmentRequest( id=id.hex, collection=collection.hex, metadata=to_proto_update_metadata(write_metadata) if write_metadata else None, ) if metadata is None: request.ClearField("metadata") request.reset_metadata = True self._sys_db_stub.UpdateSegment( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: logger.info( f"Failed to update segment with id {id} for collection {collection}, error: {e}" ) raise InternalError() @overrides def create_collection( self, id: UUID, name: str, configuration: CollectionConfigurationInternal, segments: Sequence[Segment], metadata: Optional[Metadata] = None, dimension: Optional[int] = None, get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Tuple[Collection, bool]: try: request = CreateCollectionRequest( id=id.hex, name=name, configuration_json_str=configuration.to_json_str(), metadata=to_proto_update_metadata(metadata) if metadata else None, dimension=dimension, get_or_create=get_or_create, tenant=tenant, database=database, segments=[to_proto_segment(segment) for segment in segments], ) response = self._sys_db_stub.CreateCollection( request, timeout=self._request_timeout_seconds ) collection = from_proto_collection(response.collection) return collection, response.created except grpc.RpcError as e: logger.error( f"Failed to create collection id {id}, name {name} for database {database} and tenant {tenant} due to error: {e}" ) if e.code() == grpc.StatusCode.ALREADY_EXISTS: raise UniqueConstraintError() raise InternalError() @overrides def delete_collection( self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: try: request = DeleteCollectionRequest( id=id.hex, tenant=tenant, database=database, ) response = self._sys_db_stub.DeleteCollection( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: logger.error( f"Failed to delete collection id {id} for database {database} and tenant {tenant} due to error: {e}" ) e = cast(grpc.Call, e) logger.error( f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}" ) if e.code() == grpc.StatusCode.NOT_FOUND: raise NotFoundError() raise InternalError() @overrides def get_collections( self, id: Optional[UUID] = None, name: Optional[str] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, limit: Optional[int] = None, offset: Optional[int] = None, ) -> Sequence[Collection]: try: # TODO: implement limit and offset in the gRPC service request = None if id is not None: request = GetCollectionsRequest( id=id.hex, limit=limit, offset=offset, ) if name is not None: if tenant is None and database is None: raise ValueError( "If name is specified, tenant and database must also be specified in order to uniquely identify the collection" ) request = GetCollectionsRequest( name=name, tenant=tenant, database=database, limit=limit, offset=offset, ) if id is None and name is None: request = GetCollectionsRequest( tenant=tenant, database=database, limit=limit, offset=offset, ) response: GetCollectionsResponse = self._sys_db_stub.GetCollections( request, timeout=self._request_timeout_seconds ) results: List[Collection] = [] for collection in response.collections: results.append(from_proto_collection(collection)) return results except grpc.RpcError as e: logger.error( f"Failed to get collections with id {id}, name {name}, tenant {tenant}, database {database} due to error: {e}" ) raise InternalError() @overrides def get_collection_with_segments( self, collection_id: UUID ) -> CollectionAndSegments: try: request = GetCollectionWithSegmentsRequest(id=collection_id.hex) response: GetCollectionWithSegmentsResponse = ( self._sys_db_stub.GetCollectionWithSegments(request) ) return CollectionAndSegments( collection=from_proto_collection(response.collection), segments=[from_proto_segment(segment) for segment in response.segments], ) except grpc.RpcError as e: if e.code() == grpc.StatusCode.NOT_FOUND: raise NotFoundError() logger.error( f"Failed to get collection {collection_id} and its segments due to error: {e}" ) raise InternalError() @overrides def update_collection( self, id: UUID, name: OptionalArgument[str] = Unspecified(), dimension: OptionalArgument[Optional[int]] = Unspecified(), metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(), ) -> None: try: write_name = None if name != Unspecified(): write_name = cast(str, name) write_dimension = None if dimension != Unspecified(): write_dimension = cast(Union[int, None], dimension) write_metadata = None if metadata != Unspecified(): write_metadata = cast(Union[UpdateMetadata, None], metadata) request = UpdateCollectionRequest( id=id.hex, name=write_name, dimension=write_dimension, metadata=to_proto_update_metadata(write_metadata) if write_metadata else None, ) if metadata is None: request.ClearField("metadata") request.reset_metadata = True response = self._sys_db_stub.UpdateCollection( request, timeout=self._request_timeout_seconds ) except grpc.RpcError as e: e = cast(grpc.Call, e) logger.error( f"Failed to update collection id {id}, name {name} due to error: {e}" ) if e.code() == grpc.StatusCode.NOT_FOUND: raise NotFoundError() if e.code() == grpc.StatusCode.ALREADY_EXISTS: raise UniqueConstraintError() raise InternalError() def reset_and_wait_for_ready(self) -> None: self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)
Memory