# Copyright 2020 The HuggingFace Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from functools import partial from typing import TYPE_CHECKING, Optional import pyarrow as pa from .. import config from ..features import Features from ..features.features import decode_nested_example from ..utils.py_utils import no_op_if_value_is_null from .formatting import BaseArrowExtractor, TableFormatter if TYPE_CHECKING: import polars as pl class PolarsArrowExtractor(BaseArrowExtractor["pl.DataFrame", "pl.Series", "pl.DataFrame"]): def extract_row(self, pa_table: pa.Table) -> "pl.DataFrame": if config.POLARS_AVAILABLE: if "polars" not in sys.modules: import polars else: polars = sys.modules["polars"] return polars.from_arrow(pa_table.slice(length=1)) else: raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") def extract_column(self, pa_table: pa.Table) -> "pl.Series": if config.POLARS_AVAILABLE: if "polars" not in sys.modules: import polars else: polars = sys.modules["polars"] return polars.from_arrow(pa_table.select([0]))[pa_table.column_names[0]] else: raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") def extract_batch(self, pa_table: pa.Table) -> "pl.DataFrame": if config.POLARS_AVAILABLE: if "polars" not in sys.modules: import polars else: polars = sys.modules["polars"] return polars.from_arrow(pa_table) else: raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") class PolarsFeaturesDecoder: def __init__(self, features: Optional[Features]): self.features = features import polars as pl # noqa: F401 - import pl at initialization def decode_row(self, row: "pl.DataFrame") -> "pl.DataFrame": decode = ( { column_name: no_op_if_value_is_null(partial(decode_nested_example, feature)) for column_name, feature in self.features.items() if self.features._column_requires_decoding[column_name] } if self.features else {} ) if decode: row[list(decode.keys())] = row.map_rows(decode) return row def decode_column(self, column: "pl.Series", column_name: str) -> "pl.Series": decode = ( no_op_if_value_is_null(partial(decode_nested_example, self.features[column_name])) if self.features and column_name in self.features and self.features._column_requires_decoding[column_name] else None ) if decode: column = column.map_elements(decode) return column def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame": return self.decode_row(batch) class PolarsFormatter(TableFormatter["pl.DataFrame", "pl.Series", "pl.DataFrame"]): table_type = "polars dataframe" column_type = "polars series" def __init__(self, features=None, **np_array_kwargs): super().__init__(features=features) self.np_array_kwargs = np_array_kwargs self.polars_arrow_extractor = PolarsArrowExtractor self.polars_features_decoder = PolarsFeaturesDecoder(features) import polars as pl # noqa: F401 - import pl at initialization def format_row(self, pa_table: pa.Table) -> "pl.DataFrame": row = self.polars_arrow_extractor().extract_row(pa_table) row = self.polars_features_decoder.decode_row(row) return row def format_column(self, pa_table: pa.Table) -> "pl.Series": column = self.polars_arrow_extractor().extract_column(pa_table) column = self.polars_features_decoder.decode_column(column, pa_table.column_names[0]) return column def format_batch(self, pa_table: pa.Table) -> "pl.DataFrame": row = self.polars_arrow_extractor().extract_batch(pa_table) row = self.polars_features_decoder.decode_batch(row) return row
Memory