""" Implements a thread pool for parallel copying of files. """ from __future__ import unicode_literals import typing import threading from six.moves.queue import Queue from .copy import copy_file_internal, copy_modified_time from .errors import BulkCopyFailed from .tools import copy_file_data if typing.TYPE_CHECKING: from typing import IO, List, Optional, Text, Tuple, Type from types import TracebackType from .base import FS class _Worker(threading.Thread): """Worker thread that pulls tasks from a queue.""" def __init__(self, copier): # type (Copier) -> None self.copier = copier super(_Worker, self).__init__() self.daemon = True def run(self): # type () -> None queue = self.copier.queue while True: task = queue.get(block=True) try: if task is None: break # Sentinel to exit thread task() except Exception as error: self.copier.add_error(error) finally: queue.task_done() class _Task(object): """Base class for a task.""" def __call__(self): # type: () -> None """Task implementation.""" class _CopyTask(_Task): """A callable that copies from one file another.""" def __init__(self, src_file, dst_file): # type: (IO, IO) -> None self.src_file = src_file self.dst_file = dst_file def __call__(self): # type: () -> None try: copy_file_data(self.src_file, self.dst_file, chunk_size=1024 * 1024) finally: try: self.src_file.close() finally: self.dst_file.close() class Copier(object): """Copy files in worker threads.""" def __init__(self, num_workers=4, preserve_time=False): # type: (int, bool) -> None if num_workers < 0: raise ValueError("num_workers must be >= 0") self.num_workers = num_workers self.preserve_time = preserve_time self.all_tasks = [] # type: List[Tuple[FS, Text, FS, Text]] self.queue = None # type: Optional[Queue[_Task]] self.workers = [] # type: List[_Worker] self.errors = [] # type: List[Exception] self.running = False def start(self): """Start the workers.""" if self.num_workers: self.queue = Queue(maxsize=self.num_workers) self.workers = [_Worker(self) for _ in range(self.num_workers)] for worker in self.workers: worker.start() self.running = True def stop(self): """Stop the workers (will block until they are finished).""" if self.running and self.num_workers: # Notify the workers that all tasks have arrived # and wait for them to finish. for _worker in self.workers: self.queue.put(None) for worker in self.workers: worker.join() # If the "last modified" time is to be preserved, do it now. if self.preserve_time: for args in self.all_tasks: copy_modified_time(*args) # Free up references held by workers del self.workers[:] self.queue.join() self.running = False def add_error(self, error): """Add an exception raised by a task.""" self.errors.append(error) def __enter__(self): self.start() return self def __exit__( self, exc_type, # type: Optional[Type[BaseException]] exc_value, # type: Optional[BaseException] traceback, # type: Optional[TracebackType] ): self.stop() if traceback is None and self.errors: raise BulkCopyFailed(self.errors) def copy(self, src_fs, src_path, dst_fs, dst_path, preserve_time=False): # type: (FS, Text, FS, Text, bool) -> None """Copy a file from one fs to another.""" if self.queue is None: # This should be the most performant for a single-thread copy_file_internal( src_fs, src_path, dst_fs, dst_path, preserve_time=self.preserve_time ) else: self.all_tasks.append((src_fs, src_path, dst_fs, dst_path)) src_file = src_fs.openbin(src_path, "r") try: dst_file = dst_fs.openbin(dst_path, "w") except Exception: src_file.close() raise task = _CopyTask(src_file, dst_file) self.queue.put(task)
Memory