import datetime
import functools
import logging
import time
from typing import Any, Callable, Optional
import grpc
from .exceptions import ErrorCode, MilvusException
from .grpc_gen import common_pb2
LOGGER = logging.getLogger(__name__)
WARNING_COLOR = "\033[93m{}\033[0m"
def deprecated(func: Any):
@functools.wraps(func)
def inner(*args, **kwargs):
LOGGER.warning(
WARNING_COLOR.format(
"[WARNING] PyMilvus: ",
"class Milvus will be deprecated soon, please use Collection/utility instead",
)
)
return func(*args, **kwargs)
return inner
# Reference: https://grpc.github.io/grpc/python/grpc.html#grpc-status-code
IGNORE_RETRY_CODES = (
grpc.StatusCode.DEADLINE_EXCEEDED,
grpc.StatusCode.PERMISSION_DENIED,
grpc.StatusCode.UNAUTHENTICATED,
grpc.StatusCode.INVALID_ARGUMENT,
grpc.StatusCode.ALREADY_EXISTS,
grpc.StatusCode.RESOURCE_EXHAUSTED,
grpc.StatusCode.UNIMPLEMENTED,
)
def retry_on_rpc_failure(
*,
retry_times: int = 75,
initial_back_off: float = 0.01,
max_back_off: float = 3,
back_off_multiplier: int = 3,
):
def wrapper(func: Any):
@functools.wraps(func)
@error_handler(func_name=func.__name__)
@tracing_request()
def handler(*args, **kwargs):
# This has to make sure every timeout parameter is passing
# throught kwargs form as `timeout=10`
_timeout = kwargs.get("timeout")
_retry_times = kwargs.get("retry_times")
_retry_on_rate_limit = kwargs.get("retry_on_rate_limit", True)
retry_timeout = _timeout if _timeout is not None and isinstance(_timeout, int) else None
final_retry_times = (
_retry_times
if _retry_times is not None and isinstance(_retry_times, int)
else retry_times
)
counter = 1
back_off = initial_back_off
start_time = time.time()
def timeout(start_time: Optional[float] = None) -> bool:
"""If timeout is valid, use timeout as the retry limits,
If timeout is None, use final_retry_times as the retry limits.
"""
if retry_timeout is not None:
return time.time() - start_time >= retry_timeout
return counter > final_retry_times
to_msg = (
f"Retry timeout: {retry_timeout}s"
if retry_timeout is not None
else f"Retry run out of {final_retry_times} retry times"
)
while True:
try:
return func(*args, **kwargs)
except grpc.RpcError as e:
# Do not retry on these codes
if e.code() in IGNORE_RETRY_CODES:
raise e from e
if timeout(start_time):
raise MilvusException(e.code, f"{to_msg}, message={e.details()}") from e
if counter > 3:
retry_msg = (
f"[{func.__name__}] retry:{counter}, cost: {back_off:.2f}s, "
f"reason: <{e.__class__.__name__}: {e.code()}, {e.details()}>"
)
# retry msg uses info level
LOGGER.info(retry_msg)
time.sleep(back_off)
back_off = min(back_off * back_off_multiplier, max_back_off)
except MilvusException as e:
if timeout(start_time):
LOGGER.warning(WARNING_COLOR.format(to_msg))
raise MilvusException(
code=e.code, message=f"{to_msg}, message={e.message}"
) from e
if _retry_on_rate_limit and (
e.code == ErrorCode.RATE_LIMIT or e.compatible_code == common_pb2.RateLimit
):
time.sleep(back_off)
back_off = min(back_off * back_off_multiplier, max_back_off)
else:
raise e from e
except Exception as e:
raise e from e
finally:
counter += 1
return handler
return wrapper
def error_handler(func_name: str = ""):
def wrapper(func: Callable):
@functools.wraps(func)
def handler(*args, **kwargs):
inner_name = func_name
if inner_name == "":
inner_name = func.__name__
record_dict = {}
try:
record_dict["RPC start"] = str(datetime.datetime.now())
return func(*args, **kwargs)
except MilvusException as e:
record_dict["RPC error"] = str(datetime.datetime.now())
LOGGER.error(f"RPC error: [{inner_name}], {e}, <Time:{record_dict}>")
raise e from e
except grpc.FutureTimeoutError as e:
record_dict["gRPC timeout"] = str(datetime.datetime.now())
LOGGER.error(
f"grpc Timeout: [{inner_name}], <{e.__class__.__name__}: "
f"{e.code()}, {e.details()}>, <Time:{record_dict}>"
)
raise e from e
except grpc.RpcError as e:
record_dict["gRPC error"] = str(datetime.datetime.now())
LOGGER.error(
f"grpc RpcError: [{inner_name}], <{e.__class__.__name__}: "
f"{e.code()}, {e.details()}>, <Time:{record_dict}>"
)
raise e from e
except Exception as e:
record_dict["Exception"] = str(datetime.datetime.now())
LOGGER.error(f"Unexpected error: [{inner_name}], {e}, <Time: {record_dict}>")
raise MilvusException(message=f"Unexpected error, message=<{e!s}>") from e
return handler
return wrapper
def tracing_request():
def wrapper(func: Callable):
@functools.wraps(func)
def handler(self: Callable, *args, **kwargs):
level = kwargs.get("log_level")
req_id = kwargs.get("client_request_id")
if level:
self.set_onetime_loglevel(level)
if req_id:
self.set_onetime_request_id(req_id)
return func(self, *args, **kwargs)
return handler
return wrapper
def ignore_unimplemented(default_return_value: Any):
def wrapper(func: Callable):
@functools.wraps(func)
def handler(*args, **kwargs):
try:
return func(*args, **kwargs)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNIMPLEMENTED:
LOGGER.debug(f"{func.__name__} unimplemented, ignore it")
return default_return_value
raise e from e
except Exception as e:
raise e from e
return handler
return wrapper
def upgrade_reminder(func: Callable):
@functools.wraps(func)
def handler(*args, **kwargs):
try:
return func(*args, **kwargs)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNIMPLEMENTED:
msg = (
"Incorrect port or sdk is incompatible with server, "
"please check your port or downgrade your sdk or upgrade your server"
)
raise MilvusException(message=msg) from e
raise e from e
except Exception as e:
raise e from e
return handler