import abc
import threading
from typing import Any, Callable, Optional
from pymilvus.exceptions import MilvusException
from pymilvus.grpc_gen import milvus_pb2
from .abstract import MutationResult, SearchResult
from .types import Status
from .utils import check_status
# TODO: remove this to a common util
def _parameter_is_empty(func: Callable):
import inspect
sig = inspect.signature(func)
# todo: add more check to parameter, such as `default parameter`,
# `positional-only`, `positional-or-keyword`, `keyword-only`, `var-positional`, `var-keyword`
# if len(params) == 0:
# for param in params.values():
# if (param.kind == inspect.Parameter.POSITIONAL_ONLY or
# param.default == inspect._empty:
return len(sig.parameters) == 0
class AbstractFuture:
@abc.abstractmethod
def result(self, **kwargs):
"""Return deserialized result.
It's a synchronous interface. It will wait executing until
server respond or timeout occur(if specified).
This API is thread-safe.
"""
raise NotImplementedError
@abc.abstractmethod
def cancel(self):
"""Cancle gRPC future.
This API is thread-safe.
"""
raise NotImplementedError
@abc.abstractmethod
def done(self):
"""Wait for request done.
This API is thread-safe.
"""
raise NotImplementedError
class Future(AbstractFuture):
def __init__(
self,
future: Any,
done_callback: Optional[Callable] = None,
pre_exception: Optional[Callable] = None,
**kwargs,
) -> None:
self._future = future
# keep compatible (such as Future(future, done_callback)), deprecated later
self._done_cb = done_callback
self._done_cb_list = []
self.add_callback(done_callback)
self._condition = threading.Condition()
self._canceled = False
self._done = False
self._response = None
self._results = None
self._exception = pre_exception
self._callback_called = False # callback function should be called only once
self._kwargs = kwargs
def add_callback(self, func: Callable):
self._done_cb_list.append(func)
def __del__(self) -> None:
self._future = None
@abc.abstractmethod
def on_response(self, response: Callable):
"""Parse response from gRPC server and return results."""
raise NotImplementedError
def _callback(self):
if not self._callback_called:
for cb in self._done_cb_list:
if cb:
# necessary to check parameter signature of cb?
if isinstance(self._results, tuple):
cb(*self._results)
elif _parameter_is_empty(cb):
cb()
elif self._results is not None:
cb(self._results)
else:
raise MilvusException(message="callback function is not legal!")
self._callback_called = True
def result(self, **kwargs):
self.exception()
with self._condition:
# future not finished. wait callback being called.
to = kwargs.get("timeout")
if to is None:
to = self._kwargs.get("timeout", None)
if self._future and self._results is None:
try:
self._response = self._future.result(timeout=to)
except Exception as e:
raise MilvusException(message=str(e)) from e
self._results = self.on_response(self._response)
self._callback()
self._done = True
self._condition.notify_all()
self.exception()
if kwargs.get("raw", False) is True:
# just return response object received from gRPC
return self._response
if self._results:
return self._results
return self.on_response(self._response)
def cancel(self):
with self._condition:
if self._future:
self._future.cancel()
self._condition.notify_all()
def is_done(self):
return self._done
def done(self):
with self._condition:
if self._future and self._results is None:
try:
self._response = self._future.result()
self._results = self.on_response(self._response)
self._callback() # https://github.com/milvus-io/milvus/issues/6160
except Exception as e:
self._exception = e
self._done = True
self._condition.notify_all()
def exception(self):
if self._exception:
raise self._exception
if self._future:
self._future.exception()
class SearchFuture(Future):
def on_response(self, response: milvus_pb2.SearchResults):
check_status(response.status)
return SearchResult(response.results, status=response.status)
class MutationFuture(Future):
def on_response(self, response: Any):
check_status(response.status)
return MutationResult(response)
class CreateIndexFuture(Future):
def on_response(self, response: Any):
check_status(response)
return Status(response.code, response.reason)
class CreateFlatIndexFuture(AbstractFuture):
def __init__(
self,
res: Any,
done_callback: Optional[Callable] = None,
pre_exception: Optional[Callable] = None,
) -> None:
self._results = res
self._done_cb = done_callback
self._done_cb_list = []
self.add_callback(done_callback)
self._condition = threading.Condition()
self._exception = pre_exception
def add_callback(self, func: Callable):
self._done_cb_list.append(func)
def __del__(self) -> None:
self._results = None
def on_response(self, response: Any):
pass
def result(self):
self.exception()
with self._condition:
for cb in self._done_cb_list:
if cb:
# necessary to check parameter signature of cb?
if isinstance(self._results, tuple):
cb(*self._results)
elif _parameter_is_empty(cb):
cb()
elif self._results is not None:
cb(self._results)
else:
raise MilvusException(message="callback function is not legal!")
return self._results
def cancel(self):
with self._condition:
self._condition.notify_all()
def is_done(self):
return True
def done(self):
with self._condition:
self._condition.notify_all()
def exception(self):
if self._exception:
raise self._exception
class FlushFuture(Future):
def on_response(self, response: Any):
check_status(response.status)
class LoadCollectionFuture(Future):
def on_response(self, response: Any):
check_status(response.status)
class LoadPartitionsFuture(Future):
def on_response(self, response: Any):
check_status(response.status)