import os import torch import ujson from huggingface_hub import hf_hub_download from huggingface_hub.utils import RepositoryNotFoundError import dataclasses from typing import Any from collections import defaultdict from dataclasses import dataclass, fields from colbert.utils.utils import timestamp, torch_load_dnn from utility.utils.save_metadata import get_metadata_only from .core_config import * @dataclass class BaseConfig(CoreConfig): @classmethod def from_existing(cls, *sources): kw_args = {} for source in sources: if source is None: continue local_kw_args = dataclasses.asdict(source) local_kw_args = {k: local_kw_args[k] for k in source.assigned} kw_args = {**kw_args, **local_kw_args} obj = cls(**kw_args) return obj @classmethod def from_deprecated_args(cls, args): obj = cls() ignored = obj.configure(ignore_unrecognized=True, **args) return obj, ignored @classmethod def from_path(cls, name): with open(name) as f: args = ujson.load(f) if "config" in args: args = args["config"] return cls.from_deprecated_args( args ) # the new, non-deprecated version functions the same at this level. @classmethod def load_from_checkpoint(cls, checkpoint_path): if checkpoint_path.endswith(".dnn"): dnn = torch_load_dnn(checkpoint_path) config, _ = cls.from_deprecated_args(dnn.get("arguments", {})) # TODO: FIXME: Decide if the line below will have any unintended consequences. We don't want to overwrite those! config.set("checkpoint", checkpoint_path) return config try: checkpoint_path = hf_hub_download( repo_id=checkpoint_path, filename="artifact.metadata" ).split("artifact")[0] except Exception: pass loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata") if os.path.exists(loaded_config_path): loaded_config, _ = cls.from_path(loaded_config_path) loaded_config.set("checkpoint", checkpoint_path) return loaded_config return ( None # can happen if checkpoint_path is something like 'bert-base-uncased' ) @classmethod def load_from_index(cls, index_path): # FIXME: We should start here with initial_config = ColBERTConfig(config, Run().config). # This should allow us to say initial_config.index_root. Then, below, set config = Config(..., initial_c) # default_index_root = os.path.join(Run().root, Run().experiment, 'indexes/') # index_path = os.path.join(default_index_root, index_path) # CONSIDER: No more plan/metadata.json. Only metadata.json to avoid weird issues when loading an index. try: metadata_path = os.path.join(index_path, "metadata.json") loaded_config, _ = cls.from_path(metadata_path) except: metadata_path = os.path.join(index_path, "plan.json") loaded_config, _ = cls.from_path(metadata_path) return loaded_config def save(self, path, overwrite=False): assert overwrite or not os.path.exists(path), path with open(path, "w") as f: args = self.export() # dict(self.__config) args["meta"] = get_metadata_only() args["meta"]["version"] = "colbert-v0.4" # TODO: Add git_status details.. It can't be too large! It should be a path that Runs() saves on exit, maybe! f.write(ujson.dumps(args, indent=4) + "\n") def save_for_checkpoint(self, checkpoint_path): assert not checkpoint_path.endswith( ".dnn" ), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format." output_config_path = os.path.join(checkpoint_path, "artifact.metadata") self.save(output_config_path, overwrite=True)
Memory