import os import torch import __main__ from dataclasses import dataclass from colbert.utils.utils import timestamp from .core_config import DefaultVal @dataclass class RunSettings: """ The defaults here have a special status in Run(), which initially calls assign_defaults(), so these aren't soft defaults in that specific context. """ overwrite: bool = DefaultVal(False) root: str = DefaultVal(os.path.join(os.getcwd(), "experiments")) experiment: str = DefaultVal("default") index_root: str = DefaultVal(None) name: str = DefaultVal(timestamp(daydir=True)) rank: int = DefaultVal(0) nranks: int = DefaultVal(1) amp: bool = DefaultVal(True) total_visible_gpus = torch.cuda.device_count() gpus: int = DefaultVal(total_visible_gpus) avoid_fork_if_possible: bool = DefaultVal(False) @property def gpus_(self): value = self.gpus if isinstance(value, int): value = list(range(value)) if isinstance(value, str): value = value.split(",") value = list(map(int, value)) value = sorted(list(set(value))) assert all( device_idx in range(0, self.total_visible_gpus) for device_idx in value ), value return value @property def index_root_(self): return self.index_root or os.path.join(self.root, self.experiment, "indexes/") @property def script_name_(self): if "__file__" in dir(__main__): cwd = os.path.abspath(os.getcwd()) script_path = os.path.abspath(__main__.__file__) root_path = os.path.abspath(self.root) if script_path.startswith(cwd): script_path = script_path[len(cwd) :] else: try: commonpath = os.path.commonpath([script_path, root_path]) script_path = script_path[len(commonpath) :] except: pass assert script_path.endswith(".py") script_name = script_path.replace("/", ".").strip(".")[:-3] assert len(script_name) > 0, (script_name, script_path, cwd) return script_name return "none" @property def path_(self): return os.path.join(self.root, self.experiment, self.script_name_, self.name) @property def device_(self): return self.gpus_[self.rank % self.nranks] @dataclass class TokenizerSettings: query_token_id: str = DefaultVal("[unused0]") doc_token_id: str = DefaultVal("[unused1]") query_token: str = DefaultVal("[Q]") doc_token: str = DefaultVal("[D]") @dataclass class ResourceSettings: checkpoint: str = DefaultVal(None) triples: str = DefaultVal(None) collection: str = DefaultVal(None) queries: str = DefaultVal(None) index_name: str = DefaultVal(None) @dataclass class DocSettings: dim: int = DefaultVal(128) doc_maxlen: int = DefaultVal(220) mask_punctuation: bool = DefaultVal(True) @dataclass class QuerySettings: query_maxlen: int = DefaultVal(32) attend_to_mask_tokens: bool = DefaultVal(False) interaction: str = DefaultVal("colbert") @dataclass class TrainingSettings: similarity: str = DefaultVal("cosine") bsize: int = DefaultVal(32) accumsteps: int = DefaultVal(1) lr: float = DefaultVal(3e-06) maxsteps: int = DefaultVal(500_000) save_every: int = DefaultVal(None) resume: bool = DefaultVal(False) ## NEW: warmup: int = DefaultVal(None) warmup_bert: int = DefaultVal(None) relu: bool = DefaultVal(False) nway: int = DefaultVal(2) use_ib_negatives: bool = DefaultVal(False) reranker: bool = DefaultVal(False) distillation_alpha: float = DefaultVal(1.0) ignore_scores: bool = DefaultVal(False) model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased') @dataclass class IndexingSettings: index_path: str = DefaultVal(None) index_bsize: int = DefaultVal(64) nbits: int = DefaultVal(1) kmeans_niters: int = DefaultVal(4) resume: bool = DefaultVal(False) pool_factor: int = DefaultVal(1) clustering_mode: str = DefaultVal("hierarchical") protected_tokens: int = DefaultVal(0) @property def index_path_(self): return self.index_path or os.path.join(self.index_root_, self.index_name) @dataclass class SearchSettings: ncells: int = DefaultVal(None) centroid_score_threshold: float = DefaultVal(None) ndocs: int = DefaultVal(None) load_index_with_mmap: bool = DefaultVal(False)
Memory