import base64 import datetime from typing import Any, Dict, Iterable, List, Mapping, Optional, Union import numpy as np import ujson from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError from pymilvus.grpc_gen import common_pb2 as common_types from pymilvus.grpc_gen import milvus_pb2 as milvus_types from pymilvus.grpc_gen import schema_pb2 as schema_types from pymilvus.orm.schema import CollectionSchema from pymilvus.orm.types import infer_dtype_by_scalar_data from . import __version__, blob, check, entity_helper, ts_utils, utils from .check import check_pass_param, is_legal_collection_properties from .constants import ( DEFAULT_CONSISTENCY_LEVEL, DYNAMIC_FIELD_NAME, GROUP_BY_FIELD, GROUP_SIZE, ITERATOR_FIELD, PAGE_RETAIN_ORDER_FIELD, RANK_GROUP_SCORER, REDUCE_STOP_FOR_BEST, STRICT_GROUP_SIZE, ) from .types import ( DataType, PlaceholderType, ResourceGroupConfig, get_consistency_level, ) from .utils import traverse_info, traverse_upsert_info class Prepare: @classmethod def create_collection_request( cls, collection_name: str, fields: Union[Dict[str, Iterable], CollectionSchema], **kwargs, ) -> milvus_types.CreateCollectionRequest: """ Args: fields (Union(Dict[str, Iterable], CollectionSchema)). {"fields": [ {"name": "A", "type": DataType.INT32} {"name": "B", "type": DataType.INT64, "auto_id": True, "is_primary": True}, {"name": "C", "type": DataType.FLOAT}, {"name": "Vec", "type": DataType.FLOAT_VECTOR, "params": {"dim": 128}}] } Returns: milvus_types.CreateCollectionRequest """ if isinstance(fields, CollectionSchema): schema = cls.get_schema_from_collection_schema(collection_name, fields) else: schema = cls.get_schema(collection_name, fields, **kwargs) consistency_level = get_consistency_level( kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) ) req = milvus_types.CreateCollectionRequest( collection_name=collection_name, schema=bytes(schema.SerializeToString()), consistency_level=consistency_level, ) properties = kwargs.get("properties") if is_legal_collection_properties(properties): properties = [ common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in properties.items() ] req.properties.extend(properties) same_key = set(kwargs.keys()).intersection({"num_shards", "shards_num"}) if len(same_key) > 0: if len(same_key) > 1: msg = "got both num_shards and shards_num in kwargs, expected only one of them" raise ParamError(message=msg) num_shards = kwargs[next(iter(same_key))] if not isinstance(num_shards, int): msg = f"invalid num_shards type, got {type(num_shards)}, expected int" raise ParamError(message=msg) req.shards_num = num_shards num_partitions = kwargs.get("num_partitions") if num_partitions is not None: if not isinstance(num_partitions, int) or isinstance(num_partitions, bool): msg = f"invalid num_partitions type, got {type(num_partitions)}, expected int" raise ParamError(message=msg) if num_partitions < 1: msg = f"The specified num_partitions should be greater than or equal to 1, got {num_partitions}" raise ParamError(message=msg) req.num_partitions = num_partitions return req @classmethod def get_schema_from_collection_schema( cls, collection_name: str, fields: CollectionSchema, ) -> schema_types.CollectionSchema: coll_description = fields.description if not isinstance(coll_description, (str, bytes)): msg = ( f"description [{coll_description}] has type {type(coll_description).__name__}, " "but expected one of: bytes, str" ) raise ParamError(message=msg) schema = schema_types.CollectionSchema( name=collection_name, autoID=fields.auto_id, description=coll_description, enable_dynamic_field=fields.enable_dynamic_field, ) for f in fields.fields: field_schema = schema_types.FieldSchema( name=f.name, data_type=f.dtype, description=f.description, is_primary_key=f.is_primary, default_value=f.default_value, nullable=f.nullable, autoID=f.auto_id, is_partition_key=f.is_partition_key, is_dynamic=f.is_dynamic, element_type=f.element_type, is_clustering_key=f.is_clustering_key, is_function_output=f.is_function_output, ) for k, v in f.params.items(): kv_pair = common_types.KeyValuePair( key=str(k) if k != "mmap_enabled" else "mmap.enabled", value=ujson.dumps(v) ) field_schema.type_params.append(kv_pair) schema.fields.append(field_schema) for f in fields.functions: function_schema = schema_types.FunctionSchema( name=f.name, description=f.description, type=f.type, input_field_names=f.input_field_names, output_field_names=f.output_field_names, ) for k, v in f.params.items(): kv_pair = common_types.KeyValuePair(key=str(k), value=str(v)) function_schema.params.append(kv_pair) schema.functions.append(function_schema) return schema @staticmethod def get_field_schema( field: Dict, primary_field: Any, auto_id_field: Any, ) -> (schema_types.FieldSchema, Any, Any): field_name = field.get("name") if field_name is None: raise ParamError(message="You should specify the name of field!") data_type = field.get("type") if data_type is None: raise ParamError(message="You should specify the data type of field!") if not isinstance(data_type, (int, DataType)): raise ParamError(message="Field type must be of DataType!") is_primary = field.get("is_primary", False) if not isinstance(is_primary, bool): raise ParamError(message="is_primary must be boolean") if is_primary: if primary_field is not None: raise ParamError(message="A collection should only have one primary field") if DataType(data_type) not in [DataType.INT64, DataType.VARCHAR]: msg = "int64 and varChar are the only supported types of primary key" raise ParamError(message=msg) primary_field = field_name nullable = field.get("nullable", False) if not isinstance(nullable, bool): raise ParamError(message="nullable must be boolean") auto_id = field.get("auto_id", False) if not isinstance(auto_id, bool): raise ParamError(message="auto_id must be boolean") if auto_id: if auto_id_field is not None: raise ParamError(message="A collection should only have one autoID field") if DataType(data_type) != DataType.INT64: msg = "int64 is the only supported type of automatic generated id" raise ParamError(message=msg) auto_id_field = field_name field_schema = schema_types.FieldSchema( name=field_name, data_type=data_type, description=field.get("description", ""), is_primary_key=is_primary, autoID=auto_id, is_partition_key=field.get("is_partition_key", False), is_clustering_key=field.get("is_clustering_key", False), ) type_params = field.get("params", {}) if not isinstance(type_params, dict): raise ParamError(message="params should be dictionary type") kvs = [ common_types.KeyValuePair( key=str(k) if k != "mmap_enabled" else "mmap.enabled", value=str(v) ) for k, v in type_params.items() ] field_schema.type_params.extend(kvs) return field_schema, primary_field, auto_id_field @classmethod def get_schema( cls, collection_name: str, fields: Dict[str, Iterable], **kwargs, ) -> schema_types.CollectionSchema: if not isinstance(fields, dict): raise ParamError(message="Param fields must be a dict") all_fields = fields.get("fields") if all_fields is None: raise ParamError(message="Param fields must contain key 'fields'") if len(all_fields) == 0: raise ParamError(message="Param fields value cannot be empty") enable_dynamic_field = kwargs.get("enable_dynamic_field", False) if "enable_dynamic_field" in fields: enable_dynamic_field = fields["enable_dynamic_field"] schema = schema_types.CollectionSchema( name=collection_name, autoID=False, description=fields.get("description", ""), enable_dynamic_field=enable_dynamic_field, ) primary_field, auto_id_field = None, None for field in all_fields: (field_schema, primary_field, auto_id_field) = cls.get_field_schema( field, primary_field, auto_id_field ) schema.fields.append(field_schema) return schema @classmethod def drop_collection_request(cls, collection_name: str) -> milvus_types.DropCollectionRequest: return milvus_types.DropCollectionRequest(collection_name=collection_name) @classmethod def describe_collection_request( cls, collection_name: str, ) -> milvus_types.DescribeCollectionRequest: return milvus_types.DescribeCollectionRequest(collection_name=collection_name) @classmethod def alter_collection_request( cls, collection_name: str, properties: Dict, ) -> milvus_types.AlterCollectionRequest: kvs = [common_types.KeyValuePair(key=k, value=str(v)) for k, v in properties.items()] return milvus_types.AlterCollectionRequest(collection_name=collection_name, properties=kvs) @classmethod def collection_stats_request(cls, collection_name: str): return milvus_types.CollectionStatsRequest(collection_name=collection_name) @classmethod def show_collections_request(cls, collection_names: Optional[List[str]] = None): req = milvus_types.ShowCollectionsRequest() if collection_names: if not isinstance(collection_names, (list,)): msg = f"collection_names must be a list of strings, but got: {collection_names}" raise ParamError(message=msg) for collection_name in collection_names: check_pass_param(collection_name=collection_name) req.collection_names.extend(collection_names) req.type = milvus_types.ShowType.InMemory return req @classmethod def rename_collections_request(cls, old_name: str, new_name: str, new_db_name: str): return milvus_types.RenameCollectionRequest( oldName=old_name, newName=new_name, newDBName=new_db_name ) @classmethod def create_partition_request(cls, collection_name: str, partition_name: str): return milvus_types.CreatePartitionRequest( collection_name=collection_name, partition_name=partition_name ) @classmethod def drop_partition_request(cls, collection_name: str, partition_name: str): return milvus_types.DropPartitionRequest( collection_name=collection_name, partition_name=partition_name ) @classmethod def has_partition_request(cls, collection_name: str, partition_name: str): return milvus_types.HasPartitionRequest( collection_name=collection_name, partition_name=partition_name ) @classmethod def partition_stats_request(cls, collection_name: str, partition_name: str): return milvus_types.PartitionStatsRequest( collection_name=collection_name, partition_name=partition_name ) @classmethod def show_partitions_request( cls, collection_name: str, partition_names: Optional[List[str]] = None, type_in_memory: bool = False, ): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.ShowPartitionsRequest(collection_name=collection_name) if partition_names: if not isinstance(partition_names, (list,)): msg = f"partition_names must be a list of strings, but got: {partition_names}" raise ParamError(message=msg) for partition_name in partition_names: check_pass_param(partition_name=partition_name) req.partition_names.extend(partition_names) if type_in_memory is False: req.type = milvus_types.ShowType.All else: req.type = milvus_types.ShowType.InMemory return req @classmethod def get_loading_progress( cls, collection_name: str, partition_names: Optional[List[str]] = None ): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.GetLoadingProgressRequest(collection_name=collection_name) if partition_names: req.partition_names.extend(partition_names) return req @classmethod def get_load_state(cls, collection_name: str, partition_names: Optional[List[str]] = None): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.GetLoadStateRequest(collection_name=collection_name) if partition_names: req.partition_names.extend(partition_names) return req @classmethod def empty(cls): msg = "no empty request later" raise DeprecationWarning(msg) @classmethod def register_link_request(cls): return milvus_types.RegisterLinkRequest() @classmethod def partition_name(cls, collection_name: str, partition_name: str): if not isinstance(collection_name, str): raise ParamError(message="collection_name must be of str type") if not isinstance(partition_name, str): raise ParamError(message="partition_name must be of str type") return milvus_types.PartitionName(collection_name=collection_name, tag=partition_name) @staticmethod def _is_input_field(field: Dict, is_upsert: bool): return (not field.get("auto_id", False) or is_upsert) and not field.get( "is_function_output", False ) @staticmethod def _function_output_field_names(fields_info: List[Dict]): return [field["name"] for field in fields_info if field.get("is_function_output", False)] @staticmethod def _num_input_fields(fields_info: List[Dict], is_upsert: bool): return len([field for field in fields_info if Prepare._is_input_field(field, is_upsert)]) @staticmethod def _parse_row_request( request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], fields_info: List[Dict], enable_dynamic: bool, entities: List, ): input_fields_info = [ field for field in fields_info if Prepare._is_input_field(field, is_upsert=False) ] function_output_field_names = Prepare._function_output_field_names(fields_info) fields_data = { field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"]) for field in input_fields_info } field_info_map = {field["name"]: field for field in input_fields_info} if enable_dynamic: d_field = schema_types.FieldData( field_name=DYNAMIC_FIELD_NAME, is_dynamic=True, type=DataType.JSON ) fields_data[d_field.field_name] = d_field field_info_map[d_field.field_name] = d_field try: for entity in entities: if not isinstance(entity, Dict): msg = f"expected Dict, got '{type(entity).__name__}'" raise TypeError(msg) for k, v in entity.items(): if k not in fields_data: if k in function_output_field_names: raise DataNotMatchException( message=ExceptionsMessage.InsertUnexpectedFunctionOutputField % k ) if not enable_dynamic: raise DataNotMatchException( message=ExceptionsMessage.InsertUnexpectedField % k ) if k in fields_data: field_info, field_data = field_info_map[k], fields_data[k] if field_info.get("nullable", False) or field_info.get( "default_value", None ): field_data.valid_data.append(v is not None) entity_helper.pack_field_value_to_field_data(v, field_data, field_info) for field in input_fields_info: key = field["name"] if key in entity: continue field_info, field_data = field_info_map[key], fields_data[key] if field_info.get("nullable", False) or field_info.get("default_value", None): field_data.valid_data.append(False) entity_helper.pack_field_value_to_field_data(None, field_data, field_info) else: raise DataNotMatchException( message=ExceptionsMessage.InsertMissedField % key ) json_dict = { k: v for k, v in entity.items() if k not in fields_data and enable_dynamic } if enable_dynamic: json_value = entity_helper.convert_to_json(json_dict) d_field.scalars.json_data.data.append(json_value) except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e request.fields_data.extend(fields_data.values()) expected_num_input_fields = len(input_fields_info) + (1 if enable_dynamic else 0) if len(fields_data) != expected_num_input_fields: msg = f"{ExceptionsMessage.FieldsNumInconsistent}, expected {expected_num_input_fields} fields, got {len(fields_data)}" raise ParamError(message=msg) return request @staticmethod def _parse_upsert_row_request( request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], fields_info: List[Dict], enable_dynamic: bool, entities: List, ): input_fields_info = [ field for field in fields_info if Prepare._is_input_field(field, is_upsert=True) ] function_output_field_names = Prepare._function_output_field_names(fields_info) fields_data = { field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"]) for field in input_fields_info } field_info_map = {field["name"]: field for field in input_fields_info} if enable_dynamic: d_field = schema_types.FieldData( field_name=DYNAMIC_FIELD_NAME, is_dynamic=True, type=DataType.JSON ) fields_data[d_field.field_name] = d_field field_info_map[d_field.field_name] = d_field try: for entity in entities: if not isinstance(entity, Dict): msg = f"expected Dict, got '{type(entity).__name__}'" raise TypeError(msg) for k, v in entity.items(): if k not in fields_data: if k in function_output_field_names: raise DataNotMatchException( message=ExceptionsMessage.InsertUnexpectedFunctionOutputField % k ) if not enable_dynamic: raise DataNotMatchException( message=ExceptionsMessage.InsertUnexpectedField % k ) if k in fields_data: field_info, field_data = field_info_map[k], fields_data[k] if field_info.get("nullable", False) or field_info.get( "default_value", None ): field_data.valid_data.append(v is not None) entity_helper.pack_field_value_to_field_data(v, field_data, field_info) for field in input_fields_info: key = field["name"] if key in entity: continue field_info, field_data = field_info_map[key], fields_data[key] if field_info.get("nullable", False) or field_info.get("default_value", None): field_data.valid_data.append(False) entity_helper.pack_field_value_to_field_data(None, field_data, field_info) else: raise DataNotMatchException( message=ExceptionsMessage.InsertMissedField % key ) json_dict = { k: v for k, v in entity.items() if k not in fields_data and enable_dynamic } if enable_dynamic: json_value = entity_helper.convert_to_json(json_dict) d_field.scalars.json_data.data.append(json_value) except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e request.fields_data.extend(fields_data.values()) for _, field in enumerate(input_fields_info): is_dynamic = False field_name = field["name"] if field.get("is_dynamic", False): is_dynamic = True for j, entity in enumerate(entities): if is_dynamic and field_name in entity: raise ParamError( message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]" ) expected_num_input_fields = len(input_fields_info) + (1 if enable_dynamic else 0) if len(fields_data) != expected_num_input_fields: msg = f"{ExceptionsMessage.FieldsNumInconsistent}, expected {expected_num_input_fields} fields, got {len(fields_data)}" raise ParamError(message=msg) return request @classmethod def row_insert_param( cls, collection_name: str, entities: List, partition_name: str, fields_info: Dict, enable_dynamic: bool = False, ): if not fields_info: raise ParamError(message="Missing collection meta to validate entities") # insert_request.hash_keys won't be filled in client. p_name = partition_name if isinstance(partition_name, str) else "" request = milvus_types.InsertRequest( collection_name=collection_name, partition_name=p_name, num_rows=len(entities), ) return cls._parse_row_request(request, fields_info, enable_dynamic, entities) @classmethod def row_upsert_param( cls, collection_name: str, entities: List, partition_name: str, fields_info: Any, enable_dynamic: bool = False, ): if not fields_info: raise ParamError(message="Missing collection meta to validate entities") # upsert_request.hash_keys won't be filled in client. p_name = partition_name if isinstance(partition_name, str) else "" request = milvus_types.UpsertRequest( collection_name=collection_name, partition_name=p_name, num_rows=len(entities), ) return cls._parse_upsert_row_request(request, fields_info, enable_dynamic, entities) @staticmethod def _pre_insert_batch_check( entities: List, fields_info: Any, ): for entity in entities: if ( entity.get("name") is None or entity.get("values") is None or entity.get("type") is None ): raise ParamError( message="Missing param in entities, a field must have type, name and values" ) if not fields_info: raise ParamError(message="Missing collection meta to validate entities") location, primary_key_loc, _ = traverse_info(fields_info) # though impossible from sdk if primary_key_loc is None: raise ParamError(message="primary key not found") expected_num_input_fields = Prepare._num_input_fields(fields_info, is_upsert=False) if len(entities) != expected_num_input_fields: msg = f"expected number of fields: {expected_num_input_fields}, actual number of fields in entities: {len(entities)}" raise ParamError(message=msg) return location @staticmethod def _pre_upsert_batch_check( entities: List, fields_info: Any, ): for entity in entities: if ( entity.get("name") is None or entity.get("values") is None or entity.get("type") is None ): raise ParamError( message="Missing param in entities, a field must have type, name and values" ) if not fields_info: raise ParamError(message="Missing collection meta to validate entities") location, primary_key_loc = traverse_upsert_info(fields_info) # though impossible from sdk if primary_key_loc is None: raise ParamError(message="primary key not found") expected_num_input_fields = Prepare._num_input_fields(fields_info, is_upsert=True) if len(entities) != expected_num_input_fields: msg = f"expected number of fields: {expected_num_input_fields}, actual number of fields in entities: {len(entities)}" raise ParamError(message=msg) return location @staticmethod def _parse_batch_request( request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], entities: List, fields_info: Any, location: Dict, ): pre_field_size = 0 try: for entity in entities: latest_field_size = entity_helper.get_input_num_rows(entity.get("values")) if latest_field_size != 0: if pre_field_size not in (0, latest_field_size): raise ParamError( message=( f"Field data size misaligned for field [{entity.get('name')}] ", f"got size=[{latest_field_size}] ", f"alignment size=[{pre_field_size}]", ) ) pre_field_size = latest_field_size if pre_field_size == 0: raise ParamError(message=ExceptionsMessage.NumberRowsInvalid) request.num_rows = pre_field_size for entity in entities: field_data = entity_helper.entity_to_field_data( entity, fields_info[location[entity.get("name")]], request.num_rows ) request.fields_data.append(field_data) except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e if pre_field_size == 0: raise ParamError(message=ExceptionsMessage.NumberRowsInvalid) request.num_rows = pre_field_size return request @classmethod def batch_insert_param( cls, collection_name: str, entities: List, partition_name: str, fields_info: Any, ): location = cls._pre_insert_batch_check(entities, fields_info) tag = partition_name if isinstance(partition_name, str) else "" request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag) return cls._parse_batch_request(request, entities, fields_info, location) @classmethod def batch_upsert_param( cls, collection_name: str, entities: List, partition_name: str, fields_info: Any, ): location = cls._pre_upsert_batch_check(entities, fields_info) tag = partition_name if isinstance(partition_name, str) else "" request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag) return cls._parse_batch_request(request, entities, fields_info, location) @classmethod def delete_request( cls, collection_name: str, filter: str, partition_name: Optional[str] = None, consistency_level: Optional[Union[int, str]] = None, **kwargs, ): check.validate_strs( collection_name=collection_name, filter=filter, ) check.validate_nullable_strs(partition_name=partition_name) return milvus_types.DeleteRequest( collection_name=collection_name, partition_name=partition_name, expr=filter, consistency_level=get_consistency_level(consistency_level), expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})), ) @classmethod def _prepare_placeholder_str(cls, data: Any): # sparse vector if entity_helper.entity_is_sparse_matrix(data): pl_type = PlaceholderType.SparseFloatVector pl_values = entity_helper.sparse_rows_to_proto(data).contents elif isinstance(data[0], np.ndarray): dtype = data[0].dtype if dtype == "bfloat16": pl_type = PlaceholderType.BFLOAT16_VECTOR pl_values = (array.tobytes() for array in data) elif dtype == "float16": pl_type = PlaceholderType.FLOAT16_VECTOR pl_values = (array.tobytes() for array in data) elif dtype in ("float32", "float64"): pl_type = PlaceholderType.FloatVector pl_values = (blob.vector_float_to_bytes(entity) for entity in data) elif dtype == "byte": pl_type = PlaceholderType.BinaryVector pl_values = data else: err_msg = f"unsupported data type: {dtype}" raise ParamError(message=err_msg) elif isinstance(data[0], bytes): pl_type = PlaceholderType.BinaryVector pl_values = data # data is already a list of bytes elif isinstance(data[0], str): pl_type = PlaceholderType.VARCHAR pl_values = (value.encode("utf-8") for value in data) else: pl_type = PlaceholderType.FloatVector pl_values = (blob.vector_float_to_bytes(entity) for entity in data) pl = common_types.PlaceholderValue(tag="$0", type=pl_type, values=pl_values) return common_types.PlaceholderGroup.SerializeToString( common_types.PlaceholderGroup(placeholders=[pl]) ) @classmethod def prepare_expression_template(cls, values: Dict) -> Any: def all_elements_same_type(lst: List): return all(isinstance(item, type(lst[0])) for item in lst) def add_array_data(v: List) -> schema_types.TemplateArrayValue: data = schema_types.TemplateArrayValue() if len(v) == 0: return data element_type = ( infer_dtype_by_scalar_data(v[0]) if all_elements_same_type(v) else schema_types.JSON ) if element_type in (schema_types.Bool,): data.bool_data.data.extend(v) return data if element_type in ( schema_types.Int8, schema_types.Int16, schema_types.Int32, schema_types.Int64, ): data.long_data.data.extend(v) return data if element_type in (schema_types.Float, schema_types.Double): data.double_data.data.extend(v) return data if element_type in (schema_types.VarChar, schema_types.String): data.string_data.data.extend(v) return data if element_type in (schema_types.Array,): for e in v: data.array_data.data.append(add_array_data(e)) return data if element_type in (schema_types.JSON,): for e in v: data.json_data.data.append(entity_helper.convert_to_json(e)) return data raise ParamError(message=f"Unsupported element type: {element_type}") def add_data(v: Any) -> schema_types.TemplateValue: dtype = infer_dtype_by_scalar_data(v) data = schema_types.TemplateValue() if dtype in (schema_types.Bool,): data.bool_val = v return data if dtype in ( schema_types.Int8, schema_types.Int16, schema_types.Int32, schema_types.Int64, ): data.int64_val = v return data if dtype in (schema_types.Float, schema_types.Double): data.float_val = v return data if dtype in (schema_types.VarChar, schema_types.String): data.string_val = v return data if dtype in (schema_types.Array,): data.array_val.CopyFrom(add_array_data(v)) return data raise ParamError(message=f"Unsupported element type: {dtype}") expression_template_values = {} for k, v in values.items(): expression_template_values[k] = add_data(v) return expression_template_values @classmethod def search_requests_with_expr( cls, collection_name: str, data: Union[List, utils.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, expr: Optional[str] = None, partition_names: Optional[List[str]] = None, output_fields: Optional[List[str]] = None, round_decimal: int = -1, **kwargs, ) -> milvus_types.SearchRequest: use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) ignore_growing = param.get("ignore_growing", False) or kwargs.get("ignore_growing", False) params = param.get("params", {}) if not isinstance(params, dict): raise ParamError(message=f"Search params must be a dict, got {type(params)}") if PAGE_RETAIN_ORDER_FIELD in kwargs and PAGE_RETAIN_ORDER_FIELD in param: raise ParamError( message="Provide page_retain_order both in kwargs and param, expect just one" ) page_retain_order = kwargs.get(PAGE_RETAIN_ORDER_FIELD) or param.get( PAGE_RETAIN_ORDER_FIELD ) if page_retain_order is not None: if not isinstance(page_retain_order, bool): raise ParamError( message=f"wrong type for page_retain_order, expect bool, got {type(page_retain_order)}" ) params[PAGE_RETAIN_ORDER_FIELD] = page_retain_order search_params = { "topk": limit, "params": params, "round_decimal": round_decimal, "ignore_growing": ignore_growing, } # parse offset if "offset" in kwargs and "offset" in param: raise ParamError(message="Provide offset both in kwargs and param, expect just one") offset = kwargs.get("offset") or param.get("offset") if offset is not None: if not isinstance(offset, int): raise ParamError(message=f"wrong type for offset, expect int, got {type(offset)}") search_params["offset"] = offset is_iterator = kwargs.get(ITERATOR_FIELD) if is_iterator is not None: search_params[ITERATOR_FIELD] = is_iterator group_by_field = kwargs.get(GROUP_BY_FIELD) if group_by_field is not None: search_params[GROUP_BY_FIELD] = group_by_field group_size = kwargs.get(GROUP_SIZE) if group_size is not None: search_params[GROUP_SIZE] = group_size strict_group_size = kwargs.get(STRICT_GROUP_SIZE) if strict_group_size is not None: search_params[STRICT_GROUP_SIZE] = strict_group_size if param.get("metric_type") is not None: search_params["metric_type"] = param["metric_type"] if anns_field: search_params["anns_field"] = anns_field req_params = [ common_types.KeyValuePair(key=str(key), value=utils.dumps(value)) for key, value in search_params.items() ] nq = entity_helper.get_input_num_rows(data) plg_str = cls._prepare_placeholder_str(data) request = milvus_types.SearchRequest( collection_name=collection_name, partition_names=partition_names, output_fields=output_fields, guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), use_default_consistency=use_default_consistency, consistency_level=kwargs.get("consistency_level", 0), nq=nq, placeholder_group=plg_str, dsl_type=common_types.DslType.BoolExprV1, search_params=req_params, expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})), ) if expr is not None: request.dsl = expr return request @classmethod def hybrid_search_request_with_ranker( cls, collection_name: str, reqs: List, rerank_param: Dict, limit: int, partition_names: Optional[List[str]] = None, output_fields: Optional[List[str]] = None, round_decimal: int = -1, **kwargs, ) -> milvus_types.HybridSearchRequest: use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) rerank_param["limit"] = limit rerank_param["round_decimal"] = round_decimal rerank_param["offset"] = kwargs.get("offset", 0) request = milvus_types.HybridSearchRequest( collection_name=collection_name, partition_names=partition_names, requests=reqs, output_fields=output_fields, guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), use_default_consistency=use_default_consistency, consistency_level=kwargs.get("consistency_level", 0), ) request.rank_params.extend( [ common_types.KeyValuePair(key=str(key), value=utils.dumps(value)) for key, value in rerank_param.items() ] ) if kwargs.get(RANK_GROUP_SCORER) is not None: request.rank_params.extend( [ common_types.KeyValuePair( key=RANK_GROUP_SCORER, value=kwargs.get(RANK_GROUP_SCORER) ) ] ) if kwargs.get(GROUP_BY_FIELD) is not None: request.rank_params.extend( [ common_types.KeyValuePair( key=GROUP_BY_FIELD, value=utils.dumps(kwargs.get(GROUP_BY_FIELD)) ) ] ) if kwargs.get(GROUP_SIZE) is not None: request.rank_params.extend( [ common_types.KeyValuePair( key=GROUP_SIZE, value=utils.dumps(kwargs.get(GROUP_SIZE)) ) ] ) if kwargs.get(STRICT_GROUP_SIZE) is not None: request.rank_params.extend( [ common_types.KeyValuePair( key=STRICT_GROUP_SIZE, value=utils.dumps(kwargs.get(STRICT_GROUP_SIZE)) ) ] ) return request @classmethod def create_alias_request(cls, collection_name: str, alias: str): return milvus_types.CreateAliasRequest(collection_name=collection_name, alias=alias) @classmethod def drop_alias_request(cls, alias: str): return milvus_types.DropAliasRequest(alias=alias) @classmethod def alter_alias_request(cls, collection_name: str, alias: str): return milvus_types.AlterAliasRequest(collection_name=collection_name, alias=alias) @classmethod def describe_alias_request(cls, alias: str): return milvus_types.DescribeAliasRequest(alias=alias) @classmethod def list_aliases_request(cls, collection_name: str, db_name: str = ""): return milvus_types.ListAliasesRequest(collection_name=collection_name, db_name=db_name) @classmethod def create_index_request(cls, collection_name: str, field_name: str, params: Dict, **kwargs): index_params = milvus_types.CreateIndexRequest( collection_name=collection_name, field_name=field_name, index_name=kwargs.get("index_name", ""), ) if isinstance(params, dict): for tk, tv in params.items(): if tk == "dim" and (not tv or not isinstance(tv, int)): raise ParamError(message="dim must be of int!") kv_pair = common_types.KeyValuePair(key=str(tk), value=utils.dumps(tv)) index_params.extra_params.append(kv_pair) return index_params @classmethod def alter_index_request(cls, collection_name: str, index_name: str, extra_params: dict): params = [] for k, v in extra_params.items(): params.append(common_types.KeyValuePair(key=str(k), value=utils.dumps(v))) return milvus_types.AlterIndexRequest( collection_name=collection_name, index_name=index_name, extra_params=params ) @classmethod def describe_index_request( cls, collection_name: str, index_name: str, timestamp: Optional[int] = None ): return milvus_types.DescribeIndexRequest( collection_name=collection_name, index_name=index_name, timestamp=timestamp ) @classmethod def get_index_build_progress(cls, collection_name: str, index_name: str): return milvus_types.GetIndexBuildProgressRequest( collection_name=collection_name, index_name=index_name ) @classmethod def get_index_state_request(cls, collection_name: str, index_name: str): return milvus_types.GetIndexStateRequest( collection_name=collection_name, index_name=index_name ) @classmethod def load_collection( cls, db_name: str, collection_name: str, replica_number: int, refresh: bool, resource_groups: List[str], load_fields: List[str], skip_load_dynamic_field: bool, ): return milvus_types.LoadCollectionRequest( db_name=db_name, collection_name=collection_name, replica_number=replica_number, refresh=refresh, resource_groups=resource_groups, load_fields=load_fields, skip_load_dynamic_field=skip_load_dynamic_field, ) @classmethod def release_collection(cls, db_name: str, collection_name: str): return milvus_types.ReleaseCollectionRequest( db_name=db_name, collection_name=collection_name ) @classmethod def load_partitions( cls, db_name: str, collection_name: str, partition_names: List[str], replica_number: int, refresh: bool, resource_groups: List[str], load_fields: List[str], skip_load_dynamic_field: bool, ): return milvus_types.LoadPartitionsRequest( db_name=db_name, collection_name=collection_name, partition_names=partition_names, replica_number=replica_number, refresh=refresh, resource_groups=resource_groups, load_fields=load_fields, skip_load_dynamic_field=skip_load_dynamic_field, ) @classmethod def release_partitions(cls, db_name: str, collection_name: str, partition_names: List[str]): return milvus_types.ReleasePartitionsRequest( db_name=db_name, collection_name=collection_name, partition_names=partition_names ) @classmethod def get_collection_stats_request(cls, collection_name: str): return milvus_types.GetCollectionStatisticsRequest(collection_name=collection_name) @classmethod def get_persistent_segment_info_request(cls, collection_name: str): return milvus_types.GetPersistentSegmentInfoRequest(collectionName=collection_name) @classmethod def get_flush_state_request(cls, segment_ids: List[int], collection_name: str, flush_ts: int): return milvus_types.GetFlushStateRequest( segmentIDs=segment_ids, collection_name=collection_name, flush_ts=flush_ts ) @classmethod def get_query_segment_info_request(cls, collection_name: str): return milvus_types.GetQuerySegmentInfoRequest(collectionName=collection_name) @classmethod def flush_param(cls, collection_names: List[str]): return milvus_types.FlushRequest(collection_names=collection_names) @classmethod def drop_index_request(cls, collection_name: str, field_name: str, index_name: str): return milvus_types.DropIndexRequest( db_name="", collection_name=collection_name, field_name=field_name, index_name=index_name, ) @classmethod def get_partition_stats_request(cls, collection_name: str, partition_name: str): return milvus_types.GetPartitionStatisticsRequest( db_name="", collection_name=collection_name, partition_name=partition_name ) @classmethod def dummy_request(cls, request_type: Any): return milvus_types.DummyRequest(request_type=request_type) @classmethod def retrieve_request( cls, collection_name: str, ids: List[str], output_fields: List[str], partition_names: List[str], ): ids = schema_types.IDs(int_id=schema_types.LongArray(data=ids)) return milvus_types.RetrieveRequest( db_name="", collection_name=collection_name, ids=ids, output_fields=output_fields, partition_names=partition_names, ) @classmethod def query_request( cls, collection_name: str, expr: str, output_fields: List[str], partition_names: List[str], **kwargs, ): use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) req = milvus_types.QueryRequest( db_name="", collection_name=collection_name, expr=expr, output_fields=output_fields, partition_names=partition_names, guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), use_default_consistency=use_default_consistency, consistency_level=kwargs.get("consistency_level", 0), expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})), ) limit = kwargs.get("limit") if limit is not None: req.query_params.append(common_types.KeyValuePair(key="limit", value=str(limit))) offset = kwargs.get("offset") if offset is not None: req.query_params.append(common_types.KeyValuePair(key="offset", value=str(offset))) ignore_growing = kwargs.get("ignore_growing", False) stop_reduce_for_best = kwargs.get(REDUCE_STOP_FOR_BEST, False) is_iterator = kwargs.get(ITERATOR_FIELD) if is_iterator is not None: req.query_params.append( common_types.KeyValuePair(key=ITERATOR_FIELD, value=is_iterator) ) req.query_params.append( common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing)) ) req.query_params.append( common_types.KeyValuePair(key=REDUCE_STOP_FOR_BEST, value=str(stop_reduce_for_best)) ) return req @classmethod def load_balance_request( cls, collection_name: str, src_node_id: int, dst_node_ids: List[int], sealed_segment_ids: List[int], ): return milvus_types.LoadBalanceRequest( collectionName=collection_name, src_nodeID=src_node_id, dst_nodeIDs=dst_node_ids, sealed_segmentIDs=sealed_segment_ids, ) @classmethod def manual_compaction(cls, collection_id: int, is_clustering: bool): if collection_id is None or not isinstance(collection_id, int): raise ParamError(message=f"collection_id value {collection_id} is illegal") if is_clustering is None or not isinstance(is_clustering, bool): raise ParamError(message=f"is_clustering value {is_clustering} is illegal") request = milvus_types.ManualCompactionRequest() request.collectionID = collection_id request.majorCompaction = is_clustering return request @classmethod def get_compaction_state(cls, compaction_id: int): if compaction_id is None or not isinstance(compaction_id, int): raise ParamError(message=f"compaction_id value {compaction_id} is illegal") request = milvus_types.GetCompactionStateRequest() request.compactionID = compaction_id return request @classmethod def get_compaction_state_with_plans(cls, compaction_id: int): if compaction_id is None or not isinstance(compaction_id, int): raise ParamError(message=f"compaction_id value {compaction_id} is illegal") request = milvus_types.GetCompactionPlansRequest() request.compactionID = compaction_id return request @classmethod def get_replicas(cls, collection_id: int): if collection_id is None or not isinstance(collection_id, int): raise ParamError(message=f"collection_id value {collection_id} is illegal") return milvus_types.GetReplicasRequest( collectionID=collection_id, with_shard_nodes=True, ) @classmethod def do_bulk_insert(cls, collection_name: str, partition_name: str, files: list, **kwargs): channel_names = kwargs.get("channel_names") req = milvus_types.ImportRequest( collection_name=collection_name, partition_name=partition_name, files=files, ) if channel_names is not None: req.channel_names.extend(channel_names) for k, v in kwargs.items(): if k in ("bucket", "backup", "sep", "nullkey"): kv_pair = common_types.KeyValuePair(key=str(k), value=str(v)) req.options.append(kv_pair) return req @classmethod def get_bulk_insert_state(cls, task_id: int): if task_id is None or not isinstance(task_id, int): msg = f"task_id value {task_id} is not an integer" raise ParamError(message=msg) return milvus_types.GetImportStateRequest(task=task_id) @classmethod def list_bulk_insert_tasks(cls, limit: int, collection_name: str): if limit is None or not isinstance(limit, int): msg = f"limit value {limit} is not an integer" raise ParamError(message=msg) return milvus_types.ListImportTasksRequest( collection_name=collection_name, limit=limit, ) @classmethod def create_user_request(cls, user: str, password: str): check_pass_param(user=user, password=password) return milvus_types.CreateCredentialRequest( username=user, password=base64.b64encode(password.encode("utf-8")) ) @classmethod def update_password_request(cls, user: str, old_password: str, new_password: str): check_pass_param(user=user) check_pass_param(password=old_password) check_pass_param(password=new_password) return milvus_types.UpdateCredentialRequest( username=user, oldPassword=base64.b64encode(old_password.encode("utf-8")), newPassword=base64.b64encode(new_password.encode("utf-8")), ) @classmethod def delete_user_request(cls, user: str): if not isinstance(user, str): raise ParamError(message=f"invalid user {user}") return milvus_types.DeleteCredentialRequest(username=user) @classmethod def list_usernames_request(cls): return milvus_types.ListCredUsersRequest() @classmethod def create_role_request(cls, role_name: str): check_pass_param(role_name=role_name) return milvus_types.CreateRoleRequest(entity=milvus_types.RoleEntity(name=role_name)) @classmethod def drop_role_request(cls, role_name: str): check_pass_param(role_name=role_name) return milvus_types.DropRoleRequest(role_name=role_name) @classmethod def operate_user_role_request(cls, username: str, role_name: str, operate_user_role_type: Any): check_pass_param(user=username) check_pass_param(role_name=role_name) check_pass_param(operate_user_role_type=operate_user_role_type) return milvus_types.OperateUserRoleRequest( username=username, role_name=role_name, type=operate_user_role_type ) @classmethod def select_role_request(cls, role_name: str, include_user_info: bool): if role_name: check_pass_param(role_name=role_name) check_pass_param(include_user_info=include_user_info) return milvus_types.SelectRoleRequest( role=milvus_types.RoleEntity(name=role_name) if role_name else None, include_user_info=include_user_info, ) @classmethod def select_user_request(cls, username: str, include_role_info: bool): if username: check_pass_param(user=username) check_pass_param(include_role_info=include_role_info) return milvus_types.SelectUserRequest( user=milvus_types.UserEntity(name=username) if username else None, include_role_info=include_role_info, ) @classmethod def operate_privilege_request( cls, role_name: str, object: Any, object_name: str, privilege: str, db_name: str, operate_privilege_type: Any, ): check_pass_param(role_name=role_name) check_pass_param(object=object) check_pass_param(object_name=object_name) check_pass_param(privilege=privilege) check_pass_param(operate_privilege_type=operate_privilege_type) return milvus_types.OperatePrivilegeRequest( entity=milvus_types.GrantEntity( role=milvus_types.RoleEntity(name=role_name), object=milvus_types.ObjectEntity(name=object), object_name=object_name, db_name=db_name, grantor=milvus_types.GrantorEntity( privilege=milvus_types.PrivilegeEntity(name=privilege) ), ), type=operate_privilege_type, ) @classmethod def operate_privilege_v2_request( cls, role_name: str, privilege: str, operate_privilege_type: Any, db_name: str, collection_name: str, ): check_pass_param( role_name=role_name, privilege=privilege, operate_privilege_type=operate_privilege_type, db_name=db_name, collection_name=collection_name, ) return milvus_types.OperatePrivilegeV2Request( role=milvus_types.RoleEntity(name=role_name), grantor=milvus_types.GrantorEntity( privilege=milvus_types.PrivilegeEntity(name=privilege) ), type=operate_privilege_type, db_name=db_name, collection_name=collection_name, ) @classmethod def select_grant_request(cls, role_name: str, object: str, object_name: str, db_name: str): check_pass_param(role_name=role_name) if object: check_pass_param(object=object) if object_name: check_pass_param(object_name=object_name) return milvus_types.SelectGrantRequest( entity=milvus_types.GrantEntity( role=milvus_types.RoleEntity(name=role_name), object=milvus_types.ObjectEntity(name=object) if object else None, object_name=object_name if object_name else None, db_name=db_name, ), ) @classmethod def get_server_version(cls): return milvus_types.GetVersionRequest() @classmethod def create_resource_group(cls, name: str, **kwargs): check_pass_param(resource_group_name=name) return milvus_types.CreateResourceGroupRequest( resource_group=name, config=kwargs.get("config"), ) @classmethod def update_resource_groups(cls, configs: Mapping[str, ResourceGroupConfig]): return milvus_types.UpdateResourceGroupsRequest( resource_groups=configs, ) @classmethod def drop_resource_group(cls, name: str): check_pass_param(resource_group_name=name) return milvus_types.DropResourceGroupRequest(resource_group=name) @classmethod def list_resource_groups(cls): return milvus_types.ListResourceGroupsRequest() @classmethod def describe_resource_group(cls, name: str): check_pass_param(resource_group_name=name) return milvus_types.DescribeResourceGroupRequest(resource_group=name) @classmethod def transfer_node(cls, source: str, target: str, num_node: int): check_pass_param(resource_group_name=source) check_pass_param(resource_group_name=target) return milvus_types.TransferNodeRequest( source_resource_group=source, target_resource_group=target, num_node=num_node ) @classmethod def transfer_replica(cls, source: str, target: str, collection_name: str, num_replica: int): check_pass_param(resource_group_name=source) check_pass_param(resource_group_name=target) return milvus_types.TransferReplicaRequest( source_resource_group=source, target_resource_group=target, collection_name=collection_name, num_replica=num_replica, ) @classmethod def flush_all_request(cls, db_name: str): return milvus_types.FlushAllRequest(db_name=db_name) @classmethod def get_flush_all_state_request(cls, flush_all_ts: int, db_name: str): return milvus_types.GetFlushAllStateRequest(flush_all_ts=flush_all_ts, db_name=db_name) @classmethod def register_request(cls, user: str, host: str, **kwargs): reserved = {} for k, v in kwargs.items(): reserved[k] = v now = datetime.datetime.now() this = common_types.ClientInfo( sdk_type="Python", sdk_version=__version__, local_time=now.__str__(), reserved=reserved, ) if user is not None: this.user = user if host is not None: this.host = host return milvus_types.ConnectRequest( client_info=this, ) @classmethod def create_database_req(cls, db_name: str, **kwargs): check_pass_param(db_name=db_name) req = milvus_types.CreateDatabaseRequest(db_name=db_name) properties = kwargs.get("properties") if is_legal_collection_properties(properties): properties = [ common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in properties.items() ] req.properties.extend(properties) return req @classmethod def drop_database_req(cls, db_name: str): check_pass_param(db_name=db_name) return milvus_types.DropDatabaseRequest(db_name=db_name) @classmethod def list_database_req(cls): return milvus_types.ListDatabasesRequest() @classmethod def alter_database_req(cls, db_name: str, properties: Dict): check_pass_param(db_name=db_name) kvs = [common_types.KeyValuePair(key=k, value=str(v)) for k, v in properties.items()] return milvus_types.AlterDatabaseRequest(db_name=db_name, properties=kvs) @classmethod def describe_database_req(cls, db_name: str): check_pass_param(db_name=db_name) return milvus_types.DescribeDatabaseRequest(db_name=db_name) @classmethod def create_privilege_group_req(cls, privilege_group: str): check_pass_param(privilege_group=privilege_group) return milvus_types.CreatePrivilegeGroupRequest(group_name=privilege_group) @classmethod def drop_privilege_group_req(cls, privilege_group: str): check_pass_param(privilege_group=privilege_group) return milvus_types.DropPrivilegeGroupRequest(group_name=privilege_group) @classmethod def list_privilege_groups_req(cls): return milvus_types.ListPrivilegeGroupsRequest() @classmethod def operate_privilege_group_req( cls, privilege_group: str, privileges: List[str], operate_privilege_group_type: Any ): check_pass_param(privilege_group=privilege_group) check_pass_param(privileges=privileges) check_pass_param(operate_privilege_group_type=operate_privilege_group_type) return milvus_types.OperatePrivilegeGroupRequest( group_name=privilege_group, privileges=[milvus_types.PrivilegeEntity(name=p) for p in privileges], type=operate_privilege_group_type, )
Memory