# Copyright (C) 2019-2023 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import logging import threading import time import uuid from pathlib import Path from threading import Lock, Thread from typing import Callable, Optional from pymilvus.orm.schema import CollectionSchema from .bulk_writer import BulkWriter from .constants import ( MB, BulkFileType, ) logger = logging.getLogger("local_bulk_writer") logger.setLevel(logging.DEBUG) class LocalBulkWriter(BulkWriter): def __init__( self, schema: CollectionSchema, local_path: str, chunk_size: int = 128 * MB, file_type: BulkFileType = BulkFileType.PARQUET, config: Optional[dict] = None, **kwargs, ): super().__init__(schema, chunk_size, file_type, config, **kwargs) self._local_path = local_path self._uuid = str(uuid.uuid4()) self._flush_count = 0 self._working_thread = {} self._working_thread_lock = Lock() self._local_files = [] self._make_dir() @property def uuid(self): return self._uuid def __enter__(self): return self def __exit__(self, exc_type: object, exc_val: object, exc_tb: object): self._exit() def __del__(self): self._exit() def _exit(self): # wait flush thread if len(self._working_thread) > 0: for k, th in self._working_thread.items(): logger.info(f"Wait flush thread '{k}' to finish") th.join() self._rm_dir() def _make_dir(self): Path(self._local_path).mkdir(exist_ok=True) logger.info(f"Data path created: {self._local_path}") uidir = Path(self._local_path).joinpath(self._uuid) self._local_path = uidir Path(uidir).mkdir(exist_ok=True) logger.info(f"Data path created: {uidir}") def _rm_dir(self): # remove the uuid folder if it is empty if Path(self._local_path).exists() and not any(Path(self._local_path).iterdir()): Path(self._local_path).rmdir() logger.info(f"Delete local directory '{self._local_path}'") def append_row(self, row: dict, **kwargs): super().append_row(row, **kwargs) # only one thread can enter this section to persist data, # in the _flush() method, the buffer will be swapped to a new one. # in anync mode, the flush thread is asynchronously, other threads can # continue to append if the new buffer size is less than target size with self._working_thread_lock: if super().buffer_size > super().chunk_size: self.commit(_async=True) def commit(self, **kwargs): # _async=True, the flush thread is asynchronously while len(self._working_thread) > 0: logger.info( f"Previous flush action is not finished, {threading.current_thread().name} is waiting..." ) time.sleep(1.0) logger.info( f"Prepare to flush buffer, row_count: {super().buffer_row_count}, size: {super().buffer_size}" ) _async = kwargs.get("_async", False) call_back = kwargs.get("call_back") x = Thread(target=self._flush, args=(call_back,)) logger.info(f"Flush thread begin, name: {x.name}") self._working_thread[x.name] = x x.start() if not _async: logger.info("Wait flush to finish") x.join() super().commit() # reset the buffer size logger.info(f"Commit done with async={_async}") def _flush(self, call_back: Optional[Callable] = None): try: self._flush_count = self._flush_count + 1 target_path = Path.joinpath(self._local_path, str(self._flush_count)) old_buffer = super()._new_buffer() if old_buffer.row_count > 0: file_list = old_buffer.persist( local_path=str(target_path), buffer_size=self.buffer_size, buffer_row_count=self.buffer_row_count, ) self._local_files.append(file_list) if call_back: call_back(file_list) except Exception as e: logger.error(f"Failed to fulsh, error: {e}") raise e from e finally: del self._working_thread[threading.current_thread().name] logger.info(f"Flush thread finished, name: {threading.current_thread().name}") @property def data_path(self): return self._local_path @property def batch_files(self): return self._local_files
Memory