# Copyright (C) 2019-2023 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
import json
import logging
from threading import Lock
from typing import Optional
import numpy as np
from pymilvus.client.types import DataType
from pymilvus.exceptions import MilvusException
from pymilvus.orm.schema import CollectionSchema, FieldSchema
from .buffer import (
Buffer,
)
from .constants import (
TYPE_SIZE,
TYPE_VALIDATOR,
BulkFileType,
)
logger = logging.getLogger("bulk_writer")
logger.setLevel(logging.DEBUG)
class BulkWriter:
def __init__(
self,
schema: CollectionSchema,
chunk_size: int,
file_type: BulkFileType,
config: Optional[dict] = None,
**kwargs,
):
self._schema = schema
self._buffer_size = 0
self._buffer_row_count = 0
self._total_row_count = 0
self._file_type = file_type
self._buffer_lock = Lock()
self._config = config
# the old parameter segment_size is changed to chunk_size, compatible with the legacy code
self._chunk_size = chunk_size
segment_size = kwargs.get("segment_size", 0)
if segment_size > 0:
self._chunk_size = segment_size
if len(self._schema.fields) == 0:
self._throw("collection schema fields list is empty")
if self._schema.primary_field is None:
self._throw("primary field is null")
self._buffer = None
self._new_buffer()
@property
def buffer_size(self):
return self._buffer_size
@property
def buffer_row_count(self):
return self._buffer_row_count
@property
def total_row_count(self):
return self._total_row_count
@property
def chunk_size(self):
return self._chunk_size
def _new_buffer(self):
old_buffer = self._buffer
with self._buffer_lock:
self._buffer = Buffer(self._schema, self._file_type, self._config)
return old_buffer
def append_row(self, row: dict, **kwargs):
self._verify_row(row)
with self._buffer_lock:
self._buffer.append_row(row)
def commit(self, **kwargs):
with self._buffer_lock:
self._buffer_size = 0
self._buffer_row_count = 0
@property
def data_path(self):
return ""
def _try_convert_json(self, field_name: str, obj: object):
if isinstance(obj, str):
try:
return json.loads(obj)
except Exception as e:
self._throw(
f"Illegal JSON value for field '{field_name}', type mismatch or illegal format, error: {e}"
)
return obj
def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)
def _verify_vector(self, x: object, field: FieldSchema):
dtype = DataType(field.dtype)
validator = TYPE_VALIDATOR[dtype.name]
if dtype != DataType.SPARSE_FLOAT_VECTOR:
dim = field.params["dim"]
try:
origin_list = validator(x, dim)
if dtype == DataType.FLOAT_VECTOR:
return origin_list, dim * 4 # for float vector, each dim occupies 4 bytes
if dtype == DataType.BINARY_VECTOR:
return origin_list, dim / 8 # for binary vector, 8 dim occupies 1 byte
return origin_list, dim * 2 # for float16 vector, each dim occupies 2 bytes
except MilvusException as e:
self._throw(f"Illegal vector data for vector field: '{field.name}': {e.message}")
else:
try:
validator(x)
return x, len(x) * 12 # for sparse vector, each key-value is int-float, 12 bytes
except MilvusException as e:
self._throw(f"Illegal vector data for vector field: '{field.name}': {e.message}")
def _verify_json(self, x: object, field: FieldSchema):
size = 0
validator = TYPE_VALIDATOR[DataType.JSON.name]
if isinstance(x, str):
size = len(x)
x = self._try_convert_json(field.name, x)
elif validator(x):
size = len(json.dumps(x))
else:
self._throw(f"Illegal JSON value for field '{field.name}', type mismatch")
return x, size
def _verify_varchar(self, x: object, field: FieldSchema):
max_len = field.params["max_length"]
validator = TYPE_VALIDATOR[DataType.VARCHAR.name]
if not validator(x, max_len):
self._throw(
f"Illegal varchar value for field '{field.name}',"
f" length exceeds {max_len} or type mismatch"
)
return len(x)
def _verify_array(self, x: object, field: FieldSchema):
max_capacity = field.params["max_capacity"]
element_type = field.element_type
validator = TYPE_VALIDATOR[DataType.ARRAY.name]
if not validator(x, max_capacity):
self._throw(
f"Illegal array value for field '{field.name}', length exceeds capacity or type mismatch"
)
row_size = 0
if element_type.name in TYPE_SIZE:
row_size = TYPE_SIZE[element_type.name] * len(x)
elif element_type == DataType.VARCHAR:
for ele in x:
row_size = row_size + self._verify_varchar(ele, field)
else:
self._throw(f"Unsupported element type for array field '{field.name}'")
return row_size
def _verify_row(self, row: dict):
if not isinstance(row, dict):
self._throw("The input row must be a dict object")
row_size = 0
for field in self._schema.fields:
if field.is_primary and field.auto_id:
if field.name in row:
self._throw(
f"The primary key field '{field.name}' is auto-id, no need to provide"
)
else:
continue
if field.is_function_output:
if field.name in row:
self._throw(f"Field '{field.name}' is function output, no need to provide")
else:
continue
if field.name not in row:
self._throw(f"The field '{field.name}' is missed in the row")
dtype = DataType(field.dtype)
# deal with null (None)
if field.nullable and row[field.name] is None:
if (
field.default_value is not None
and field.default_value.WhichOneof("data") is not None
):
# set default value
data_type = field.default_value.WhichOneof("data")
row[field.name] = getattr(field.default_value, data_type)
else:
# skip field check if the field is null
continue
if dtype in {
DataType.BINARY_VECTOR,
DataType.FLOAT_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.SPARSE_FLOAT_VECTOR,
}:
origin_list, byte_len = self._verify_vector(row[field.name], field)
row[field.name] = origin_list
row_size = row_size + byte_len
elif dtype == DataType.VARCHAR:
row_size = row_size + self._verify_varchar(row[field.name], field)
elif dtype == DataType.JSON:
row[field.name], size = self._verify_json(row[field.name], field)
row_size = row_size + size
elif dtype == DataType.ARRAY:
if isinstance(row[field.name], np.ndarray):
row[field.name] = row[field.name].tolist()
row_size = row_size + self._verify_array(row[field.name], field)
else:
if isinstance(row[field.name], np.generic):
row[field.name] = row[field.name].item()
validator = TYPE_VALIDATOR[dtype.name]
if not validator(row[field.name]):
self._throw(
f"Illegal scalar value for field '{field.name}', value overflow or type mismatch"
)
row_size = row_size + TYPE_SIZE[dtype.name]
with self._buffer_lock:
self._buffer_size = self._buffer_size + row_size
self._buffer_row_count = self._buffer_row_count + 1
self._total_row_count = self._total_row_count + 1