import datetime import sys from typing import Any, Callable, Union from pymilvus.exceptions import ParamError from pymilvus.grpc_gen import milvus_pb2 as milvus_types from . import entity_helper from .singleton_utils import Singleton def validate_strs(**kwargs): """validate if all values are legal non-emtpy str""" invalid_pair = {k: v for k, v in kwargs.items() if not validate_str(v)} if invalid_pair: msg = f"Illegal str variables: {invalid_pair}, expect non-empty str" raise ParamError(message=msg) def validate_nullable_strs(**kwargs): """validate if all values are either None or legal non-empty str""" invalid_pair = {k: v for k, v in kwargs.items() if v is not None and not validate_str(v)} if invalid_pair: msg = f"Illegal nullable str variables: {invalid_pair}, expect None or non-empty str" raise ParamError(message=msg) def validate_str(var: Any) -> bool: """check if a variable is legal non-empty str""" return var and isinstance(var, str) def is_legal_address(addr: Any) -> bool: if not isinstance(addr, str): return False a = addr.split(":") if len(a) != 2: return False return is_legal_host(a[0]) and is_legal_port(a[1]) def is_legal_host(host: Any) -> bool: return isinstance(host, str) and len(host) > 0 and (":" not in host) def is_legal_port(port: Any) -> bool: if isinstance(port, (str, int)): try: int(port) except ValueError: return False else: return True return False def int_or_str(item: Union[int, str]) -> str: if isinstance(item, int): return str(item) return item def is_correct_date_str(param: str) -> bool: try: datetime.datetime.strptime(param, "%Y-%m-%d") except ValueError: return False return True def is_legal_dimension(dim: Any) -> bool: try: _ = int(dim) except ValueError: return False return True def is_legal_index_size(index_size: Any) -> bool: return isinstance(index_size, int) def is_legal_table_name(table_name: Any) -> bool: return validate_str(table_name) def is_legal_db_name(db_name: Any) -> bool: # you can connect to the default database "". return isinstance(db_name, str) def is_legal_field_name(field_name: Any) -> bool: return field_name and isinstance(field_name, str) def is_legal_index_name(index_name: Any) -> bool: return index_name and isinstance(index_name, str) def is_legal_timeout(timeout: Any) -> bool: return timeout is None or isinstance(timeout, (int, float)) def is_legal_nlist(nlist: Any) -> bool: return not isinstance(nlist, bool) and isinstance(nlist, int) def is_legal_topk(topk: Any) -> bool: return not isinstance(topk, bool) and isinstance(topk, int) def is_legal_ids(ids: Any) -> bool: if not ids or not isinstance(ids, list): return False # TODO: Here check id valid value range may not match other SDK for i in ids: if not isinstance(i, (int, str)): return False try: i_ = int(i) if i_ < 0 or i_ > sys.maxsize: return False except Exception: return False return True def is_legal_nprobe(nprobe: Any) -> bool: return isinstance(nprobe, int) def is_legal_itopk_size(itopk_size: Any) -> bool: return isinstance(itopk_size, int) def is_legal_search_width(search_width: Any) -> bool: return isinstance(search_width, int) def is_legal_min_iterations(min_iterations: Any) -> bool: return isinstance(min_iterations, int) def is_legal_max_iterations(max_iterations: Any) -> bool: return isinstance(max_iterations, int) def is_legal_drop_ratio(drop_ratio: Any) -> bool: return isinstance(drop_ratio, float) and 0 <= drop_ratio < 1 def is_legal_team_size(team_size: Any) -> bool: return isinstance(team_size, int) def is_legal_cmd(cmd: Any) -> bool: return cmd and isinstance(cmd, str) def parser_range_date(date: Union[str, datetime.date]) -> str: if isinstance(date, datetime.date): return date.strftime("%Y-%m-%d") if isinstance(date, str): if not is_correct_date_str(date): raise ParamError(message="Date string should be YY-MM-DD format!") return date raise ParamError( message="Date should be YY-MM-DD format string or datetime.date, " "or datetime.datetime object" ) def is_legal_date_range(start: str, end: str) -> bool: start_date = datetime.datetime.strptime(start, "%Y-%m-%d") end_date = datetime.datetime.strptime(end, "%Y-%m-%d") return (end_date - start_date).days >= 0 def is_legal_partition_name(tag: Any) -> bool: return tag is not None and isinstance(tag, str) def is_legal_limit(limit: Any) -> bool: return isinstance(limit, int) and limit > 0 def is_legal_anns_field(field: Any) -> bool: return field is None or isinstance(field, str) def is_legal_search_data(data: Any) -> bool: import numpy as np if entity_helper.entity_is_sparse_matrix(data): return True if not isinstance(data, (list, np.ndarray)): return False return all(isinstance(vector, (list, bytes, np.ndarray, str)) for vector in data) def is_legal_output_fields(output_fields: Any) -> bool: if output_fields is None: return True if not isinstance(output_fields, list): return False return all(is_legal_field_name(field) for field in output_fields) def is_legal_partition_name_array(tag_array: Any) -> bool: if tag_array is None: return True if not isinstance(tag_array, list): return False return all(is_legal_partition_name(tag) for tag in tag_array) def is_legal_replica_number(replica_number: int) -> bool: return isinstance(replica_number, int) def _raise_param_error(param_name: str, param_value: Any) -> None: raise ParamError(message=f"`{param_name}` value {param_value} is illegal") def is_legal_round_decimal(round_decimal: Any) -> bool: return isinstance(round_decimal, int) and -2 < round_decimal < 7 def is_legal_guarantee_timestamp(ts: Any) -> bool: return (ts is None) or (isinstance(ts, int) and ts >= 0) def is_legal_user(user: Any) -> bool: return isinstance(user, str) def is_legal_password(password: Any) -> bool: return isinstance(password, str) def is_legal_role_name(role_name: Any) -> bool: return role_name and isinstance(role_name, str) def is_legal_operate_user_role_type(operate_user_role_type: Any) -> bool: return operate_user_role_type in ( milvus_types.OperateUserRoleType.AddUserToRole, milvus_types.OperateUserRoleType.RemoveUserFromRole, ) def is_legal_include_user_info(include_user_info: Any) -> bool: return isinstance(include_user_info, bool) def is_legal_include_role_info(include_role_info: Any) -> bool: return isinstance(include_role_info, bool) def is_legal_object(object: Any) -> bool: return object and isinstance(object, str) def is_legal_object_name(object_name: Any) -> bool: return object_name and isinstance(object_name, str) def is_legal_privilege(privilege: Any) -> bool: return privilege and isinstance(privilege, str) def is_legal_collection_properties(properties: Any) -> bool: return properties and isinstance(properties, dict) def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool: return operate_privilege_type in ( milvus_types.OperatePrivilegeType.Grant, milvus_types.OperatePrivilegeType.Revoke, ) def is_legal_privilege_group(privilege_group: Any) -> bool: return privilege_group and isinstance(privilege_group, str) def is_legal_privileges(privileges: Any) -> bool: return ( privileges and isinstance(privileges, list) and all(is_legal_privilege(p) for p in privileges) ) def is_legal_operate_privilege_group_type(operate_privilege_group_type: Any) -> bool: return operate_privilege_group_type in ( milvus_types.OperatePrivilegeGroupType.AddPrivilegesToGroup, milvus_types.OperatePrivilegeGroupType.RemovePrivilegesFromGroup, ) class ParamChecker(metaclass=Singleton): def __init__(self) -> None: self.check_dict = { "db_name": is_legal_db_name, "collection_name": is_legal_table_name, "alias": is_legal_table_name, "field_name": is_legal_field_name, "dimension": is_legal_dimension, "index_file_size": is_legal_index_size, "topk": is_legal_topk, "ids": is_legal_ids, "nprobe": is_legal_nprobe, "nlist": is_legal_nlist, "cmd": is_legal_cmd, "partition_name": is_legal_partition_name, "partition_name_array": is_legal_partition_name_array, "limit": is_legal_limit, "anns_field": is_legal_anns_field, "search_data": is_legal_search_data, "output_fields": is_legal_output_fields, "round_decimal": is_legal_round_decimal, "guarantee_timestamp": is_legal_guarantee_timestamp, "user": is_legal_user, "password": is_legal_password, "role_name": is_legal_role_name, "operate_user_role_type": is_legal_operate_user_role_type, "include_user_info": is_legal_include_user_info, "include_role_info": is_legal_include_role_info, "object": is_legal_object, "object_name": is_legal_object_name, "privilege": is_legal_privilege, "operate_privilege_type": is_legal_operate_privilege_type, "properties": is_legal_collection_properties, "replica_number": is_legal_replica_number, "resource_group_name": is_legal_table_name, "itopk_size": is_legal_itopk_size, "search_width": is_legal_search_width, "min_iterations": is_legal_min_iterations, "max_iterations": is_legal_max_iterations, "team_size": is_legal_team_size, "index_name": is_legal_index_name, "timeout": is_legal_timeout, "drop_ratio_build": is_legal_drop_ratio, "drop_ratio_search": is_legal_drop_ratio, "privilege_group": is_legal_privilege_group, "privileges": is_legal_privileges, "operate_privilege_group_type": is_legal_operate_privilege_group_type, } def check(self, key: str, value: Callable): if key in self.check_dict: if not self.check_dict[key](value): _raise_param_error(key, value) else: raise ParamError(message=f"unknown param `{key}`") def _get_param_checker(): return ParamChecker() def check_pass_param(*_args: Any, **kwargs: Any) -> None: # pylint: disable=too-many-statements if kwargs is None: raise ParamError(message="Param should not be None") checker = _get_param_checker() for key, value in kwargs.items(): checker.check(key, value)
Memory