# Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause import pickle import re import warnings import numpy as np import pytest import scipy.sparse as sp from numpy.testing import assert_allclose import sklearn from sklearn import config_context, datasets from sklearn.base import ( BaseEstimator, OutlierMixin, TransformerMixin, clone, is_classifier, is_clusterer, is_outlier_detector, is_regressor, ) from sklearn.cluster import KMeans from sklearn.decomposition import PCA from sklearn.ensemble import IsolationForest from sklearn.exceptions import InconsistentVersionWarning from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC, SVR from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils._mocking import MockDataFrame from sklearn.utils._set_output import _get_output_config from sklearn.utils._testing import ( _convert_container, assert_array_equal, ) from sklearn.utils.validation import _check_n_features, validate_data ############################################################################# # A few test classes class MyEstimator(BaseEstimator): def __init__(self, l1=0, empty=None): self.l1 = l1 self.empty = empty class K(BaseEstimator): def __init__(self, c=None, d=None): self.c = c self.d = d class T(BaseEstimator): def __init__(self, a=None, b=None): self.a = a self.b = b class NaNTag(BaseEstimator): def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.input_tags.allow_nan = True return tags class NoNaNTag(BaseEstimator): def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.input_tags.allow_nan = False return tags class OverrideTag(NaNTag): def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.input_tags.allow_nan = False return tags class DiamondOverwriteTag(NaNTag, NoNaNTag): pass class InheritDiamondOverwriteTag(DiamondOverwriteTag): pass class ModifyInitParams(BaseEstimator): """Deprecated behavior. Equal parameters but with a type cast. Doesn't fulfill a is a """ def __init__(self, a=np.array([0])): self.a = a.copy() class Buggy(BaseEstimator): "A buggy estimator that does not set its parameters right." def __init__(self, a=None): self.a = 1 class NoEstimator: def __init__(self): pass def fit(self, X=None, y=None): return self def predict(self, X=None): return None class VargEstimator(BaseEstimator): """scikit-learn estimators shouldn't have vargs.""" def __init__(self, *vargs): pass ############################################################################# # The tests def test_clone(): # Tests that clone creates a correct deep copy. # We create an estimator, make a copy of its original state # (which, in this case, is the current state of the estimator), # and check that the obtained copy is a correct deep copy. from sklearn.feature_selection import SelectFpr, f_classif selector = SelectFpr(f_classif, alpha=0.1) new_selector = clone(selector) assert selector is not new_selector assert selector.get_params() == new_selector.get_params() selector = SelectFpr(f_classif, alpha=np.zeros((10, 2))) new_selector = clone(selector) assert selector is not new_selector def test_clone_2(): # Tests that clone doesn't copy everything. # We first create an estimator, give it an own attribute, and # make a copy of its original state. Then we check that the copy doesn't # have the specific attribute we manually added to the initial estimator. from sklearn.feature_selection import SelectFpr, f_classif selector = SelectFpr(f_classif, alpha=0.1) selector.own_attribute = "test" new_selector = clone(selector) assert not hasattr(new_selector, "own_attribute") def test_clone_buggy(): # Check that clone raises an error on buggy estimators. buggy = Buggy() buggy.a = 2 with pytest.raises(RuntimeError): clone(buggy) no_estimator = NoEstimator() with pytest.raises(TypeError): clone(no_estimator) varg_est = VargEstimator() with pytest.raises(RuntimeError): clone(varg_est) est = ModifyInitParams() with pytest.raises(RuntimeError): clone(est) def test_clone_empty_array(): # Regression test for cloning estimators with empty arrays clf = MyEstimator(empty=np.array([])) clf2 = clone(clf) assert_array_equal(clf.empty, clf2.empty) clf = MyEstimator(empty=sp.csr_matrix(np.array([[0]]))) clf2 = clone(clf) assert_array_equal(clf.empty.data, clf2.empty.data) def test_clone_nan(): # Regression test for cloning estimators with default parameter as np.nan clf = MyEstimator(empty=np.nan) clf2 = clone(clf) assert clf.empty is clf2.empty def test_clone_dict(): # test that clone creates a clone of a dict orig = {"a": MyEstimator()} cloned = clone(orig) assert orig["a"] is not cloned["a"] def test_clone_sparse_matrices(): sparse_matrix_classes = [ cls for name in dir(sp) if name.endswith("_matrix") and type(cls := getattr(sp, name)) is type ] for cls in sparse_matrix_classes: sparse_matrix = cls(np.eye(5)) clf = MyEstimator(empty=sparse_matrix) clf_cloned = clone(clf) assert clf.empty.__class__ is clf_cloned.empty.__class__ assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray()) def test_clone_estimator_types(): # Check that clone works for parameters that are types rather than # instances clf = MyEstimator(empty=MyEstimator) clf2 = clone(clf) assert clf.empty is clf2.empty def test_clone_class_rather_than_instance(): # Check that clone raises expected error message when # cloning class rather than instance msg = "You should provide an instance of scikit-learn estimator" with pytest.raises(TypeError, match=msg): clone(MyEstimator) def test_repr(): # Smoke test the repr of the base estimator. my_estimator = MyEstimator() repr(my_estimator) test = T(K(), K()) assert repr(test) == "T(a=K(), b=K())" some_est = T(a=["long_params"] * 1000) assert len(repr(some_est)) == 485 def test_str(): # Smoke test the str of the base estimator my_estimator = MyEstimator() str(my_estimator) def test_get_params(): test = T(K(), K) assert "a__d" in test.get_params(deep=True) assert "a__d" not in test.get_params(deep=False) test.set_params(a__d=2) assert test.a.d == 2 with pytest.raises(ValueError): test.set_params(a__a=2) # TODO(1.8): Remove this test when the deprecation is removed def test_is_estimator_type_class(): with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"): assert is_classifier(SVC) with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"): assert is_regressor(SVR) with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"): assert is_clusterer(KMeans) with pytest.warns(FutureWarning, match="passing a class to.*is deprecated"): assert is_outlier_detector(IsolationForest) @pytest.mark.parametrize( "estimator, expected_result", [ (SVC(), True), (GridSearchCV(SVC(), {"C": [0.1, 1]}), True), (Pipeline([("svc", SVC())]), True), (Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), True), (SVR(), False), (GridSearchCV(SVR(), {"C": [0.1, 1]}), False), (Pipeline([("svr", SVR())]), False), (Pipeline([("svr_cv", GridSearchCV(SVR(), {"C": [0.1, 1]}))]), False), ], ) def test_is_classifier(estimator, expected_result): assert is_classifier(estimator) == expected_result @pytest.mark.parametrize( "estimator, expected_result", [ (SVR(), True), (GridSearchCV(SVR(), {"C": [0.1, 1]}), True), (Pipeline([("svr", SVR())]), True), (Pipeline([("svr_cv", GridSearchCV(SVR(), {"C": [0.1, 1]}))]), True), (SVC(), False), (GridSearchCV(SVC(), {"C": [0.1, 1]}), False), (Pipeline([("svc", SVC())]), False), (Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), False), ], ) def test_is_regressor(estimator, expected_result): assert is_regressor(estimator) == expected_result @pytest.mark.parametrize( "estimator, expected_result", [ (KMeans(), True), (GridSearchCV(KMeans(), {"n_clusters": [3, 8]}), True), (Pipeline([("km", KMeans())]), True), (Pipeline([("km_cv", GridSearchCV(KMeans(), {"n_clusters": [3, 8]}))]), True), (SVC(), False), (GridSearchCV(SVC(), {"C": [0.1, 1]}), False), (Pipeline([("svc", SVC())]), False), (Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), False), ], ) def test_is_clusterer(estimator, expected_result): assert is_clusterer(estimator) == expected_result def test_set_params(): # test nested estimator parameter setting clf = Pipeline([("svc", SVC())]) # non-existing parameter in svc with pytest.raises(ValueError): clf.set_params(svc__stupid_param=True) # non-existing parameter of pipeline with pytest.raises(ValueError): clf.set_params(svm__stupid_param=True) # we don't currently catch if the things in pipeline are estimators # bad_pipeline = Pipeline([("bad", NoEstimator())]) # with pytest.raises(AttributeError): # bad_pipeline.set_params(bad__stupid_param=True) def test_set_params_passes_all_parameters(): # Make sure all parameters are passed together to set_params # of nested estimator. Regression test for #9944 class TestDecisionTree(DecisionTreeClassifier): def set_params(self, **kwargs): super().set_params(**kwargs) # expected_kwargs is in test scope assert kwargs == expected_kwargs return self expected_kwargs = {"max_depth": 5, "min_samples_leaf": 2} for est in [ Pipeline([("estimator", TestDecisionTree())]), GridSearchCV(TestDecisionTree(), {}), ]: est.set_params(estimator__max_depth=5, estimator__min_samples_leaf=2) def test_set_params_updates_valid_params(): # Check that set_params tries to set SVC().C, not # DecisionTreeClassifier().C gscv = GridSearchCV(DecisionTreeClassifier(), {}) gscv.set_params(estimator=SVC(), estimator__C=42.0) assert gscv.estimator.C == 42.0 @pytest.mark.parametrize( "tree,dataset", [ ( DecisionTreeClassifier(max_depth=2, random_state=0), datasets.make_classification(random_state=0), ), ( DecisionTreeRegressor(max_depth=2, random_state=0), datasets.make_regression(random_state=0), ), ], ) def test_score_sample_weight(tree, dataset): rng = np.random.RandomState(0) # check that the score with and without sample weights are different X, y = dataset tree.fit(X, y) # generate random sample weights sample_weight = rng.randint(1, 10, size=len(y)) score_unweighted = tree.score(X, y) score_weighted = tree.score(X, y, sample_weight=sample_weight) msg = "Unweighted and weighted scores are unexpectedly equal" assert score_unweighted != score_weighted, msg def test_clone_pandas_dataframe(): class DummyEstimator(TransformerMixin, BaseEstimator): """This is a dummy class for generating numerical features This feature extractor extracts numerical features from pandas data frame. Parameters ---------- df: pandas data frame The pandas data frame parameter. Notes ----- """ def __init__(self, df=None, scalar_param=1): self.df = df self.scalar_param = scalar_param def fit(self, X, y=None): pass def transform(self, X): pass # build and clone estimator d = np.arange(10) df = MockDataFrame(d) e = DummyEstimator(df, scalar_param=1) cloned_e = clone(e) # the test assert (e.df == cloned_e.df).values.all() assert e.scalar_param == cloned_e.scalar_param def test_clone_protocol(): """Checks that clone works with `__sklearn_clone__` protocol.""" class FrozenEstimator(BaseEstimator): def __init__(self, fitted_estimator): self.fitted_estimator = fitted_estimator def __getattr__(self, name): return getattr(self.fitted_estimator, name) def __sklearn_clone__(self): return self def fit(self, *args, **kwargs): return self def fit_transform(self, *args, **kwargs): return self.fitted_estimator.transform(*args, **kwargs) X = np.array([[-1, -1], [-2, -1], [-3, -2]]) pca = PCA().fit(X) components = pca.components_ frozen_pca = FrozenEstimator(pca) assert_allclose(frozen_pca.components_, components) # Calling PCA methods such as `get_feature_names_out` still works assert_array_equal(frozen_pca.get_feature_names_out(), pca.get_feature_names_out()) # Fitting on a new data does not alter `components_` X_new = np.asarray([[-1, 2], [3, 4], [1, 2]]) frozen_pca.fit(X_new) assert_allclose(frozen_pca.components_, components) # `fit_transform` does not alter state frozen_pca.fit_transform(X_new) assert_allclose(frozen_pca.components_, components) # Cloning estimator is a no-op clone_frozen_pca = clone(frozen_pca) assert clone_frozen_pca is frozen_pca assert_allclose(clone_frozen_pca.components_, components) def test_pickle_version_warning_is_not_raised_with_matching_version(): iris = datasets.load_iris() tree = DecisionTreeClassifier().fit(iris.data, iris.target) tree_pickle = pickle.dumps(tree) assert b"_sklearn_version" in tree_pickle with warnings.catch_warnings(): warnings.simplefilter("error") tree_restored = pickle.loads(tree_pickle) # test that we can predict with the restored decision tree classifier score_of_original = tree.score(iris.data, iris.target) score_of_restored = tree_restored.score(iris.data, iris.target) assert score_of_original == score_of_restored class TreeBadVersion(DecisionTreeClassifier): def __getstate__(self): return dict(self.__dict__.items(), _sklearn_version="something") pickle_error_message = ( "Trying to unpickle estimator {estimator} from " "version {old_version} when using version " "{current_version}. This might " "lead to breaking code or invalid results. " "Use at your own risk." ) def test_pickle_version_warning_is_issued_upon_different_version(): iris = datasets.load_iris() tree = TreeBadVersion().fit(iris.data, iris.target) tree_pickle_other = pickle.dumps(tree) message = pickle_error_message.format( estimator="TreeBadVersion", old_version="something", current_version=sklearn.__version__, ) with pytest.warns(UserWarning, match=message) as warning_record: pickle.loads(tree_pickle_other) message = warning_record.list[0].message assert isinstance(message, InconsistentVersionWarning) assert message.estimator_name == "TreeBadVersion" assert message.original_sklearn_version == "something" assert message.current_sklearn_version == sklearn.__version__ class TreeNoVersion(DecisionTreeClassifier): def __getstate__(self): return self.__dict__ def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle(): iris = datasets.load_iris() # TreeNoVersion has no getstate, like pre-0.18 tree = TreeNoVersion().fit(iris.data, iris.target) tree_pickle_noversion = pickle.dumps(tree) assert b"_sklearn_version" not in tree_pickle_noversion message = pickle_error_message.format( estimator="TreeNoVersion", old_version="pre-0.18", current_version=sklearn.__version__, ) # check we got the warning about using pre-0.18 pickle with pytest.warns(UserWarning, match=message): pickle.loads(tree_pickle_noversion) def test_pickle_version_no_warning_is_issued_with_non_sklearn_estimator(): iris = datasets.load_iris() tree = TreeNoVersion().fit(iris.data, iris.target) tree_pickle_noversion = pickle.dumps(tree) try: module_backup = TreeNoVersion.__module__ TreeNoVersion.__module__ = "notsklearn" with warnings.catch_warnings(): warnings.simplefilter("error") pickle.loads(tree_pickle_noversion) finally: TreeNoVersion.__module__ = module_backup class DontPickleAttributeMixin: def __getstate__(self): data = self.__dict__.copy() data["_attribute_not_pickled"] = None return data def __setstate__(self, state): state["_restored"] = True self.__dict__.update(state) class MultiInheritanceEstimator(DontPickleAttributeMixin, BaseEstimator): def __init__(self, attribute_pickled=5): self.attribute_pickled = attribute_pickled self._attribute_not_pickled = None def test_pickling_when_getstate_is_overwritten_by_mixin(): estimator = MultiInheritanceEstimator() estimator._attribute_not_pickled = "this attribute should not be pickled" serialized = pickle.dumps(estimator) estimator_restored = pickle.loads(serialized) assert estimator_restored.attribute_pickled == 5 assert estimator_restored._attribute_not_pickled is None assert estimator_restored._restored def test_pickling_when_getstate_is_overwritten_by_mixin_outside_of_sklearn(): try: estimator = MultiInheritanceEstimator() text = "this attribute should not be pickled" estimator._attribute_not_pickled = text old_mod = type(estimator).__module__ type(estimator).__module__ = "notsklearn" serialized = estimator.__getstate__() assert serialized == {"_attribute_not_pickled": None, "attribute_pickled": 5} serialized["attribute_pickled"] = 4 estimator.__setstate__(serialized) assert estimator.attribute_pickled == 4 assert estimator._restored finally: type(estimator).__module__ = old_mod class SingleInheritanceEstimator(BaseEstimator): def __init__(self, attribute_pickled=5): self.attribute_pickled = attribute_pickled self._attribute_not_pickled = None def __getstate__(self): state = super().__getstate__() state["_attribute_not_pickled"] = None return state def test_pickling_works_when_getstate_is_overwritten_in_the_child_class(): estimator = SingleInheritanceEstimator() estimator._attribute_not_pickled = "this attribute should not be pickled" serialized = pickle.dumps(estimator) estimator_restored = pickle.loads(serialized) assert estimator_restored.attribute_pickled == 5 assert estimator_restored._attribute_not_pickled is None def test_tag_inheritance(): # test that changing tags by inheritance is not allowed nan_tag_est = NaNTag() no_nan_tag_est = NoNaNTag() assert nan_tag_est.__sklearn_tags__().input_tags.allow_nan assert not no_nan_tag_est.__sklearn_tags__().input_tags.allow_nan redefine_tags_est = OverrideTag() assert not redefine_tags_est.__sklearn_tags__().input_tags.allow_nan diamond_tag_est = DiamondOverwriteTag() assert diamond_tag_est.__sklearn_tags__().input_tags.allow_nan inherit_diamond_tag_est = InheritDiamondOverwriteTag() assert inherit_diamond_tag_est.__sklearn_tags__().input_tags.allow_nan def test_raises_on_get_params_non_attribute(): class MyEstimator(BaseEstimator): def __init__(self, param=5): pass def fit(self, X, y=None): return self est = MyEstimator() msg = "'MyEstimator' object has no attribute 'param'" with pytest.raises(AttributeError, match=msg): est.get_params() def test_repr_mimebundle_(): # Checks the display configuration flag controls the json output tree = DecisionTreeClassifier() output = tree._repr_mimebundle_() assert "text/plain" in output assert "text/html" in output with config_context(display="text"): output = tree._repr_mimebundle_() assert "text/plain" in output assert "text/html" not in output def test_repr_html_wraps(): # Checks the display configuration flag controls the html output tree = DecisionTreeClassifier() output = tree._repr_html_() assert "<style>" in output with config_context(display="text"): msg = "_repr_html_ is only defined when" with pytest.raises(AttributeError, match=msg): output = tree._repr_html_() def test_n_features_in_validation(): """Check that `_check_n_features` validates data when reset=False""" est = MyEstimator() X_train = [[1, 2, 3], [4, 5, 6]] _check_n_features(est, X_train, reset=True) assert est.n_features_in_ == 3 msg = "X does not contain any features, but MyEstimator is expecting 3 features" with pytest.raises(ValueError, match=msg): _check_n_features(est, "invalid X", reset=False) def test_n_features_in_no_validation(): """Check that `_check_n_features` does not validate data when n_features_in_ is not defined.""" est = MyEstimator() _check_n_features(est, "invalid X", reset=True) assert not hasattr(est, "n_features_in_") # does not raise _check_n_features(est, "invalid X", reset=False) def test_feature_names_in(): """Check that feature_name_in are recorded by `_validate_data`""" pd = pytest.importorskip("pandas") iris = datasets.load_iris() X_np = iris.data df = pd.DataFrame(X_np, columns=iris.feature_names) class NoOpTransformer(TransformerMixin, BaseEstimator): def fit(self, X, y=None): validate_data(self, X) return self def transform(self, X): validate_data(self, X, reset=False) return X # fit on dataframe saves the feature names trans = NoOpTransformer().fit(df) assert_array_equal(trans.feature_names_in_, df.columns) # fit again but on ndarray does not keep the previous feature names (see #21383) trans.fit(X_np) assert not hasattr(trans, "feature_names_in_") trans.fit(df) msg = "The feature names should match those that were passed" df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1]) with pytest.raises(ValueError, match=msg): trans.transform(df_bad) # warns when fitted on dataframe and transforming a ndarray msg = ( "X does not have valid feature names, but NoOpTransformer was " "fitted with feature names" ) with pytest.warns(UserWarning, match=msg): trans.transform(X_np) # warns when fitted on a ndarray and transforming dataframe msg = "X has feature names, but NoOpTransformer was fitted without feature names" trans = NoOpTransformer().fit(X_np) with pytest.warns(UserWarning, match=msg): trans.transform(df) # fit on dataframe with all integer feature names works without warning df_int_names = pd.DataFrame(X_np) trans = NoOpTransformer() with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) trans.fit(df_int_names) # fit on dataframe with no feature names or all integer feature names # -> do not warn on transform Xs = [X_np, df_int_names] for X in Xs: with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) trans.transform(X) # fit on dataframe with feature names that are mixed raises an error: df_mixed = pd.DataFrame(X_np, columns=["a", "b", 1, 2]) trans = NoOpTransformer() msg = re.escape( "Feature names are only supported if all input features have string names, " "but your input has ['int', 'str'] as feature name / column name types. " "If you want feature names to be stored and validated, you must convert " "them all to strings, by using X.columns = X.columns.astype(str) for " "example. Otherwise you can remove feature / column names from your input " "data, or convert them all to a non-string data type." ) with pytest.raises(TypeError, match=msg): trans.fit(df_mixed) # transform on feature names that are mixed also raises: with pytest.raises(TypeError, match=msg): trans.transform(df_mixed) def test_validate_data_skip_check_array(): """Check skip_check_array option of _validate_data.""" pd = pytest.importorskip("pandas") iris = datasets.load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) y = pd.Series(iris.target) class NoOpTransformer(TransformerMixin, BaseEstimator): pass no_op = NoOpTransformer() X_np_out = validate_data(no_op, df, skip_check_array=False) assert isinstance(X_np_out, np.ndarray) assert_allclose(X_np_out, df.to_numpy()) X_df_out = validate_data(no_op, df, skip_check_array=True) assert X_df_out is df y_np_out = validate_data(no_op, y=y, skip_check_array=False) assert isinstance(y_np_out, np.ndarray) assert_allclose(y_np_out, y.to_numpy()) y_series_out = validate_data(no_op, y=y, skip_check_array=True) assert y_series_out is y X_np_out, y_np_out = validate_data(no_op, df, y, skip_check_array=False) assert isinstance(X_np_out, np.ndarray) assert_allclose(X_np_out, df.to_numpy()) assert isinstance(y_np_out, np.ndarray) assert_allclose(y_np_out, y.to_numpy()) X_df_out, y_series_out = validate_data(no_op, df, y, skip_check_array=True) assert X_df_out is df assert y_series_out is y msg = "Validation should be done on X, y or both." with pytest.raises(ValueError, match=msg): validate_data(no_op) def test_clone_keeps_output_config(): """Check that clone keeps the set_output config.""" ss = StandardScaler().set_output(transform="pandas") config = _get_output_config("transform", ss) ss_clone = clone(ss) config_clone = _get_output_config("transform", ss_clone) assert config == config_clone class _Empty: pass class EmptyEstimator(_Empty, BaseEstimator): pass @pytest.mark.parametrize("estimator", [BaseEstimator(), EmptyEstimator()]) def test_estimator_empty_instance_dict(estimator): """Check that ``__getstate__`` returns an empty ``dict`` with an empty instance. Python 3.11+ changed behaviour by returning ``None`` instead of raising an ``AttributeError``. Non-regression test for gh-25188. """ state = estimator.__getstate__() expected = {"_sklearn_version": sklearn.__version__} assert state == expected # this should not raise pickle.loads(pickle.dumps(BaseEstimator())) def test_estimator_getstate_using_slots_error_message(): """Using a `BaseEstimator` with `__slots__` is not supported.""" class WithSlots: __slots__ = ("x",) class Estimator(BaseEstimator, WithSlots): pass msg = ( "You cannot use `__slots__` in objects inheriting from " "`sklearn.base.BaseEstimator`" ) with pytest.raises(TypeError, match=msg): Estimator().__getstate__() with pytest.raises(TypeError, match=msg): pickle.dumps(Estimator()) @pytest.mark.parametrize( "constructor_name, minversion", [ ("dataframe", "1.5.0"), ("pyarrow", "12.0.0"), ("polars", "0.20.23"), ], ) def test_dataframe_protocol(constructor_name, minversion): """Uses the dataframe exchange protocol to get feature names.""" data = [[1, 4, 2], [3, 3, 6]] columns = ["col_0", "col_1", "col_2"] df = _convert_container( data, constructor_name, columns_name=columns, minversion=minversion ) class NoOpTransformer(TransformerMixin, BaseEstimator): def fit(self, X, y=None): validate_data(self, X) return self def transform(self, X): return validate_data(self, X, reset=False) no_op = NoOpTransformer() no_op.fit(df) assert_array_equal(no_op.feature_names_in_, columns) X_out = no_op.transform(df) if constructor_name != "pyarrow": # pyarrow does not work with `np.asarray` # https://github.com/apache/arrow/issues/34886 assert_allclose(df, X_out) bad_names = ["a", "b", "c"] df_bad = _convert_container(data, constructor_name, columns_name=bad_names) with pytest.raises(ValueError, match="The feature names should match"): no_op.transform(df_bad) @config_context(enable_metadata_routing=True) def test_transformer_fit_transform_with_metadata_in_transform(): """Test that having a transformer with metadata for transform raises a warning when calling fit_transform.""" class CustomTransformer(BaseEstimator, TransformerMixin): def fit(self, X, y=None, prop=None): return self def transform(self, X, prop=None): return X # passing the metadata to `fit_transform` should raise a warning since it # could potentially be consumed by `transform` with pytest.warns(UserWarning, match="`transform` method which consumes metadata"): CustomTransformer().set_transform_request(prop=True).fit_transform( [[1]], [1], prop=1 ) # not passing a metadata which can potentially be consumed by `transform` should # not raise a warning with warnings.catch_warnings(record=True) as record: CustomTransformer().set_transform_request(prop=True).fit_transform([[1]], [1]) assert len(record) == 0 @config_context(enable_metadata_routing=True) def test_outlier_mixin_fit_predict_with_metadata_in_predict(): """Test that having an OutlierMixin with metadata for predict raises a warning when calling fit_predict.""" class CustomOutlierDetector(BaseEstimator, OutlierMixin): def fit(self, X, y=None, prop=None): return self def predict(self, X, prop=None): return X # passing the metadata to `fit_predict` should raise a warning since it # could potentially be consumed by `predict` with pytest.warns(UserWarning, match="`predict` method which consumes metadata"): CustomOutlierDetector().set_predict_request(prop=True).fit_predict( [[1]], [1], prop=1 ) # not passing a metadata which can potentially be consumed by `predict` should # not raise a warning with warnings.catch_warnings(record=True) as record: CustomOutlierDetector().set_predict_request(prop=True).fit_predict([[1]], [1]) assert len(record) == 0
Memory