import itertools
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import pandas as pd
import pyarrow as pa
import datasets
import datasets.config
from datasets.features.features import require_storage_cast
from datasets.table import table_cast
from datasets.utils.py_utils import Literal
logger = datasets.utils.logging.get_logger(__name__)
_PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS = ["names", "prefix"]
_PANDAS_READ_CSV_DEPRECATED_PARAMETERS = ["warn_bad_lines", "error_bad_lines", "mangle_dupe_cols"]
_PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS = ["encoding_errors", "on_bad_lines"]
_PANDAS_READ_CSV_NEW_2_0_0_PARAMETERS = ["date_format"]
_PANDAS_READ_CSV_DEPRECATED_2_2_0_PARAMETERS = ["verbose"]
@dataclass
class CsvConfig(datasets.BuilderConfig):
"""BuilderConfig for CSV."""
sep: str = ","
delimiter: Optional[str] = None
header: Optional[Union[int, list[int], str]] = "infer"
names: Optional[list[str]] = None
column_names: Optional[list[str]] = None
index_col: Optional[Union[int, str, list[int], list[str]]] = None
usecols: Optional[Union[list[int], list[str]]] = None
prefix: Optional[str] = None
mangle_dupe_cols: bool = True
engine: Optional[Literal["c", "python", "pyarrow"]] = None
converters: dict[Union[int, str], Callable[[Any], Any]] = None
true_values: Optional[list] = None
false_values: Optional[list] = None
skipinitialspace: bool = False
skiprows: Optional[Union[int, list[int]]] = None
nrows: Optional[int] = None
na_values: Optional[Union[str, list[str]]] = None
keep_default_na: bool = True
na_filter: bool = True
verbose: bool = False
skip_blank_lines: bool = True
thousands: Optional[str] = None
decimal: str = "."
lineterminator: Optional[str] = None
quotechar: str = '"'
quoting: int = 0
escapechar: Optional[str] = None
comment: Optional[str] = None
encoding: Optional[str] = None
dialect: Optional[str] = None
error_bad_lines: bool = True
warn_bad_lines: bool = True
skipfooter: int = 0
doublequote: bool = True
memory_map: bool = False
float_precision: Optional[str] = None
chunksize: int = 10_000
features: Optional[datasets.Features] = None
encoding_errors: Optional[str] = "strict"
on_bad_lines: Literal["error", "warn", "skip"] = "error"
date_format: Optional[str] = None
def __post_init__(self):
super().__post_init__()
if self.delimiter is not None:
self.sep = self.delimiter
if self.column_names is not None:
self.names = self.column_names
@property
def pd_read_csv_kwargs(self):
pd_read_csv_kwargs = {
"sep": self.sep,
"header": self.header,
"names": self.names,
"index_col": self.index_col,
"usecols": self.usecols,
"prefix": self.prefix,
"mangle_dupe_cols": self.mangle_dupe_cols,
"engine": self.engine,
"converters": self.converters,
"true_values": self.true_values,
"false_values": self.false_values,
"skipinitialspace": self.skipinitialspace,
"skiprows": self.skiprows,
"nrows": self.nrows,
"na_values": self.na_values,
"keep_default_na": self.keep_default_na,
"na_filter": self.na_filter,
"verbose": self.verbose,
"skip_blank_lines": self.skip_blank_lines,
"thousands": self.thousands,
"decimal": self.decimal,
"lineterminator": self.lineterminator,
"quotechar": self.quotechar,
"quoting": self.quoting,
"escapechar": self.escapechar,
"comment": self.comment,
"encoding": self.encoding,
"dialect": self.dialect,
"error_bad_lines": self.error_bad_lines,
"warn_bad_lines": self.warn_bad_lines,
"skipfooter": self.skipfooter,
"doublequote": self.doublequote,
"memory_map": self.memory_map,
"float_precision": self.float_precision,
"chunksize": self.chunksize,
"encoding_errors": self.encoding_errors,
"on_bad_lines": self.on_bad_lines,
"date_format": self.date_format,
}
# some kwargs must not be passed if they don't have a default value
# some others are deprecated and we can also not pass them if they are the default value
for pd_read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS:
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter):
del pd_read_csv_kwargs[pd_read_csv_parameter]
# Remove 1.3 new arguments
if not (datasets.config.PANDAS_VERSION.major >= 1 and datasets.config.PANDAS_VERSION.minor >= 3):
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS:
del pd_read_csv_kwargs[pd_read_csv_parameter]
# Remove 2.0 new arguments
if not (datasets.config.PANDAS_VERSION.major >= 2):
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_2_0_0_PARAMETERS:
del pd_read_csv_kwargs[pd_read_csv_parameter]
# Remove 2.2 deprecated arguments
if datasets.config.PANDAS_VERSION.release >= (2, 2):
for pd_read_csv_parameter in _PANDAS_READ_CSV_DEPRECATED_2_2_0_PARAMETERS:
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter):
del pd_read_csv_kwargs[pd_read_csv_parameter]
return pd_read_csv_kwargs
class Csv(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = CsvConfig
def _info(self):
return datasets.DatasetInfo(features=self.config.features)
def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
schema = self.config.features.arrow_schema
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
# cheaper cast
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
else:
# more expensive cast; allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, schema)
return pa_table
def _generate_tables(self, files):
schema = self.config.features.arrow_schema if self.config.features else None
# dtype allows reading an int column as str
dtype = (
{
name: dtype.to_pandas_dtype() if not require_storage_cast(feature) else object
for name, dtype, feature in zip(schema.names, schema.types, self.config.features.values())
}
if schema is not None
else None
)
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
try:
for batch_idx, df in enumerate(csv_file_reader):
pa_table = pa.Table.from_pandas(df)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise