import abc
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import ujson
from pymilvus.exceptions import DataTypeNotMatchException, ExceptionsMessage, MilvusException
from pymilvus.grpc_gen import common_pb2, schema_pb2
from pymilvus.settings import Config
from . import entity_helper, utils
from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED
from .types import DataType, FunctionType
class FieldSchema:
def __init__(self, raw: Any):
self._raw = raw
self.field_id = 0
self.name = None
self.is_primary = False
self.description = None
self.auto_id = False
self.type = DataType.UNKNOWN
self.indexes = []
self.params = {}
self.is_partition_key = False
self.is_dynamic = False
self.nullable = False
self.default_value = None
self.is_function_output = False
# For array field
self.element_type = None
self.is_clustering_key = False
self.__pack(self._raw)
def __pack(self, raw: Any):
self.field_id = raw.fieldID
self.name = raw.name
self.is_primary = raw.is_primary_key
self.description = raw.description
self.auto_id = raw.autoID
self.type = DataType(raw.data_type)
self.is_partition_key = raw.is_partition_key
self.element_type = DataType(raw.element_type)
self.is_clustering_key = raw.is_clustering_key
self.default_value = raw.default_value
if raw.default_value is not None and raw.default_value.WhichOneof("data") is None:
self.default_value = None
self.is_dynamic = raw.is_dynamic
self.nullable = raw.nullable
self.is_function_output = raw.is_function_output
for type_param in raw.type_params:
if type_param.key == "params":
import json
self.params[type_param.key] = json.loads(type_param.value)
else:
if type_param.key in ["mmap.enabled"]:
self.params["mmap_enabled"] = (
bool(type_param.value) if type_param.value.lower() != "false" else False
)
continue
self.params[type_param.key] = type_param.value
if type_param.key in ["dim"]:
self.params[type_param.key] = int(type_param.value)
if type_param.key in [Config.MaxVarCharLengthKey] and raw.data_type in (
DataType.VARCHAR,
DataType.ARRAY,
):
self.params[type_param.key] = int(type_param.value)
# TO-DO: use constants defined in orm
if type_param.key in ["max_capacity"] and raw.data_type == DataType.ARRAY:
self.params[type_param.key] = int(type_param.value)
index_dict = {}
for index_param in raw.index_params:
if index_param.key == "params":
import json
index_dict[index_param.key] = json.loads(index_param.value)
else:
index_dict[index_param.key] = index_param.value
self.indexes.extend([index_dict])
def dict(self):
_dict = {
"field_id": self.field_id,
"name": self.name,
"description": self.description,
"type": self.type,
"params": self.params or {},
}
if self.default_value is not None:
# default_value is nil match this situation
if self.default_value.WhichOneof("data") is None:
self.default_value = None
else:
_dict["default_value"] = self.default_value
if self.element_type:
_dict["element_type"] = self.element_type
if self.is_partition_key:
_dict["is_partition_key"] = True
if self.is_dynamic:
_dict["is_dynamic"] = True
if self.auto_id:
_dict["auto_id"] = True
if self.nullable:
_dict["nullable"] = True
if self.is_primary:
_dict["is_primary"] = self.is_primary
if self.is_clustering_key:
_dict["is_clustering_key"] = True
if self.is_function_output:
_dict["is_function_output"] = True
return _dict
class FunctionSchema:
def __init__(self, raw: Any):
self._raw = raw
self.name = None
self.description = None
self.type = None
self.params = {}
self.input_field_names = []
self.input_field_ids = []
self.output_field_names = []
self.output_field_ids = []
self.id = 0
self.__pack(self._raw)
def __pack(self, raw: Any):
self.name = raw.name
self.description = raw.description
self.id = raw.id
self.type = FunctionType(raw.type)
self.params = {}
for param in raw.params:
self.params[param.key] = param.value
self.input_field_names = raw.input_field_names
self.input_field_ids = raw.input_field_ids
self.output_field_names = raw.output_field_names
self.output_field_ids = raw.output_field_ids
def dict(self):
return {
"name": self.name,
"id": self.id,
"description": self.description,
"type": self.type,
"params": self.params,
"input_field_names": self.input_field_names,
"input_field_ids": self.input_field_ids,
"output_field_names": self.output_field_names,
"output_field_ids": self.output_field_ids,
}
class CollectionSchema:
def __init__(self, raw: Any):
self._raw = raw
self.collection_name = None
self.description = None
self.params = {}
self.fields = []
self.functions = []
self.statistics = {}
self.auto_id = False # auto_id is not in collection level any more later
self.aliases = []
self.collection_id = 0
self.consistency_level = DEFAULT_CONSISTENCY_LEVEL # by default
self.properties = {}
self.num_shards = 0
self.num_partitions = 0
self.enable_dynamic_field = False
if self._raw:
self.__pack(self._raw)
def __pack(self, raw: Any):
self.collection_name = raw.schema.name
self.description = raw.schema.description
self.aliases = list(raw.aliases)
self.collection_id = raw.collectionID
self.num_shards = raw.shards_num
self.num_partitions = raw.num_partitions
# keep compatible with older Milvus
try:
self.consistency_level = raw.consistency_level
except Exception:
self.consistency_level = DEFAULT_CONSISTENCY_LEVEL
try:
self.enable_dynamic_field = raw.schema.enable_dynamic_field
except Exception:
self.enable_dynamic_field = False
# TODO: extra_params here
# for kv in raw.extra_params:
self.fields = [FieldSchema(f) for f in raw.schema.fields]
self.functions = [FunctionSchema(f) for f in raw.schema.functions]
function_output_field_names = [f for fn in self.functions for f in fn.output_field_names]
for field in self.fields:
if field.name in function_output_field_names:
field.is_function_output = True
# for s in raw.statistics:
for p in raw.properties:
self.properties[p.key] = p.value
@classmethod
def _rewrite_schema_dict(cls, schema_dict: Dict):
fields = schema_dict.get("fields", [])
if not fields:
return
for field_dict in fields:
if field_dict.get("auto_id", None) is not None:
schema_dict["auto_id"] = field_dict["auto_id"]
return
def dict(self):
if not self._raw:
return {}
_dict = {
"collection_name": self.collection_name,
"auto_id": self.auto_id,
"num_shards": self.num_shards,
"description": self.description,
"fields": [f.dict() for f in self.fields],
"functions": [f.dict() for f in self.functions],
"aliases": self.aliases,
"collection_id": self.collection_id,
"consistency_level": self.consistency_level,
"properties": self.properties,
"num_partitions": self.num_partitions,
"enable_dynamic_field": self.enable_dynamic_field,
}
self._rewrite_schema_dict(_dict)
return _dict
def __str__(self):
return self.dict().__str__()
class MutationResult:
def __init__(self, raw: Any):
self._raw = raw
self._primary_keys = []
self._insert_cnt = 0
self._delete_cnt = 0
self._upsert_cnt = 0
self._timestamp = 0
self._succ_index = []
self._err_index = []
self._cost = 0
self._pack(raw)
@property
def primary_keys(self):
return self._primary_keys
@property
def insert_count(self):
return self._insert_cnt
@property
def delete_count(self):
return self._delete_cnt
@property
def upsert_count(self):
return self._upsert_cnt
@property
def timestamp(self):
return self._timestamp
@property
def succ_count(self):
return len(self._succ_index)
@property
def err_count(self):
return len(self._err_index)
@property
def succ_index(self):
return self._succ_index
@property
def err_index(self):
return self._err_index
# The unit of this cost is vcu, similar to token
@property
def cost(self):
return self._cost
def __str__(self):
if self.cost:
return (
f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, "
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count}, "
f"cost: {self._cost})"
)
return (
f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, "
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count}"
)
__repr__ = __str__
# TODO
# def error_code(self):
# pass
#
# def error_reason(self):
# pass
def _pack(self, raw: Any):
which = raw.IDs.WhichOneof("id_field")
if which == "int_id":
self._primary_keys = raw.IDs.int_id.data
elif which == "str_id":
self._primary_keys = raw.IDs.str_id.data
self._insert_cnt = raw.insert_cnt
self._delete_cnt = raw.delete_cnt
self._upsert_cnt = raw.upsert_cnt
self._timestamp = raw.timestamp
self._succ_index = raw.succ_index
self._err_index = raw.err_index
self._cost = int(
raw.status.extra_info["report_value"] if raw.status and raw.status.extra_info else "0"
)
class SequenceIterator:
def __init__(self, seq: Sequence[Any]):
self._seq = seq
self._idx = 0
def __next__(self) -> Any:
if self._idx < len(self._seq):
res = self._seq[self._idx]
self._idx += 1
return res
raise StopIteration
class BaseRanker:
def __int__(self):
return
def dict(self):
return {}
def __str__(self):
return self.dict().__str__()
class RRFRanker(BaseRanker):
def __init__(
self,
k: int = 60,
):
self._strategy = RANKER_TYPE_RRF
self._k = k
def dict(self):
params = {
"k": self._k,
}
return {
"strategy": self._strategy,
"params": params,
}
class WeightedRanker(BaseRanker):
def __init__(self, *nums):
self._strategy = RANKER_TYPE_WEIGHTED
weights = []
for num in nums:
weights.append(num)
self._weights = weights
def dict(self):
params = {
"weights": self._weights,
}
return {
"strategy": self._strategy,
"params": params,
}
class AnnSearchRequest:
def __init__(
self,
data: Union[List, utils.SparseMatrixInputType],
anns_field: str,
param: Dict,
limit: int,
expr: Optional[str] = None,
):
self._data = data
self._anns_field = anns_field
self._param = param
self._limit = limit
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))
self._expr = expr
@property
def data(self):
return self._data
@property
def anns_field(self):
return self._anns_field
@property
def param(self):
return self._param
@property
def limit(self):
return self._limit
@property
def expr(self):
return self._expr
def __str__(self):
return {
"anns_field": self.anns_field,
"param": self.param,
"limit": self.limit,
"expr": self.expr,
}.__str__()
class SearchResult(list):
"""nq results: List[Hits]"""
def __init__(
self,
res: schema_pb2.SearchResultData,
round_decimal: Optional[int] = None,
status: Optional[common_pb2.Status] = None,
session_ts: Optional[int] = 0,
):
self._nq = res.num_queries
all_topks = res.topks
self.cost = int(status.extra_info["report_value"] if status and status.extra_info else "0")
output_fields = res.output_fields
fields_data = res.fields_data
all_pks: List[Union[str, int]] = []
all_scores: List[float] = []
if res.ids.HasField("int_id"):
all_pks = res.ids.int_id.data
elif res.ids.HasField("str_id"):
all_pks = res.ids.str_id.data
if isinstance(round_decimal, int) and round_decimal > 0:
all_scores = [round(x, round_decimal) for x in res.scores]
else:
all_scores = res.scores
data = []
nq_thres = 0
for topk in all_topks:
start, end = nq_thres, nq_thres + topk
nq_th_fields = self.get_fields_by_range(start, end, fields_data)
data.append(
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields)
)
nq_thres += topk
self._session_ts = session_ts
super().__init__(data)
def get_session_ts(self):
return self._session_ts
def get_fields_by_range(
self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]
) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]:
field2data: Dict[str, Tuple[List[Any], schema_pb2.FieldData]] = {}
for field in all_fields_data:
name, scalars, dtype = field.field_name, field.scalars, field.type
field_meta = schema_pb2.FieldData(
type=dtype,
field_name=name,
field_id=field.field_id,
is_dynamic=field.is_dynamic,
)
if dtype == DataType.BOOL:
field2data[name] = (
apply_valid_data(
scalars.bool_data.data[start:end], field.valid_data, start, end
),
field_meta,
)
continue
if dtype in (DataType.INT8, DataType.INT16, DataType.INT32):
field2data[name] = (
apply_valid_data(
scalars.int_data.data[start:end], field.valid_data, start, end
),
field_meta,
)
continue
if dtype == DataType.INT64:
field2data[name] = (
apply_valid_data(
scalars.long_data.data[start:end], field.valid_data, start, end
),
field_meta,
)
continue
if dtype == DataType.FLOAT:
field2data[name] = (
apply_valid_data(
scalars.float_data.data[start:end], field.valid_data, start, end
),
field_meta,
)
continue
if dtype == DataType.DOUBLE:
field2data[name] = (
apply_valid_data(
scalars.double_data.data[start:end], field.valid_data, start, end
),
field_meta,
)
continue
if dtype == DataType.VARCHAR:
field2data[name] = (
apply_valid_data(
scalars.string_data.data[start:end], field.valid_data, start, end
),
field_meta,
)
continue
if dtype == DataType.JSON:
res = apply_valid_data(
scalars.json_data.data[start:end], field.valid_data, start, end
)
json_dict_list = [ujson.loads(item) if item is not None else item for item in res]
field2data[name] = json_dict_list, field_meta
continue
if dtype == DataType.ARRAY:
res = apply_valid_data(
scalars.array_data.data[start:end], field.valid_data, start, end
)
field2data[name] = (
extract_array_row_data(res, scalars.array_data.element_type),
field_meta,
)
continue
# vectors
dim, vectors = field.vectors.dim, field.vectors
field_meta.vectors.dim = dim
if dtype == DataType.FLOAT_VECTOR:
if start == 0 and (end - start) * dim >= len(vectors.float_vector.data):
# If the range equals to the lenth of ectors.float_vector.data, direct return
# it to avoid a copy. This logic improves performance by 25% for the case
# retrival 1536 dim embeddings with topk=16384.
field2data[name] = vectors.float_vector.data, field_meta
else:
field2data[name] = (
vectors.float_vector.data[start * dim : end * dim],
field_meta,
)
continue
if dtype == DataType.BINARY_VECTOR:
field2data[name] = (
vectors.binary_vector[start * (dim // 8) : end * (dim // 8)],
field_meta,
)
continue
# TODO(SPARSE): do we want to allow the user to specify the return format?
if dtype == DataType.SPARSE_FLOAT_VECTOR:
field2data[name] = (
entity_helper.sparse_proto_to_rows(vectors.sparse_float_vector, start, end),
field_meta,
)
continue
if dtype == DataType.BFLOAT16_VECTOR:
field2data[name] = (
vectors.bfloat16_vector[start * (dim * 2) : end * (dim * 2)],
field_meta,
)
continue
if dtype == DataType.FLOAT16_VECTOR:
field2data[name] = (
vectors.float16_vector[start * (dim * 2) : end * (dim * 2)],
field_meta,
)
continue
return field2data
def __iter__(self) -> SequenceIterator:
return SequenceIterator(self)
def __str__(self) -> str:
"""Only print at most 10 query results"""
reminder = f" ... and {len(self) - 10} results remaining" if len(self) > 10 else ""
if self.cost:
return f"data: {list(map(str, self[:10]))}{reminder}, cost: {self.cost}"
return f"data: {list(map(str, self[:10]))}{reminder}"
__repr__ = __str__
class Hits(list):
ids: List[Union[str, int]]
distances: List[float]
def __init__(
self,
topk: int,
pks: Union[int, str],
distances: List[float],
fields: Dict[str, Tuple[List[Any], schema_pb2.FieldData]],
output_fields: List[str],
):
"""
Args:
fields(Dict[str, Tuple[List[Any], schema_pb2.FieldData]]):
field name to a tuple of topk data and field meta
"""
self.ids = pks
self.distances = distances
all_fields = list(fields.keys())
dynamic_fields = list(set(output_fields) - set(all_fields))
hits = []
for i in range(topk):
curr_field = {}
for fname, (data, field_meta) in fields.items():
if len(data) <= i:
curr_field[fname] = None
# Get dense vectors
if field_meta.type in (
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
):
dim = field_meta.vectors.dim
if field_meta.type in [DataType.BINARY_VECTOR]:
dim = dim // 8
elif field_meta.type in [DataType.BFLOAT16_VECTOR, DataType.FLOAT16_VECTOR]:
dim = dim * 2
curr_field[fname] = data[i * dim : (i + 1) * dim]
continue
# Get dynamic fields
if field_meta.type == DataType.JSON and field_meta.is_dynamic:
if len(dynamic_fields) > 0:
curr_field.update({k: v for k, v in data[i].items() if k in dynamic_fields})
continue
if fname in output_fields:
curr_field.update(data[i])
continue
# sparse float vector and other fields
curr_field[fname] = data[i]
hits.append(Hit(pks[i], distances[i], curr_field))
super().__init__(hits)
def __iter__(self) -> SequenceIterator:
return SequenceIterator(self)
def __str__(self) -> str:
"""Only print at most 10 query results"""
reminder = f" ... and {len(self) - 10} entities remaining" if len(self) > 10 else ""
return f"{list(map(str, self[:10]))!s}{reminder}"
__repr__ = __str__
class Hit:
id: Union[int, str]
distance: float
fields: Dict[str, Any]
def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any]):
self.id = pk
self.distance = distance
self.fields = fields
def __getattr__(self, item: str):
if item not in self.fields:
raise MilvusException(message=f"Field {item} is not in the hit entity")
return self.fields[item]
@property
def entity(self):
return self
@property
def pk(self) -> Union[str, int]:
return self.id
@property
def score(self) -> float:
return self.distance
def get(self, field_name: str) -> Any:
return self.fields.get(field_name)
def __str__(self) -> str:
return f"id: {self.id}, distance: {self.distance}, entity: {self.fields}"
__repr__ = __str__
def to_dict(self):
return {
"id": self.id,
"distance": self.distance,
"entity": self.fields,
}
def extract_array_row_data(
scalars: List[schema_pb2.ScalarField], element_type: DataType
) -> List[List[Any]]:
row = []
for ith_array in scalars:
if ith_array is None:
row.append(None)
continue
if element_type == DataType.INT64:
row.append(ith_array.long_data.data)
continue
if element_type == DataType.BOOL:
row.append(ith_array.bool_data.data)
continue
if element_type in (DataType.INT8, DataType.INT16, DataType.INT32):
row.append(ith_array.int_data.data)
continue
if element_type == DataType.FLOAT:
row.append(ith_array.float_data.data)
continue
if element_type == DataType.DOUBLE:
row.append(ith_array.double_data.data)
continue
if element_type in (DataType.STRING, DataType.VARCHAR):
row.append(ith_array.string_data.data)
continue
return row
def apply_valid_data(
data: List[Any], valid_data: Union[None, List[bool]], start: int, end: int
) -> List[Any]:
if valid_data:
for i, valid in enumerate(valid_data[start:end]):
if not valid:
data[i] = None
return data
class LoopBase:
def __init__(self):
self.__index = 0
def __iter__(self):
return self
def __getitem__(self, item: Any):
if isinstance(item, slice):
_start = item.start or 0
_end = min(item.stop, self.__len__()) if item.stop else self.__len__()
_step = item.step or 1
return [self.get__item(i) for i in range(_start, _end, _step)]
if item >= self.__len__():
msg = "Index out of range"
raise IndexError(msg)
return self.get__item(item)
def __next__(self):
while self.__index < self.__len__():
self.__index += 1
return self.__getitem__(self.__index - 1)
# iterate stop, raise Exception
self.__index = 0
raise StopIteration
def __str__(self):
return str(list(map(str, self.__getitem__(slice(0, 10)))))
@abc.abstractmethod
def get__item(self, item: Any):
raise NotImplementedError