import glob
import json
import os
import shutil
import time
from pathlib import Path
from typing import Optional, Union
import pyarrow as pa
import datasets
import datasets.config
import datasets.data_files
from datasets.naming import camelcase_to_snakecase, filenames_for_dataset_split
logger = datasets.utils.logging.get_logger(__name__)
def _get_modification_time(cached_directory_path):
return (Path(cached_directory_path)).stat().st_mtime
def _find_hash_in_cache(
dataset_name: str,
config_name: Optional[str],
cache_dir: Optional[str],
config_kwargs: dict,
custom_features: Optional[datasets.Features],
) -> tuple[str, str, str]:
if config_name or config_kwargs or custom_features:
config_id = datasets.BuilderConfig(config_name or "default").create_config_id(
config_kwargs=config_kwargs, custom_features=custom_features
)
else:
config_id = None
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
namespace_and_dataset_name = dataset_name.split("/")
namespace_and_dataset_name[-1] = camelcase_to_snakecase(namespace_and_dataset_name[-1])
cached_relative_path = "___".join(namespace_and_dataset_name)
cached_datasets_directory_path_root = os.path.join(cache_dir, cached_relative_path)
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(
os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*")
)
if os.path.isdir(cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not cached_directory_paths:
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
if os.path.isdir(cached_directory_path)
]
available_configs = sorted(
{Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths}
)
raise ValueError(
f"Couldn't find cache for {dataset_name}"
+ (f" for config '{config_id}'" if config_id else "")
+ (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "")
)
# get most recent
cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1])
version, hash = cached_directory_path.parts[-2:]
other_configs = [
Path(_cached_directory_path).parts[-3]
for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
if os.path.isdir(_cached_directory_path)
and (
config_kwargs
or custom_features
or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
== Path(_cached_directory_path).parts[-3] # no extra params => config_id == config_name
)
]
if not config_id and len(other_configs) > 1:
raise ValueError(
f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}"
f"\nPlease specify which configuration to reload from the cache, e.g."
f"\n\tload_dataset('{dataset_name}', '{other_configs[0]}')"
)
config_name = cached_directory_path.parts[-3]
warning_msg = (
f"Found the latest cached dataset configuration '{config_name}' at {cached_directory_path} "
f"(last modified on {time.ctime(_get_modification_time(cached_directory_path))})."
)
logger.warning(warning_msg)
return config_name, version, hash
class Cache(datasets.ArrowBasedBuilder):
def __init__(
self,
cache_dir: Optional[str] = None,
dataset_name: Optional[str] = None,
config_name: Optional[str] = None,
version: Optional[str] = "0.0.0",
hash: Optional[str] = None,
base_path: Optional[str] = None,
info: Optional[datasets.DatasetInfo] = None,
features: Optional[datasets.Features] = None,
token: Optional[Union[bool, str]] = None,
repo_id: Optional[str] = None,
data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None,
data_dir: Optional[str] = None,
storage_options: Optional[dict] = None,
writer_batch_size: Optional[int] = None,
**config_kwargs,
):
if repo_id is None and dataset_name is None:
raise ValueError("repo_id or dataset_name is required for the Cache dataset builder")
if data_files is not None:
config_kwargs["data_files"] = data_files
if data_dir is not None:
config_kwargs["data_dir"] = data_dir
if hash == "auto" and version == "auto":
config_name, version, hash = _find_hash_in_cache(
dataset_name=repo_id or dataset_name,
config_name=config_name,
cache_dir=cache_dir,
config_kwargs=config_kwargs,
custom_features=features,
)
elif hash == "auto" or version == "auto":
raise NotImplementedError("Pass both hash='auto' and version='auto' instead")
super().__init__(
cache_dir=cache_dir,
dataset_name=dataset_name,
config_name=config_name,
version=version,
hash=hash,
base_path=base_path,
info=info,
token=token,
repo_id=repo_id,
storage_options=storage_options,
writer_batch_size=writer_batch_size,
)
def _info(self) -> datasets.DatasetInfo:
return datasets.DatasetInfo()
def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs):
if not os.path.exists(self.cache_dir):
raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}")
if output_dir is not None and output_dir != self.cache_dir:
shutil.copytree(self.cache_dir, output_dir)
def _split_generators(self, dl_manager):
# used to stream from cache
if isinstance(self.info.splits, datasets.SplitDict):
split_infos: list[datasets.SplitInfo] = list(self.info.splits.values())
else:
raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}")
return [
datasets.SplitGenerator(
name=split_info.name,
gen_kwargs={
"files": filenames_for_dataset_split(
self.cache_dir,
dataset_name=self.dataset_name,
split=split_info.name,
filetype_suffix="arrow",
shard_lengths=split_info.shard_lengths,
)
},
)
for split_info in split_infos
]
def _generate_tables(self, files):
# used to stream from cache
for file_idx, file in enumerate(files):
with open(file, "rb") as f:
try:
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", pa_table
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise