from typing import (
Any,
Awaitable,
Callable,
cast,
Dict,
Sequence,
Optional,
Type,
TypeVar,
Tuple,
)
import fastapi
import orjson
from anyio import (
to_thread,
CapacityLimiter,
)
from fastapi import FastAPI as _FastAPI, Response, Request
from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse
from fastapi.routing import APIRoute
from fastapi import HTTPException, status
from functools import wraps
from chromadb.api.configuration import CollectionConfigurationInternal
from pydantic import BaseModel
from chromadb.api.types import (
Embedding,
GetResult,
QueryResult,
Embeddings,
convert_list_embeddings_to_np,
)
from chromadb.auth import UserIdentity
from chromadb.auth import (
AuthzAction,
AuthzResource,
ServerAuthenticationProvider,
ServerAuthorizationProvider,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.api import ServerAPI
from chromadb.errors import (
ChromaError,
InvalidDimensionException,
InvalidHTTPVersion,
RateLimitError,
QuotaError,
)
from chromadb.quota import QuotaEnforcer
from chromadb.rate_limit import AsyncRateLimitEnforcer
from chromadb.server import Server
from chromadb.server.fastapi.types import (
AddEmbedding,
CreateDatabase,
CreateTenant,
DeleteEmbedding,
GetEmbedding,
QueryEmbedding,
CreateCollection,
UpdateCollection,
UpdateEmbedding,
)
from starlette.datastructures import Headers
import logging
import importlib.metadata
from chromadb.telemetry.product.events import ServerStartEvent
from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid
from opentelemetry import trace
from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
from chromadb.types import Database, Tenant
from chromadb.telemetry.product import ServerContext, ProductTelemetryClient
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
add_attributes_to_current_span,
trace_method,
)
from chromadb.types import Collection as CollectionModel
logger = logging.getLogger(__name__)
def rate_limit(func):
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
self = args[0]
return await self._async_rate_limit_enforcer.rate_limit(func)(*args, **kwargs)
return wrapper
def use_route_names_as_operation_ids(app: _FastAPI) -> None:
"""
Simplify operation IDs so that generated API clients have simpler function
names.
Should be called only after all routes have been added.
"""
for route in app.routes:
if isinstance(route, APIRoute):
route.operation_id = route.name
async def add_trace_id_to_response_middleware(
request: Request, call_next: Callable[[Request], Any]
) -> Response:
trace_id = trace.get_current_span().get_span_context().trace_id
response = await call_next(request)
response.headers["Chroma-Trace-Id"] = format(trace_id, "x")
return response
async def catch_exceptions_middleware(
request: Request, call_next: Callable[[Request], Any]
) -> Response:
try:
return await call_next(request)
except ChromaError as e:
return fastapi_json_response(e)
except ValueError as e:
return ORJSONResponse(
content={"error": "InvalidArgumentError", "message": str(e)},
status_code=400,
)
except TypeError as e:
return ORJSONResponse(
content={"error": "InvalidArgumentError", "message": str(e)},
status_code=400,
)
except Exception as e:
logger.exception(e)
return ORJSONResponse(content={"error": repr(e)}, status_code=500)
async def check_http_version_middleware(
request: Request, call_next: Callable[[Request], Any]
) -> Response:
http_version = request.scope.get("http_version")
if http_version not in ["1.1", "2"]:
raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported")
return await call_next(request)
D = TypeVar("D", bound=BaseModel, contravariant=True)
def validate_model(model: Type[D], data: Any) -> D: # type: ignore
"""Used for backward compatibility with Pydantic 1.x"""
try:
return model.model_validate(data) # pydantic 2.x
except AttributeError:
return model.parse_obj(data) # pydantic 1.x
class ChromaAPIRouter(fastapi.APIRouter): # type: ignore
# A simple subclass of fastapi's APIRouter which treats URLs with a
# trailing "/" the same as URLs without. Docs will only contain URLs
# without trailing "/"s.
def add_api_route(self, path: str, *args: Any, **kwargs: Any) -> None:
# If kwargs["include_in_schema"] isn't passed OR is True, we should
# only include the non-"/" path. If kwargs["include_in_schema"] is
# False, include neither.
exclude_from_schema = (
"include_in_schema" in kwargs and not kwargs["include_in_schema"]
)
def include_in_schema(path: str) -> bool:
nonlocal exclude_from_schema
return not exclude_from_schema and not path.endswith("/")
kwargs["include_in_schema"] = include_in_schema(path)
super().add_api_route(path, *args, **kwargs)
if path.endswith("/"):
path = path[:-1]
else:
path = path + "/"
kwargs["include_in_schema"] = include_in_schema(path)
super().add_api_route(path, *args, **kwargs)
class FastAPI(Server):
def __init__(self, settings: Settings):
ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI
# https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse
self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse)
self._system = System(settings)
self._api: ServerAPI = self._system.instance(ServerAPI)
self._extra_openapi_schemas: Dict[str, Any] = {}
self._app.openapi = self.generate_openapi
self._opentelemetry_client = self._api.require(OpenTelemetryClient)
self._capacity_limiter = CapacityLimiter(
settings.chroma_server_thread_pool_size
)
self._quota_enforcer = self._system.require(QuotaEnforcer)
self._system.start()
self._app.middleware("http")(check_http_version_middleware)
self._app.middleware("http")(catch_exceptions_middleware)
self._app.middleware("http")(add_trace_id_to_response_middleware)
self._app.add_middleware(
CORSMiddleware,
allow_headers=["*"],
allow_origins=settings.chroma_server_cors_allow_origins,
allow_methods=["*"],
)
self._app.add_exception_handler(QuotaError, self.quota_exception_handler)
self._app.add_exception_handler(
RateLimitError, self.rate_limit_exception_handler
)
self._async_rate_limit_enforcer = self._system.require(AsyncRateLimitEnforcer)
self._app.on_event("shutdown")(self.shutdown)
self.authn_provider = None
if settings.chroma_server_authn_provider:
self.authn_provider = self._system.require(ServerAuthenticationProvider)
self.authz_provider = None
if settings.chroma_server_authz_provider:
self.authz_provider = self._system.require(ServerAuthorizationProvider)
self.router = ChromaAPIRouter()
self.setup_v1_routes()
self.setup_v2_routes()
self._app.include_router(self.router)
use_route_names_as_operation_ids(self._app)
instrument_fastapi(self._app)
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ServerStartEvent())
def generate_openapi(self) -> Dict[str, Any]:
"""Used instead of the default openapi() generation handler to include manually-populated schemas."""
schema: Dict[str, Any] = get_openapi(
title="Chroma",
routes=self._app.routes,
version=importlib.metadata.version("chromadb"),
)
for key, value in self._extra_openapi_schemas.items():
schema["components"]["schemas"][key] = value
return schema
def get_openapi_extras_for_body_model(
self, request_model: Type[D]
) -> Dict[str, Any]:
schema = request_model.model_json_schema(
ref_template="#/components/schemas/{model}"
)
if "$defs" in schema:
for key, value in schema["$defs"].items():
self._extra_openapi_schemas[key] = value
openapi_extra = {
"requestBody": {
"content": {"application/json": {"schema": schema}},
"required": True,
}
}
return openapi_extra
def setup_v2_routes(self) -> None:
self.router.add_api_route("/api/v2", self.root, methods=["GET"])
self.router.add_api_route("/api/v2/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v2/version", self.version, methods=["GET"])
self.router.add_api_route("/api/v2/heartbeat", self.heartbeat, methods=["GET"])
self.router.add_api_route(
"/api/v2/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
)
self.router.add_api_route(
"/api/v2/auth/identity",
self.get_user_identity,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases",
self.create_database,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}",
self.get_database,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}",
self.delete_database,
methods=["DELETE"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants",
self.create_tenant,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}",
self.get_tenant,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases",
self.list_databases,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections",
self.list_collections,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections_count",
self.count_collections,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections",
self.create_collection,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/add",
self.add,
methods=["POST"],
status_code=status.HTTP_201_CREATED,
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update",
self.update,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert",
self.upsert,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/get",
self.get,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete",
self.delete,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/count",
self.count,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/query",
self.get_nearest_neighbors,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(
request_model=QueryEmbedding
),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
self.get_collection,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}",
self.update_collection,
methods=["PUT"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
)
self.router.add_api_route(
"/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
self.delete_collection,
methods=["DELETE"],
response_model=None,
)
def shutdown(self) -> None:
self._system.stop()
def app(self) -> fastapi.FastAPI:
return self._app
async def rate_limit_exception_handler(
self, request: Request, exc: RateLimitError
) -> ORJSONResponse:
return ORJSONResponse(
status_code=429,
content={"message": "Rate limit exceeded."},
)
def root(self) -> Dict[str, int]:
return {"nanosecond heartbeat": self._api.heartbeat()}
async def quota_exception_handler(
self, request: Request, exc: QuotaError
) -> ORJSONResponse:
return ORJSONResponse(
status_code=400,
content={"message": exc.message()},
)
async def heartbeat(self) -> Dict[str, int]:
return self.root()
async def version(self) -> str:
return self._api.get_version()
def _set_request_context(self, request: Request) -> None:
"""
Set context about the request on any components that might need it.
"""
self._quota_enforcer.set_context(context={"request": request})
@trace_method(
"auth_request",
OpenTelemetryGranularity.OPERATION,
)
@rate_limit
async def auth_request(
self,
headers: Headers,
action: AuthzAction,
tenant: Optional[str],
database: Optional[str],
collection: Optional[str],
) -> None:
return await to_thread.run_sync(self.sync_auth_request, *(headers, action, tenant, database, collection))
def sync_auth_request(
self,
headers: Headers,
action: AuthzAction,
tenant: Optional[str],
database: Optional[str],
collection: Optional[str],
) -> None:
"""
Authenticates and authorizes the request based on the given headers
and other parameters. If the request cannot be authenticated or cannot
be authorized (with the configured providers), raises an HTTP 401.
"""
if not self.authn_provider:
add_attributes_to_current_span(
{
"tenant": tenant,
"database": database,
"collection": collection,
}
)
return
user_identity = self.authn_provider.authenticate_or_raise(dict(headers))
if not self.authz_provider:
return
authz_resource = AuthzResource(
tenant=tenant,
database=database,
collection=collection,
)
self.authz_provider.authorize_or_raise(user_identity, action, authz_resource)
add_attributes_to_current_span(
{
"tenant": tenant,
"database": database,
"collection": collection,
}
)
return
@trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
async def get_user_identity(
self,
request: Request,
) -> UserIdentity:
if not self.authn_provider:
return UserIdentity(
user_id="", tenant=DEFAULT_TENANT, databases=[DEFAULT_DATABASE]
)
return cast(
UserIdentity,
await to_thread.run_sync(
lambda: cast(ServerAuthenticationProvider, self.authn_provider).authenticate_or_raise(dict(request.headers)) # type: ignore
),
)
@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
async def create_database(
self,
request: Request,
tenant: str,
) -> None:
def process_create_database(
tenant: str, headers: Headers, raw_body: bytes
) -> None:
db = validate_model(CreateDatabase, orjson.loads(raw_body))
self.sync_auth_request(
headers,
AuthzAction.CREATE_DATABASE,
tenant,
db.name,
None,
)
self._set_request_context(request=request)
return self._api.create_database(db.name, tenant)
await to_thread.run_sync(
process_create_database,
tenant,
request.headers,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
async def get_database(
self,
request: Request,
database_name: str,
tenant: str,
) -> Database:
await self.auth_request(
request.headers,
AuthzAction.GET_DATABASE,
tenant,
database_name,
None,
)
return cast(
Database,
await to_thread.run_sync(
self._api.get_database,
database_name,
tenant,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
async def delete_database(
self,
request: Request,
database_name: str,
tenant: str,
) -> None:
self.auth_request(
request.headers,
AuthzAction.DELETE_DATABASE,
tenant,
database_name,
None,
)
await to_thread.run_sync(
self._api.delete_database,
database_name,
tenant,
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
async def create_tenant(
self,
request: Request,
) -> None:
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = validate_model(CreateTenant, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.CREATE_TENANT,
tenant.name,
None,
None,
)
return self._api.create_tenant(tenant.name)
await to_thread.run_sync(
process_create_tenant,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
async def get_tenant(
self,
request: Request,
tenant: str,
) -> Tenant:
await self.auth_request(
request.headers,
AuthzAction.GET_TENANT,
tenant,
None,
None,
)
return cast(
Tenant,
await to_thread.run_sync(
self._api.get_tenant,
tenant,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
async def list_databases(
self,
request: Request,
tenant: str,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Database]:
await self.auth_request(
request.headers,
AuthzAction.LIST_DATABASES,
tenant,
None,
None,
)
return cast(
Sequence[Database],
await to_thread.run_sync(
self._api.list_databases,
limit,
offset,
tenant,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
async def list_collections(
self,
request: Request,
tenant: str,
database_name: str,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[CollectionModel]:
def process_list_collections(
limit: Optional[int], offset: Optional[int], tenant: str, database_name: str
) -> Sequence[CollectionModel]:
self.sync_auth_request(
request.headers,
AuthzAction.LIST_COLLECTIONS,
tenant,
database_name,
None,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api.list_collections(
tenant=tenant, database=database_name, limit=limit, offset=offset
)
api_collection_models = cast(
Sequence[CollectionModel],
await to_thread.run_sync(
process_list_collections,
limit,
offset,
tenant,
database_name,
limiter=self._capacity_limiter,
),
)
return api_collection_models
@trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
async def count_collections(
self,
request: Request,
tenant: str,
database_name: str,
) -> int:
await self.auth_request(
request.headers,
AuthzAction.COUNT_COLLECTIONS,
tenant,
database_name,
None,
)
add_attributes_to_current_span({"tenant": tenant})
return cast(
int,
await to_thread.run_sync(
self._api.count_collections,
tenant,
database_name,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
async def create_collection(
self,
request: Request,
tenant: str,
database_name: str,
) -> CollectionModel:
def process_create_collection(
request: Request, tenant: str, database: str, raw_body: bytes
) -> CollectionModel:
create = validate_model(CreateCollection, orjson.loads(raw_body))
configuration = (
CollectionConfigurationInternal()
if not create.configuration
else CollectionConfigurationInternal.from_json(create.configuration)
)
self.sync_auth_request(
request.headers,
AuthzAction.CREATE_COLLECTION,
tenant,
database,
create.name,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api.create_collection(
name=create.name,
configuration=configuration,
metadata=create.metadata,
get_or_create=create.get_or_create,
tenant=tenant,
database=database,
)
api_collection_model = cast(
CollectionModel,
await to_thread.run_sync(
process_create_collection,
request,
tenant,
database_name,
await request.body(),
limiter=self._capacity_limiter,
),
)
return api_collection_model
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
async def get_collection(
self,
request: Request,
tenant: str,
database_name: str,
collection_name: str,
) -> CollectionModel:
await self.auth_request(
request.headers,
AuthzAction.GET_COLLECTION,
tenant,
database_name,
collection_name,
)
add_attributes_to_current_span({"tenant": tenant})
api_collection_model = cast(
CollectionModel,
await to_thread.run_sync(
self._api.get_collection,
collection_name,
tenant,
database_name,
limiter=self._capacity_limiter,
),
)
return api_collection_model
@trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION)
async def update_collection(
self,
tenant: str,
database_name: str,
collection_id: str,
request: Request,
) -> None:
def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = validate_model(UpdateCollection, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.UPDATE_COLLECTION,
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api._modify(
id=_uuid(collection_id),
new_name=update.new_name,
new_metadata=update.new_metadata,
tenant=tenant,
database=database_name,
)
await to_thread.run_sync(
process_update_collection,
request,
collection_id,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
async def delete_collection(
self,
request: Request,
collection_name: str,
tenant: str,
database_name: str,
) -> None:
await self.auth_request(
request.headers,
AuthzAction.DELETE_COLLECTION,
tenant,
database_name,
collection_name,
)
add_attributes_to_current_span({"tenant": tenant})
await to_thread.run_sync(
self._api.delete_collection,
collection_name,
tenant,
database_name,
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def add(
self,
request: Request,
tenant: str,
database_name: str,
collection_id: str,
) -> bool:
try:
def process_add(request: Request, raw_body: bytes) -> bool:
add = validate_model(AddEmbedding, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.ADD,
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api._add(
collection_id=_uuid(collection_id),
ids=add.ids,
embeddings=cast(
Embeddings,
convert_list_embeddings_to_np(add.embeddings)
if add.embeddings
else None,
),
metadatas=add.metadatas, # type: ignore
documents=add.documents, # type: ignore
uris=add.uris, # type: ignore
tenant=tenant,
database=database_name,
)
return cast(
bool,
await to_thread.run_sync(
process_add,
request,
await request.body(),
limiter=self._capacity_limiter,
),
)
except InvalidDimensionException as e:
raise HTTPException(status_code=500, detail=str(e))
@trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def update(
self,
request: Request,
tenant: str,
database_name: str,
collection_id: str,
) -> None:
def process_update(request: Request, raw_body: bytes) -> bool:
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.UPDATE,
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api._update(
collection_id=_uuid(collection_id),
ids=update.ids,
embeddings=convert_list_embeddings_to_np(update.embeddings)
if update.embeddings
else None,
metadatas=update.metadatas, # type: ignore
documents=update.documents, # type: ignore
uris=update.uris, # type: ignore
tenant=tenant,
database=database_name,
)
await to_thread.run_sync(
process_update,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def upsert(
self,
request: Request,
tenant: str,
database_name: str,
collection_id: str,
) -> None:
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.UPSERT,
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api._upsert(
collection_id=_uuid(collection_id),
ids=upsert.ids,
embeddings=cast(
Embeddings,
convert_list_embeddings_to_np(upsert.embeddings)
if upsert.embeddings
else None,
),
metadatas=upsert.metadatas, # type: ignore
documents=upsert.documents, # type: ignore
uris=upsert.uris, # type: ignore
tenant=tenant,
database=database_name,
)
await to_thread.run_sync(
process_upsert,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get(
self,
collection_id: str,
tenant: str,
database_name: str,
request: Request,
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = validate_model(GetEmbedding, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.GET,
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api._get(
collection_id=_uuid(collection_id),
ids=get.ids,
where=get.where,
sort=get.sort,
limit=get.limit,
offset=get.offset,
where_document=get.where_document,
include=get.include,
tenant=tenant,
database=database_name,
)
get_result = cast(
GetResult,
await to_thread.run_sync(
process_get,
request,
await request.body(),
limiter=self._capacity_limiter,
),
)
if get_result["embeddings"] is not None:
get_result["embeddings"] = [
cast(Embedding, embedding).tolist()
for embedding in get_result["embeddings"]
]
return get_result
@trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def delete(
self,
collection_id: str,
tenant: str,
database_name: str,
request: Request,
) -> None:
def process_delete(request: Request, raw_body: bytes) -> None:
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.DELETE,
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api._delete(
collection_id=_uuid(collection_id),
ids=delete.ids,
where=delete.where,
where_document=delete.where_document,
tenant=tenant,
database=database_name,
)
await to_thread.run_sync(
process_delete,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def count(
self,
request: Request,
tenant: str,
database_name: str,
collection_id: str,
) -> int:
await self.auth_request(
request.headers,
AuthzAction.COUNT,
tenant,
database_name,
collection_id,
)
add_attributes_to_current_span({"tenant": tenant})
return cast(
int,
await to_thread.run_sync(
self._api._count,
_uuid(collection_id),
tenant,
database_name,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def reset(
self,
request: Request,
) -> bool:
await self.auth_request(
request.headers,
AuthzAction.RESET,
None,
None,
None,
)
return cast(
bool,
await to_thread.run_sync(
self._api.reset,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_nearest_neighbors(
self,
tenant: str,
database_name: str,
collection_id: str,
request: Request,
) -> QueryResult:
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = validate_model(QueryEmbedding, orjson.loads(raw_body))
self.sync_auth_request(
request.headers,
AuthzAction.QUERY,
tenant,
database_name,
collection_id,
)
self._set_request_context(request=request)
add_attributes_to_current_span({"tenant": tenant})
return self._api._query(
collection_id=_uuid(collection_id),
query_embeddings=cast(
Embeddings,
convert_list_embeddings_to_np(query.query_embeddings)
if query.query_embeddings
else None,
),
n_results=query.n_results,
where=query.where,
where_document=query.where_document,
include=query.include,
tenant=tenant,
database=database_name,
)
nnresult = cast(
QueryResult,
await to_thread.run_sync(
process_query,
request,
await request.body(),
limiter=self._capacity_limiter,
),
)
if nnresult["embeddings"] is not None:
nnresult["embeddings"] = [
[cast(Embedding, embedding).tolist() for embedding in result]
for result in nnresult["embeddings"]
]
return nnresult
async def pre_flight_checks(self) -> Dict[str, Any]:
def process_pre_flight_checks() -> Dict[str, Any]:
return {
"max_batch_size": self._api.get_max_batch_size(),
}
return cast(
Dict[str, Any],
await to_thread.run_sync(
process_pre_flight_checks,
limiter=self._capacity_limiter,
),
)
# =========================================================================
# OLD ROUTES FOR BACKWARDS COMPATIBILITY — WILL BE REMOVED
# =========================================================================
def setup_v1_routes(self) -> None:
self.router.add_api_route("/api/v1", self.root, methods=["GET"])
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v1/version", self.version, methods=["GET"])
self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"])
self.router.add_api_route(
"/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
)
self.router.add_api_route(
"/api/v1/databases",
self.create_database_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
)
self.router.add_api_route(
"/api/v1/databases/{database}",
self.get_database_v1,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v1/tenants",
self.create_tenant_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
)
self.router.add_api_route(
"/api/v1/tenants/{tenant}",
self.get_tenant_v1,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v1/collections",
self.list_collections_v1,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v1/count_collections",
self.count_collections_v1,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v1/collections",
self.create_collection_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/add",
self.add_v1,
methods=["POST"],
status_code=status.HTTP_201_CREATED,
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/update",
self.update_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/upsert",
self.upsert_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/get",
self.get_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/delete",
self.delete_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/count",
self.count_v1,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}/query",
self.get_nearest_neighbors_v1,
methods=["POST"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(QueryEmbedding),
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}",
self.get_collection_v1,
methods=["GET"],
response_model=None,
)
self.router.add_api_route(
"/api/v1/collections/{collection_id}",
self.update_collection_v1,
methods=["PUT"],
response_model=None,
openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
)
self.router.add_api_route(
"/api/v1/collections/{collection_name}",
self.delete_collection_v1,
methods=["DELETE"],
response_model=None,
)
@trace_method(
"auth_and_get_tenant_and_database_for_request_v1",
OpenTelemetryGranularity.OPERATION,
)
@rate_limit
async def auth_and_get_tenant_and_database_for_request(
self,
headers: Headers,
action: AuthzAction,
tenant: Optional[str],
database: Optional[str],
collection: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:
"""
Authenticates and authorizes the request based on the given headers
and other parameters. If the request cannot be authenticated or cannot
be authorized (with the configured providers), raises an HTTP 401.
If the request is authenticated and authorized, returns the tenant and
database to be used for the request. These will differ from the passed
tenant and database if and only if:
- The request is authenticated
- chroma_overwrite_singleton_tenant_database_access_from_auth = True
- The passed tenant or database are None or default_{tenant, database}
(can be overwritten separately)
- The user has access to a single tenant and/or single database.
"""
return await to_thread.run_sync(self.auth_and_get_tenant_and_database_for_request, headers, action, tenant, database, collection)
def sync_auth_and_get_tenant_and_database_for_request(
self,
headers: Headers,
action: AuthzAction,
tenant: Optional[str],
database: Optional[str],
collection: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:
if not self.authn_provider:
add_attributes_to_current_span(
{
"tenant": tenant,
"database": database,
"collection": collection,
}
)
return (tenant, database)
user_identity = self.authn_provider.authenticate_or_raise(dict(headers))
(
new_tenant,
new_database,
) = self.authn_provider.singleton_tenant_database_if_applicable(user_identity)
if (not tenant or tenant == DEFAULT_TENANT) and new_tenant:
tenant = new_tenant
if (not database or database == DEFAULT_DATABASE) and new_database:
database = new_database
if not self.authz_provider:
return (tenant, database)
authz_resource = AuthzResource(
tenant=tenant,
database=database,
collection=collection,
)
self.authz_provider.authorize_or_raise(user_identity, action, authz_resource)
add_attributes_to_current_span(
{
"tenant": tenant,
"database": database,
"collection": collection,
}
)
return (tenant, database)
@trace_method("FastAPI.create_database_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def create_database_v1(
self,
request: Request,
tenant: str = DEFAULT_TENANT,
) -> None:
def process_create_database(
tenant: str, headers: Headers, raw_body: bytes
) -> None:
db = validate_model(CreateDatabase, orjson.loads(raw_body))
(
maybe_tenant,
maybe_database,
) = self.sync_auth_and_get_tenant_and_database_for_request(
headers,
AuthzAction.CREATE_DATABASE,
tenant,
db.name,
None,
)
if maybe_tenant:
tenant = maybe_tenant
if maybe_database:
db.name = maybe_database
return self._api.create_database(db.name, tenant)
await to_thread.run_sync(
process_create_database,
tenant,
request.headers,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.get_database_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_database_v1(
self,
request: Request,
database: str,
tenant: str = DEFAULT_TENANT,
) -> Database:
(
maybe_tenant,
maybe_database,
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET_DATABASE,
tenant,
database,
None,
)
if maybe_tenant:
tenant = maybe_tenant
if maybe_database:
database = maybe_database
return cast(
Database,
await to_thread.run_sync(
self._api.get_database,
database,
tenant,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.create_tenant_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def create_tenant_v1(
self,
request: Request,
) -> None:
def process_create_tenant(request: Request, raw_body: bytes) -> None:
tenant = validate_model(CreateTenant, orjson.loads(raw_body))
maybe_tenant, _ = self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.CREATE_TENANT,
tenant.name,
None,
None,
)
if maybe_tenant:
tenant.name = maybe_tenant
return self._api.create_tenant(tenant.name)
await to_thread.run_sync(
process_create_tenant,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.get_tenant_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_tenant_v1(
self,
request: Request,
tenant: str,
) -> Tenant:
maybe_tenant, _ = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET_TENANT,
tenant,
None,
None,
)
if maybe_tenant:
tenant = maybe_tenant
return cast(
Tenant,
await to_thread.run_sync(
self._api.get_tenant,
tenant,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.list_collections_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def list_collections_v1(
self,
request: Request,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
(
maybe_tenant,
maybe_database,
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.LIST_COLLECTIONS,
tenant,
database,
None,
)
if maybe_tenant:
tenant = maybe_tenant
if maybe_database:
database = maybe_database
api_collection_models = cast(
Sequence[CollectionModel],
await to_thread.run_sync(
self._api.list_collections,
limit,
offset,
tenant,
database,
limiter=self._capacity_limiter,
),
)
return api_collection_models
@trace_method("FastAPI.count_collections_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def count_collections_v1(
self,
request: Request,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
(
maybe_tenant,
maybe_database,
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.COUNT_COLLECTIONS,
tenant,
database,
None,
)
if maybe_tenant:
tenant = maybe_tenant
if maybe_database:
database = maybe_database
return cast(
int,
await to_thread.run_sync(
self._api.count_collections,
tenant,
database,
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.create_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def create_collection_v1(
self,
request: Request,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
def process_create_collection(
request: Request, tenant: str, database: str, raw_body: bytes
) -> CollectionModel:
create = validate_model(CreateCollection, orjson.loads(raw_body))
configuration = (
CollectionConfigurationInternal()
if not create.configuration
else CollectionConfigurationInternal.from_json(create.configuration)
)
(
maybe_tenant,
maybe_database,
) = self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.CREATE_COLLECTION,
tenant,
database,
create.name,
)
if maybe_tenant:
tenant = maybe_tenant
if maybe_database:
database = maybe_database
return self._api.create_collection(
name=create.name,
configuration=configuration,
metadata=create.metadata,
get_or_create=create.get_or_create,
tenant=tenant,
database=database,
)
api_collection_model = cast(
CollectionModel,
await to_thread.run_sync(
process_create_collection,
request,
tenant,
database,
await request.body(),
limiter=self._capacity_limiter,
),
)
return api_collection_model
@trace_method("FastAPI.get_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_collection_v1(
self,
request: Request,
collection_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
(
maybe_tenant,
maybe_database,
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET_COLLECTION,
tenant,
database,
collection_name,
)
if maybe_tenant:
tenant = maybe_tenant
if maybe_database:
database = maybe_database
async def inner():
api_collection_model = cast(
CollectionModel,
await to_thread.run_sync(
self._api.get_collection,
collection_name,
tenant,
database,
limiter=self._capacity_limiter,
),
)
return api_collection_model
return await inner()
@trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def update_collection_v1(
self,
collection_id: str,
request: Request,
) -> None:
def process_update_collection(
request: Request, collection_id: str, raw_body: bytes
) -> None:
update = validate_model(UpdateCollection, orjson.loads(raw_body))
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPDATE_COLLECTION,
None,
None,
collection_id,
)
return self._api._modify(
id=_uuid(collection_id),
new_name=update.new_name,
new_metadata=update.new_metadata,
)
await to_thread.run_sync(
process_update_collection,
request,
collection_id,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.delete_collection_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def delete_collection_v1(
self,
request: Request,
collection_name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
(
maybe_tenant,
maybe_database,
) = await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.DELETE_COLLECTION,
tenant,
database,
collection_name,
)
if maybe_tenant:
tenant = maybe_tenant
if maybe_database:
database = maybe_database
await to_thread.run_sync(
self._api.delete_collection,
collection_name,
tenant,
database,
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.add_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def add_v1(
self,
request: Request,
collection_id: str,
) -> bool:
try:
def process_add(request: Request, raw_body: bytes) -> bool:
add = validate_model(AddEmbedding, orjson.loads(raw_body))
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.ADD,
None,
None,
collection_id,
)
return self._api._add(
collection_id=_uuid(collection_id),
ids=add.ids,
embeddings=cast(
Embeddings,
convert_list_embeddings_to_np(add.embeddings)
if add.embeddings
else None,
),
metadatas=add.metadatas, # type: ignore
documents=add.documents, # type: ignore
uris=add.uris, # type: ignore
)
return cast(
bool,
await to_thread.run_sync(
process_add,
request,
await request.body(),
limiter=self._capacity_limiter,
),
)
except InvalidDimensionException as e:
raise HTTPException(status_code=500, detail=str(e))
@trace_method("FastAPI.update_v1", OpenTelemetryGranularity.OPERATION)
async def update_v1(
self,
request: Request,
collection_id: str,
) -> None:
def process_update(request: Request, raw_body: bytes) -> bool:
update = validate_model(UpdateEmbedding, orjson.loads(raw_body))
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPDATE,
None,
None,
collection_id,
)
return self._api._update(
collection_id=_uuid(collection_id),
ids=update.ids,
embeddings=convert_list_embeddings_to_np(update.embeddings)
if update.embeddings
else None,
metadatas=update.metadatas, # type: ignore
documents=update.documents, # type: ignore
uris=update.uris, # type: ignore
)
await to_thread.run_sync(
process_update,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.upsert_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def upsert_v1(
self,
request: Request,
collection_id: str,
) -> None:
def process_upsert(request: Request, raw_body: bytes) -> bool:
upsert = validate_model(AddEmbedding, orjson.loads(raw_body))
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.UPSERT,
None,
None,
collection_id,
)
return self._api._upsert(
collection_id=_uuid(collection_id),
ids=upsert.ids,
embeddings=cast(
Embeddings,
convert_list_embeddings_to_np(upsert.embeddings)
if upsert.embeddings
else None,
),
metadatas=upsert.metadatas, # type: ignore
documents=upsert.documents, # type: ignore
uris=upsert.uris, # type: ignore
)
await to_thread.run_sync(
process_upsert,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.get_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def get_v1(
self,
collection_id: str,
request: Request,
) -> GetResult:
def process_get(request: Request, raw_body: bytes) -> GetResult:
get = validate_model(GetEmbedding, orjson.loads(raw_body))
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.GET,
None,
None,
collection_id,
)
return self._api._get(
collection_id=_uuid(collection_id),
ids=get.ids,
where=get.where,
sort=get.sort,
limit=get.limit,
offset=get.offset,
where_document=get.where_document,
include=get.include,
)
get_result = cast(
GetResult,
await to_thread.run_sync(
process_get,
request,
await request.body(),
limiter=self._capacity_limiter,
),
)
if get_result["embeddings"] is not None:
get_result["embeddings"] = [
cast(Embedding, embedding).tolist()
for embedding in get_result["embeddings"]
]
return get_result
@trace_method("FastAPI.delete_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def delete_v1(
self,
collection_id: str,
request: Request,
) -> None:
def process_delete(request: Request, raw_body: bytes) -> None:
delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.DELETE,
None,
None,
collection_id,
)
return self._api._delete(
collection_id=_uuid(collection_id),
ids=delete.ids,
where=delete.where,
where_document=delete.where_document,
)
await to_thread.run_sync(
process_delete,
request,
await request.body(),
limiter=self._capacity_limiter,
)
@trace_method("FastAPI.count_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def count_v1(
self,
request: Request,
collection_id: str,
) -> int:
await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.COUNT,
None,
None,
collection_id,
)
return cast(
int,
await to_thread.run_sync(
self._api._count,
_uuid(collection_id),
limiter=self._capacity_limiter,
),
)
@trace_method("FastAPI.reset_v1", OpenTelemetryGranularity.OPERATION)
@rate_limit
async def reset_v1(
self,
request: Request,
) -> bool:
await self.auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.RESET,
None,
None,
None,
)
return cast(
bool,
await to_thread.run_sync(
self._api.reset,
limiter=self._capacity_limiter,
),
)
@trace_method(
"FastAPI.get_nearest_neighbors_v1", OpenTelemetryGranularity.OPERATION
)
@rate_limit
async def get_nearest_neighbors_v1(
self,
collection_id: str,
request: Request,
) -> QueryResult:
def process_query(request: Request, raw_body: bytes) -> QueryResult:
query = validate_model(QueryEmbedding, orjson.loads(raw_body))
self.sync_auth_and_get_tenant_and_database_for_request(
request.headers,
AuthzAction.QUERY,
None,
None,
collection_id,
)
return self._api._query(
collection_id=_uuid(collection_id),
query_embeddings=cast(
Embeddings,
convert_list_embeddings_to_np(query.query_embeddings)
if query.query_embeddings
else None,
),
n_results=query.n_results,
where=query.where,
where_document=query.where_document,
include=query.include,
)
nnresult = cast(
QueryResult,
await to_thread.run_sync(
process_query,
request,
await request.body(),
limiter=self._capacity_limiter,
),
)
if nnresult["embeddings"] is not None:
nnresult["embeddings"] = [
[cast(Embedding, embedding).tolist() for embedding in result]
for result in nnresult["embeddings"]
]
return nnresult
# =========================================================================