import datetime
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from pymilvus.client import entity_helper, utils
from pymilvus.client.abstract import Hits, LoopBase
from pymilvus.exceptions import (
MilvusException,
ParamError,
)
from .connections import Connections
from .constants import (
BATCH_SIZE,
CALC_DIST_BM25,
CALC_DIST_COSINE,
CALC_DIST_HAMMING,
CALC_DIST_IP,
CALC_DIST_JACCARD,
CALC_DIST_L2,
CALC_DIST_TANIMOTO,
DEFAULT_SEARCH_EXTENSION_RATE,
EF,
FIELDS,
GUARANTEE_TIMESTAMP,
INT64_MAX,
IS_PRIMARY,
ITERATOR_FIELD,
ITERATOR_SESSION_CP_FILE,
ITERATOR_SESSION_TS_FIELD,
MAX_BATCH_SIZE,
MAX_FILTERED_IDS_COUNT_ITERATION,
MAX_TRY_TIME,
METRIC_TYPE,
MILVUS_LIMIT,
OFFSET,
PARAMS,
PRINT_ITERATOR_CURSOR,
RADIUS,
RANGE_FILTER,
REDUCE_STOP_FOR_BEST,
UNLIMITED,
)
from .schema import CollectionSchema
from .types import DataType
from .utility import mkts_from_datetime
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
QueryIterator = TypeVar("QueryIterator")
SearchIterator = TypeVar("SearchIterator")
log = logging.getLogger(__name__)
def fall_back_to_latest_session_ts():
d = datetime.datetime.now()
return mkts_from_datetime(d, milliseconds=1000.0)
def assert_info(condition: bool, message: str):
if not condition:
raise MilvusException(message)
def io_operation(io_func: Callable[[Any], None], message: str):
try:
io_func()
except OSError as ose:
raise MilvusException(message=message) from ose
def extend_batch_size(batch_size: int, next_param: dict, to_extend_batch_size: bool) -> int:
extend_rate = 1
if to_extend_batch_size:
extend_rate = DEFAULT_SEARCH_EXTENSION_RATE
if EF in next_param[PARAMS]:
real_batch = min(MAX_BATCH_SIZE, batch_size * extend_rate, next_param[PARAMS][EF])
next_param[PARAMS][EF] = min(next_param[PARAMS][EF], real_batch)
return real_batch
return min(MAX_BATCH_SIZE, batch_size * extend_rate)
def check_set_flag(obj: Any, flag_name: str, kwargs: Dict[str, Any], key: str):
setattr(obj, flag_name, kwargs.get(key, False))
class QueryIterator:
def __init__(
self,
connection: Connections,
collection_name: str,
batch_size: Optional[int] = 1000,
limit: Optional[int] = -1,
expr: Optional[str] = None,
output_fields: Optional[List[str]] = None,
partition_names: Optional[List[str]] = None,
schema: Optional[CollectionSchema] = None,
timeout: Optional[float] = None,
**kwargs,
) -> QueryIterator:
self._conn = connection
self._collection_name = collection_name
self._output_fields = output_fields
self._partition_names = partition_names
self._schema = schema
self._timeout = timeout
self._session_ts = 0
self._kwargs = kwargs
self._kwargs[ITERATOR_FIELD] = "True"
self.__check_set_batch_size(batch_size)
self._limit = limit
self.__check_set_reduce_stop_for_best()
check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR)
self._returned_count = 0
self.__setup__pk_prop()
self.__set_up_expr(expr)
self._next_id = None
self._cache_id_in_use = NO_CACHE_ID
self._cp_file_handler = None
self.__set_up_ts_cp()
self.__seek_to_offset()
def __seek_to_offset(self):
# read pk cursor from cp file, no need to seek offset
if self._next_id is not None:
return
offset = self._kwargs.get(OFFSET, 0)
if offset > 0:
seek_params = self._kwargs.copy()
seek_params[OFFSET] = 0
seek_params[MILVUS_LIMIT] = offset
res = self._conn.query(
collection_name=self._collection_name,
expr=self._expr,
output_field=self._output_fields,
partition_name=self._partition_names,
timeout=self._timeout,
**seek_params,
)
result_index = min(len(res), offset)
self.__update_cursor(res[:result_index])
self._kwargs[OFFSET] = 0
def __init_cp_file_handler(self) -> bool:
mode = "w"
if self._cp_file_path.exists():
mode = "r+"
try:
self._cp_file_handler = self._cp_file_path.open(mode)
except OSError as ose:
raise MilvusException(
message=f"Failed to open cp file for iterator:{self._cp_file_path_str}"
) from ose
return mode == "r+"
def __save_mvcc_ts(self):
assert_info(
self._cp_file_handler is not None,
"Must init cp file handler before saving session_ts",
)
self._cp_file_handler.writelines(str(self._session_ts) + "\n")
def __save_pk_cursor(self):
if self._need_save_cp and self._next_id is not None:
if not self._cp_file_path.exists():
self._cp_file_handler.close()
self._cp_file_handler = self._cp_file_path.open("w")
self._buffer_cursor_lines_number = 0
self.__save_mvcc_ts()
log.warning(
"iterator cp file is not existed any more, recreate for iteration, "
"do not remove this file manually!"
)
if self._buffer_cursor_lines_number >= 100:
self._cp_file_handler.seek(0)
self._cp_file_handler.truncate()
log.info(
"cursor lines in cp file has exceeded 100 lines, truncate the file and rewrite"
)
self._buffer_cursor_lines_number = 0
self._cp_file_handler.writelines(str(self._next_id) + "\n")
self._cp_file_handler.flush()
self._buffer_cursor_lines_number += 1
def __check_set_reduce_stop_for_best(self):
if self._kwargs.get(REDUCE_STOP_FOR_BEST, True):
self._kwargs[REDUCE_STOP_FOR_BEST] = "True"
else:
self._kwargs[REDUCE_STOP_FOR_BEST] = "False"
def __check_set_batch_size(self, batch_size: int):
if batch_size < 0:
raise ParamError(message="batch size cannot be less than zero")
if batch_size > MAX_BATCH_SIZE:
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}")
self._kwargs[BATCH_SIZE] = batch_size
self._kwargs[MILVUS_LIMIT] = batch_size
# rely on pk prop, so this method should be called after __setup__pk_prop
def __set_up_expr(self, expr: str):
if expr is not None:
self._expr = expr
elif self._pk_str:
self._expr = self._pk_field_name + ' != ""'
else:
self._expr = self._pk_field_name + " < " + str(INT64_MAX)
def __setup_ts_by_request(self):
init_ts_kwargs = self._kwargs.copy()
init_ts_kwargs[OFFSET] = 0
init_ts_kwargs[MILVUS_LIMIT] = 1
# just to set up mvccTs for iterator, no need correct limit
res = self._conn.query(
collection_name=self._collection_name,
expr=self._expr,
output_field=self._output_fields,
partition_name=self._partition_names,
timeout=self._timeout,
**init_ts_kwargs,
)
if res is None:
raise MilvusException(
message="failed to connect to milvus for setting up "
"mvccTs, check milvus servers' status"
)
if res.extra is not None:
self._session_ts = res.extra.get(ITERATOR_SESSION_TS_FIELD, 0)
if self._session_ts <= 0:
log.warning("failed to get mvccTs from milvus server, use client-side ts instead")
self._session_ts = fall_back_to_latest_session_ts()
self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts
def __set_up_ts_cp(self):
self._buffer_cursor_lines_number = 0
self._cp_file_path_str = self._kwargs.get(ITERATOR_SESSION_CP_FILE, None)
self._cp_file_path = None
# no input cp_file, set up mvccTs by query request
if self._cp_file_path_str is None:
self._need_save_cp = False
self.__setup_ts_by_request()
else:
self._need_save_cp = True
self._cp_file_path = Path(self._cp_file_path_str)
if not self.__init_cp_file_handler():
# input cp file is empty, set up mvccTs by query request
self.__setup_ts_by_request()
io_operation(self.__save_mvcc_ts, "Failed to save mvcc ts")
else:
try:
# input cp file is not emtpy, init mvccTs by reading cp file
lines = self._cp_file_handler.readlines()
line_count = len(lines)
if line_count < 2:
raise ParamError(
message=f"input cp file:{self._cp_file_path_str} should contain "
f"at least two lines, but only:{line_count} lines"
)
self._session_ts = int(lines[0])
self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts
if line_count > 1:
self._buffer_cursor_lines_number = line_count - 1
self._next_id = lines[self._buffer_cursor_lines_number].strip()
except OSError as ose:
raise MilvusException(
message=f"Failed to read cp info from file:{self._cp_file_path_str}"
) from ose
except ValueError as e:
raise ParamError(message=f"cannot parse input cp session_ts:{lines[0]}") from e
def __maybe_cache(self, result: List):
if len(result) < 2 * self._kwargs[BATCH_SIZE]:
return
start = self._kwargs[BATCH_SIZE]
cache_result = result[start:]
cache_id = iterator_cache.cache(cache_result, NO_CACHE_ID)
self._cache_id_in_use = cache_id
def __is_res_sufficient(self, res: List):
return res is not None and len(res) >= self._kwargs[BATCH_SIZE]
def next(self):
cached_res = iterator_cache.fetch_cache(self._cache_id_in_use)
ret = None
if self.__is_res_sufficient(cached_res):
ret = cached_res[0 : self._kwargs[BATCH_SIZE]]
res_to_cache = cached_res[self._kwargs[BATCH_SIZE] :]
iterator_cache.cache(res_to_cache, self._cache_id_in_use)
else:
iterator_cache.release_cache(self._cache_id_in_use)
current_expr = self.__setup_next_expr()
if self._print_iterator_cursor:
log.info(f"query_iterator_next_expr:{current_expr}")
res = self._conn.query(
collection_name=self._collection_name,
expr=current_expr,
output_fields=self._output_fields,
partition_names=self._partition_names,
timeout=self._timeout,
**self._kwargs,
)
self.__maybe_cache(res)
ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))]
ret = self.__check_reached_limit(ret)
self.__update_cursor(ret)
io_operation(self.__save_pk_cursor, "failed to save pk cursor")
self._returned_count += len(ret)
return ret
def __check_reached_limit(self, ret: List):
if self._limit == UNLIMITED:
return ret
left_count = self._limit - self._returned_count
if left_count >= len(ret):
return ret
# has exceeded the limit, cut off the result and return
return ret[0:left_count]
def __setup__pk_prop(self):
fields = self._schema[FIELDS]
for field in fields:
if field.get(IS_PRIMARY):
if field["type"] == DataType.VARCHAR:
self._pk_str = True
else:
self._pk_str = False
self._pk_field_name = field["name"]
break
if self._pk_field_name is None or self._pk_field_name == "":
raise MilvusException(message="schema must contain pk field, broke")
def __setup_next_expr(self) -> str:
current_expr = self._expr
if self._next_id is None:
return current_expr
filtered_pk_str = ""
if self._pk_str:
filtered_pk_str = f'{self._pk_field_name} > "{self._next_id}"'
else:
filtered_pk_str = f"{self._pk_field_name} > {self._next_id}"
if current_expr is None or len(current_expr) == 0:
return filtered_pk_str
return "(" + current_expr + ")" + " and " + filtered_pk_str
def __update_cursor(self, res: List) -> None:
if len(res) == 0:
return
self._next_id = res[-1][self._pk_field_name]
def close(self) -> None:
# release cache in use
iterator_cache.release_cache(self._cache_id_in_use)
if self._cp_file_handler is not None:
def inner_close():
self._cp_file_handler.close()
self._cp_file_path.unlink()
log.info(f"removed cp file:{self._cp_file_path_str} for query iterator")
io_operation(
inner_close, f"failed to clear cp file:{self._cp_file_path_str} for query iterator"
)
def metrics_positive_related(metrics: str) -> bool:
if metrics in [CALC_DIST_L2, CALC_DIST_JACCARD, CALC_DIST_HAMMING, CALC_DIST_TANIMOTO]:
return True
if metrics in [CALC_DIST_IP, CALC_DIST_COSINE, CALC_DIST_BM25]:
return False
raise MilvusException(message=f"unsupported metrics type for search iteration: {metrics}")
class SearchPage(LoopBase):
"""Since we only support nq=1 in search iteration, so search iteration response
should be different from raw response of search operation"""
def __init__(self, res: Hits, session_ts: Optional[int] = 0):
super().__init__()
self._session_ts = session_ts
self._results = []
if res is not None:
self._results.append(res)
def get_session_ts(self):
return self._session_ts
def get_res(self):
return self._results
def __len__(self):
length = 0
for res in self._results:
length += len(res)
return length
def get__item(self, idx: Any):
if len(self._results) == 0:
return None
if idx >= self.__len__():
msg = "Index out of range"
raise IndexError(msg)
index = 0
ret = None
for res in self._results:
if index + len(res) <= idx:
index += len(res)
else:
ret = res[idx - index]
break
return ret
def merge(self, others: List[Hits]):
if others is not None:
for other in others:
self._results.append(other)
def ids(self):
ids = []
for res in self._results:
for hit in res:
ids.append(hit.id)
return ids
def distances(self):
distances = []
for res in self._results:
for hit in res:
distances.append(hit.distance)
return distances
class SearchIterator:
def __init__(
self,
connection: Connections,
collection_name: str,
data: Union[List, utils.SparseMatrixInputType],
ann_field: str,
param: Dict,
batch_size: Optional[int] = 1000,
limit: Optional[int] = UNLIMITED,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
round_decimal: int = -1,
schema: Optional[CollectionSchema] = None,
**kwargs,
) -> SearchIterator:
rows = entity_helper.get_input_num_rows(data)
if rows > 1:
raise ParamError(
message="Not support search iteration over multiple vectors at present"
)
if rows == 0:
raise ParamError(message="vector_data for search cannot be empty")
self._conn = connection
self._iterator_params = {
"collection_name": collection_name,
"data": data,
"ann_field": ann_field,
BATCH_SIZE: batch_size,
"output_fields": output_fields,
"partition_names": partition_names,
"timeout": timeout,
"round_decimal": round_decimal,
}
self._expr = expr
self.__check_set_params(param)
self.__check_for_special_index_param()
self._kwargs = kwargs
self._kwargs[ITERATOR_FIELD] = "True"
self._filtered_ids = []
self._filtered_distance = None
self._schema = schema
self._limit = limit
self._returned_count = 0
self.__check_metrics()
self.__check_offset()
self.__check_rm_range_search_parameters()
self.__setup__pk_prop()
check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR)
self.__init_search_iterator()
def __init_search_iterator(self):
init_page = self.__execute_next_search(self._param, self._expr, False)
self._session_ts = init_page.get_session_ts()
if self._session_ts <= 0:
log.warning("failed to set up mvccTs from milvus server, use client-side ts instead")
self._session_ts = fall_back_to_latest_session_ts()
self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts
if len(init_page) == 0:
message = (
"Cannot init search iterator because init page contains no matched rows, "
"please check the radius and range_filter set up by searchParams"
)
LOGGER.error(message)
self._cache_id = NO_CACHE_ID
self._init_success = False
return
self._cache_id = iterator_cache.cache(init_page, NO_CACHE_ID)
self.__set_up_range_parameters(init_page)
self.__update_filtered_ids(init_page)
self._init_success = True
def __update_width(self, page: SearchPage):
first_hit, last_hit = page[0], page[-1]
if metrics_positive_related(self._param[METRIC_TYPE]):
self._width = last_hit.distance - first_hit.distance
else:
self._width = first_hit.distance - last_hit.distance
if self._width == 0.0:
self._width = 0.05
# enable a minimum value for width to avoid radius and range_filter equal error
def __set_up_range_parameters(self, page: SearchPage):
self.__update_width(page)
self._tail_band = page[-1].distance
LOGGER.debug(
f"set up init parameter for searchIterator width:{self._width} tail_band:{self._tail_band}"
)
def __check_reached_limit(self) -> bool:
if self._limit == UNLIMITED or self._returned_count < self._limit:
return False
LOGGER.debug(
f"reached search limit:{self._limit}, returned_count:{self._returned_count}, directly return"
)
return True
def __check_set_params(self, param: Dict):
if param is None:
self._param = {}
else:
self._param = deepcopy(param)
if PARAMS not in self._param:
self._param[PARAMS] = {}
def __check_for_special_index_param(self):
if (
EF in self._param[PARAMS]
and self._param[PARAMS][EF] < self._iterator_params[BATCH_SIZE]
):
raise MilvusException(
message="When using hnsw index, provided ef must be larger than or equal to batch size"
)
def __setup__pk_prop(self):
fields = self._schema[FIELDS]
for field in fields:
if field.get(IS_PRIMARY):
if field["type"] == DataType.VARCHAR:
self._pk_str = True
else:
self._pk_str = False
self._pk_field_name = field["name"]
break
if self._pk_field_name is None or self._pk_field_name == "":
raise ParamError(message="schema must contain pk field, broke")
def __check_metrics(self):
if self._param[METRIC_TYPE] is None or self._param[METRIC_TYPE] == "":
raise ParamError(message="must specify metrics type for search iterator")
"""we use search && range search to implement search iterator,
so range search parameters are disabled to clients"""
def __check_rm_range_search_parameters(self):
if (
(PARAMS in self._param)
and (RADIUS in self._param[PARAMS])
and (RANGE_FILTER in self._param[PARAMS])
):
radius = self._param[PARAMS][RADIUS]
range_filter = self._param[PARAMS][RANGE_FILTER]
if metrics_positive_related(self._param[METRIC_TYPE]) and radius <= range_filter:
raise MilvusException(
message=f"for metrics:{self._param[METRIC_TYPE]}, radius must be "
f"larger than range_filter, please adjust your parameter"
)
if not metrics_positive_related(self._param[METRIC_TYPE]) and radius >= range_filter:
raise MilvusException(
message=f"for metrics:{self._param[METRIC_TYPE]}, radius must be "
f"smalled than range_filter, please adjust your parameter"
)
def __check_offset(self):
if self._kwargs.get(OFFSET, 0) != 0:
raise ParamError(message="Not support offset when searching iteration")
def __update_filtered_ids(self, res: SearchPage):
if len(res) == 0:
return
last_hit = res[-1]
if last_hit is None:
return
if last_hit.distance != self._filtered_distance:
self._filtered_ids = [] # distance has changed, clear filter_ids array
self._filtered_distance = last_hit.distance # renew the distance for filtering
for hit in res:
if hit.distance == last_hit.distance:
self._filtered_ids.append(hit.id)
if len(self._filtered_ids) > MAX_FILTERED_IDS_COUNT_ITERATION:
raise MilvusException(
message=f"filtered ids length has accumulated to more than "
f"{MAX_FILTERED_IDS_COUNT_ITERATION!s}, "
f"there is a danger of overly memory consumption"
)
def __is_cache_enough(self, count: int) -> bool:
cached_page = iterator_cache.fetch_cache(self._cache_id)
return cached_page is not None and len(cached_page) >= count
def __extract_page_from_cache(self, count: int) -> SearchPage:
cached_page = iterator_cache.fetch_cache(self._cache_id)
if cached_page is None or len(cached_page) < count:
raise ParamError(
message=f"Wrong, try to extract {count} result from cache, "
f"more than {len(cached_page)} there must be sth wrong with code"
)
ret_page_res = cached_page[0:count]
ret_page = SearchPage(ret_page_res)
left_cache_page = SearchPage(cached_page[count:])
iterator_cache.cache(left_cache_page, self._cache_id)
return ret_page
def __push_new_page_to_cache(self, page: SearchPage) -> int:
if page is None:
raise ParamError(message="Cannot push None page into cache")
cached_page: SearchPage = iterator_cache.fetch_cache(self._cache_id)
if cached_page is None:
iterator_cache.cache(page, self._cache_id)
cached_page = page
else:
cached_page.merge(page.get_res())
return len(cached_page)
def next(self):
# 0. check reached limit
if not self._init_success or self.__check_reached_limit():
return SearchPage(None)
ret_len = self._iterator_params[BATCH_SIZE]
if self._limit is not UNLIMITED:
left_len = self._limit - self._returned_count
ret_len = min(left_len, ret_len)
# 1. if cached page is sufficient, directly return
if self.__is_cache_enough(ret_len):
ret_page = self.__extract_page_from_cache(ret_len)
self._returned_count += len(ret_page)
return ret_page
# 2. if cached page not enough, try to fill the result by probing with constant width
# until finish filling or exceeding max trial time: 10
new_page = self.__try_search_fill()
cached_page_len = self.__push_new_page_to_cache(new_page)
ret_len = min(cached_page_len, ret_len)
ret_page = self.__extract_page_from_cache(ret_len)
if len(ret_page) == self._iterator_params[BATCH_SIZE]:
self.__update_width(ret_page)
# 3. update filter ids to avoid returning result repeatedly
self._returned_count += ret_len
return ret_page
def __try_search_fill(self) -> SearchPage:
final_page = SearchPage(None)
try_time = 0
coefficient = 1
while True:
next_params = self.__next_params(coefficient)
next_expr = self.__filtered_duplicated_result_expr(self._expr)
new_page = self.__execute_next_search(next_params, next_expr, True)
self.__update_filtered_ids(new_page)
try_time += 1
if len(new_page) > 0:
final_page.merge(new_page.get_res())
self._tail_band = new_page[-1].distance
if len(final_page) >= self._iterator_params[BATCH_SIZE]:
break
if try_time > MAX_TRY_TIME:
LOGGER.warning(f"Search probe exceed max try times:{MAX_TRY_TIME} directly break")
break
# if there's a ring containing no vectors matched, then we need to extend
# the ring continually to avoid empty ring problem
coefficient += 1
return final_page
def __execute_next_search(
self, next_params: dict, next_expr: str, to_extend_batch: bool
) -> SearchPage:
if self._print_iterator_cursor:
log.info(f"search_iterator_next_expr:{next_expr}, next_params:{next_params}")
res = self._conn.search(
self._iterator_params["collection_name"],
self._iterator_params["data"],
self._iterator_params["ann_field"],
next_params,
extend_batch_size(self._iterator_params[BATCH_SIZE], next_params, to_extend_batch),
next_expr,
self._iterator_params["partition_names"],
self._iterator_params["output_fields"],
self._iterator_params["round_decimal"],
timeout=self._iterator_params["timeout"],
schema=self._schema,
**self._kwargs,
)
return SearchPage(res[0], res.get_session_ts())
# at present, the range_filter parameter means 'larger/less and equal',
# so there would be vectors with same distances returned multiple times in different pages
# we need to refine and remove these results before returning
def __filtered_duplicated_result_expr(self, expr: str):
if len(self._filtered_ids) == 0:
return expr
filtered_ids_str = ""
for filtered_id in self._filtered_ids:
if self._pk_str:
filtered_ids_str += f'"{filtered_id}",'
else:
filtered_ids_str += f"{filtered_id},"
filtered_ids_str = filtered_ids_str[0:-1]
if len(filtered_ids_str) > 0:
if expr is not None and len(expr) > 0:
filter_expr = f" and {self._pk_field_name} not in [{filtered_ids_str}]"
return "(" + expr + ")" + filter_expr
return f"{self._pk_field_name} not in [{filtered_ids_str}]"
return expr
def __next_params(self, coefficient: int):
coefficient = max(1, coefficient)
next_params = deepcopy(self._param)
if metrics_positive_related(self._param[METRIC_TYPE]):
next_radius = self._tail_band + self._width * coefficient
if RADIUS in self._param[PARAMS] and next_radius > self._param[PARAMS][RADIUS]:
next_params[PARAMS][RADIUS] = self._param[PARAMS][RADIUS]
else:
next_params[PARAMS][RADIUS] = next_radius
else:
next_radius = self._tail_band - self._width * coefficient
if RADIUS in self._param[PARAMS] and next_radius < self._param[PARAMS][RADIUS]:
next_params[PARAMS][RADIUS] = self._param[PARAMS][RADIUS]
else:
next_params[PARAMS][RADIUS] = next_radius
next_params[PARAMS][RANGE_FILTER] = self._tail_band
LOGGER.debug(
f"next round search iteration radius:{next_params[PARAMS][RADIUS]},"
f"range_filter:{next_params[PARAMS][RANGE_FILTER]},"
f"coefficient:{coefficient}"
)
return next_params
def close(self):
iterator_cache.release_cache(self._cache_id)
class IteratorCache:
def __init__(self) -> None:
self._cache_id = 0
self._cache_map = {}
def cache(self, result: Any, cache_id: int):
if cache_id == NO_CACHE_ID:
self._cache_id += 1
cache_id = self._cache_id
self._cache_map[cache_id] = result
return cache_id
def fetch_cache(self, cache_id: int):
return self._cache_map.get(cache_id, None)
def release_cache(self, cache_id: int):
if self._cache_map.get(cache_id, None) is not None:
self._cache_map.pop(cache_id)
NO_CACHE_ID = -1
# Singleton Mode in Python
iterator_cache = IteratorCache()