import abc import threading from typing import Any, Callable, Optional from pymilvus.exceptions import MilvusException from pymilvus.grpc_gen import milvus_pb2 from .abstract import MutationResult, SearchResult from .types import Status from .utils import check_status # TODO: remove this to a common util def _parameter_is_empty(func: Callable): import inspect sig = inspect.signature(func) # todo: add more check to parameter, such as `default parameter`, # `positional-only`, `positional-or-keyword`, `keyword-only`, `var-positional`, `var-keyword` # if len(params) == 0: # for param in params.values(): # if (param.kind == inspect.Parameter.POSITIONAL_ONLY or # param.default == inspect._empty: return len(sig.parameters) == 0 class AbstractFuture: @abc.abstractmethod def result(self, **kwargs): """Return deserialized result. It's a synchronous interface. It will wait executing until server respond or timeout occur(if specified). This API is thread-safe. """ raise NotImplementedError @abc.abstractmethod def cancel(self): """Cancle gRPC future. This API is thread-safe. """ raise NotImplementedError @abc.abstractmethod def done(self): """Wait for request done. This API is thread-safe. """ raise NotImplementedError class Future(AbstractFuture): def __init__( self, future: Any, done_callback: Optional[Callable] = None, pre_exception: Optional[Callable] = None, **kwargs, ) -> None: self._future = future # keep compatible (such as Future(future, done_callback)), deprecated later self._done_cb = done_callback self._done_cb_list = [] self.add_callback(done_callback) self._condition = threading.Condition() self._canceled = False self._done = False self._response = None self._results = None self._exception = pre_exception self._callback_called = False # callback function should be called only once self._kwargs = kwargs def add_callback(self, func: Callable): self._done_cb_list.append(func) def __del__(self) -> None: self._future = None @abc.abstractmethod def on_response(self, response: Callable): """Parse response from gRPC server and return results.""" raise NotImplementedError def _callback(self): if not self._callback_called: for cb in self._done_cb_list: if cb: # necessary to check parameter signature of cb? if isinstance(self._results, tuple): cb(*self._results) elif _parameter_is_empty(cb): cb() elif self._results is not None: cb(self._results) else: raise MilvusException(message="callback function is not legal!") self._callback_called = True def result(self, **kwargs): self.exception() with self._condition: # future not finished. wait callback being called. to = kwargs.get("timeout") if to is None: to = self._kwargs.get("timeout", None) if self._future and self._results is None: try: self._response = self._future.result(timeout=to) except Exception as e: raise MilvusException(message=str(e)) from e self._results = self.on_response(self._response) self._callback() self._done = True self._condition.notify_all() self.exception() if kwargs.get("raw", False) is True: # just return response object received from gRPC return self._response if self._results: return self._results return self.on_response(self._response) def cancel(self): with self._condition: if self._future: self._future.cancel() self._condition.notify_all() def is_done(self): return self._done def done(self): with self._condition: if self._future and self._results is None: try: self._response = self._future.result() self._results = self.on_response(self._response) self._callback() # https://github.com/milvus-io/milvus/issues/6160 except Exception as e: self._exception = e self._done = True self._condition.notify_all() def exception(self): if self._exception: raise self._exception if self._future: self._future.exception() class SearchFuture(Future): def on_response(self, response: milvus_pb2.SearchResults): check_status(response.status) return SearchResult(response.results, status=response.status) class MutationFuture(Future): def on_response(self, response: Any): check_status(response.status) return MutationResult(response) class CreateIndexFuture(Future): def on_response(self, response: Any): check_status(response) return Status(response.code, response.reason) class CreateFlatIndexFuture(AbstractFuture): def __init__( self, res: Any, done_callback: Optional[Callable] = None, pre_exception: Optional[Callable] = None, ) -> None: self._results = res self._done_cb = done_callback self._done_cb_list = [] self.add_callback(done_callback) self._condition = threading.Condition() self._exception = pre_exception def add_callback(self, func: Callable): self._done_cb_list.append(func) def __del__(self) -> None: self._results = None def on_response(self, response: Any): pass def result(self): self.exception() with self._condition: for cb in self._done_cb_list: if cb: # necessary to check parameter signature of cb? if isinstance(self._results, tuple): cb(*self._results) elif _parameter_is_empty(cb): cb() elif self._results is not None: cb(self._results) else: raise MilvusException(message="callback function is not legal!") return self._results def cancel(self): with self._condition: self._condition.notify_all() def is_done(self): return True def done(self): with self._condition: self._condition.notify_all() def exception(self): if self._exception: raise self._exception class FlushFuture(Future): def on_response(self, response: Any): check_status(response.status) class LoadCollectionFuture(Future): def on_response(self, response: Any): check_status(response.status) class LoadPartitionsFuture(Future): def on_response(self, response: Any): check_status(response.status)
Memory