from typing import Dict, Optional
import grpc
from overrides import overrides
from chromadb.api.types import GetResult, Metadata, QueryResult
from chromadb.config import System
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.proto import convert
from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map."""
if not metadata:
return None
result = {}
for k, v in metadata.items():
if not k.startswith("chroma:"):
result[k] = v
if len(result) == 0:
return None
return result
def _uri(metadata: Optional[Metadata]) -> Optional[str]:
"""Retrieve the uri (if any) from a Metadata map"""
if metadata and "chroma:uri" in metadata:
return str(metadata["chroma:uri"])
return None
class DistributedExecutor(Executor):
_grpc_stub_pool: Dict[str, QueryExecutorStub]
_manager: DistributedSegmentManager
_request_timeout_seconds: int
def __init__(self, system: System):
super().__init__(system)
self._grpc_stub_pool = dict()
self._manager = self.require(DistributedSegmentManager)
self._request_timeout_seconds = system.settings.require(
"chroma_query_request_timeout_seconds"
)
@overrides
def count(self, plan: CountPlan) -> int:
executor = self._grpc_executuor_stub(plan.scan)
try:
count_result = executor.Count(convert.to_proto_count_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
return convert.from_proto_count_result(count_result)
@overrides
def get(self, plan: GetPlan) -> GetResult:
executor = self._grpc_executuor_stub(plan.scan)
try:
get_result = executor.Get(convert.to_proto_get_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
records = convert.from_proto_get_result(get_result)
ids = [record["id"] for record in records]
embeddings = (
[record["embedding"] for record in records]
if plan.projection.embedding
else None
)
documents = (
[record["document"] for record in records]
if plan.projection.document
else None
)
uris = (
[_uri(record["metadata"]) for record in records]
if plan.projection.uri
else None
)
metadatas = (
[_clean_metadata(record["metadata"]) for record in records]
if plan.projection.metadata
else None
)
# TODO: Fix typing
return GetResult(
ids=ids,
embeddings=embeddings, # type: ignore[typeddict-item]
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
included=plan.projection.included,
)
@overrides
def knn(self, plan: KNNPlan) -> QueryResult:
executor = self._grpc_executuor_stub(plan.scan)
try:
knn_result = executor.KNN(convert.to_proto_knn_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
results = convert.from_proto_knn_batch_result(knn_result)
ids = [[record["record"]["id"] for record in records] for records in results]
embeddings = (
[
[record["record"]["embedding"] for record in records]
for records in results
]
if plan.projection.embedding
else None
)
documents = (
[
[record["record"]["document"] for record in records]
for records in results
]
if plan.projection.document
else None
)
uris = (
[
[_uri(record["record"]["metadata"]) for record in records]
for records in results
]
if plan.projection.uri
else None
)
metadatas = (
[
[_clean_metadata(record["record"]["metadata"]) for record in records]
for records in results
]
if plan.projection.metadata
else None
)
distances = (
[[record["distance"] for record in records] for records in results]
if plan.projection.rank
else None
)
# TODO: Fix typing
return QueryResult(
ids=ids,
embeddings=embeddings, # type: ignore[typeddict-item]
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
distances=distances, # type: ignore[typeddict-item]
included=plan.projection.included,
)
def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub:
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
grpc_url = self._manager.get_endpoint(scan.record)
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(grpc_url)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)
return self._grpc_stub_pool[grpc_url]