from threading import Lock from typing import Dict, Sequence from uuid import UUID, uuid4 from overrides import override from chromadb.config import System from chromadb.db.system import SysDB from chromadb.segment import ( SegmentImplementation, SegmentManager, SegmentType, ) from chromadb.segment.distributed import SegmentDirectory from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams from chromadb.telemetry.opentelemetry import ( OpenTelemetryGranularity, trace_method, ) from chromadb.types import ( Collection, Operation, Segment, SegmentScope, ) class DistributedSegmentManager(SegmentManager): _sysdb: SysDB _system: System _instances: Dict[UUID, SegmentImplementation] _segment_directory: SegmentDirectory _lock: Lock def __init__(self, system: System): super().__init__(system) self._sysdb = self.require(SysDB) self._segment_directory = self.require(SegmentDirectory) self._system = system self._instances = {} self._lock = Lock() @trace_method( "DistributedSegmentManager.prepare_segments_for_new_collection", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) @override def prepare_segments_for_new_collection( self, collection: Collection ) -> Sequence[Segment]: vector_segment = Segment( id=uuid4(), type=SegmentType.HNSW_DISTRIBUTED.value, scope=SegmentScope.VECTOR, collection=collection.id, metadata=PersistentHnswParams.extract(collection.metadata) if collection.metadata else None, file_paths={}, ) metadata_segment = Segment( id=uuid4(), type=SegmentType.BLOCKFILE_METADATA.value, scope=SegmentScope.METADATA, collection=collection.id, metadata=None, file_paths={}, ) record_segment = Segment( id=uuid4(), type=SegmentType.BLOCKFILE_RECORD.value, scope=SegmentScope.RECORD, collection=collection.id, metadata=None, file_paths={}, ) return [vector_segment, record_segment, metadata_segment] @override def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: # TODO: this should be a pass, delete_collection is expected to delete segments in # distributed segments = self._sysdb.get_segments(collection=collection_id) return [s["id"] for s in segments] @trace_method( "DistributedSegmentManager.get_endpoint", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) def get_endpoint(self, segment: Segment) -> str: return self._segment_directory.get_segment_endpoint(segment) @trace_method( "DistributedSegmentManager.hint_use_collection", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) @override def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: pass
Memory