From 9318d6dae4c2a5c8a984347bb9d6b061d26d0139 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Wed, 20 Apr 2022 18:30:28 +0800 Subject: [PATCH 1/5] Expose `feature_types` to sklearn interface. --- python-package/xgboost/_typing.py | 5 +-- python-package/xgboost/core.py | 25 ++++++------- python-package/xgboost/dask.py | 18 +++++++--- python-package/xgboost/data.py | 59 +++++++++++++++++-------------- python-package/xgboost/sklearn.py | 30 ++++++++++++++-- tests/python/test_with_dask.py | 7 ++++ tests/python/test_with_sklearn.py | 26 ++++++++++++++ 7 files changed, 123 insertions(+), 47 deletions(-) diff --git a/python-package/xgboost/_typing.py b/python-package/xgboost/_typing.py index d21de6f0ed8c..64ea9a0a2993 100644 --- a/python-package/xgboost/_typing.py +++ b/python-package/xgboost/_typing.py @@ -1,7 +1,7 @@ """Shared typing definition.""" import ctypes import os -from typing import Optional, List, Any, TypeVar, Union +from typing import Optional, Any, TypeVar, Union, Sequence # os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/ # cudf.DataFrame/cupy.array/dlpack @@ -9,7 +9,8 @@ # xgboost accepts some other possible types in practice due to historical reason, which is # lesser tested. For now we encourage users to pass a simple list of string. -FeatureNames = Optional[List[str]] +FeatureNames = Optional[Sequence[str]] +FeatureTypes = Optional[Sequence[str]] ArrayLike = Any PathLike = Union[str, os.PathLike] diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3321e2f0819f..046322613bc3 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -31,6 +31,7 @@ CFloatPtr, NumpyOrCupy, FeatureNames, + FeatureTypes, _T, CupyT, ) @@ -553,7 +554,7 @@ def __init__( missing: Optional[float] = None, silent: bool = False, feature_names: FeatureNames = None, - feature_types: Optional[List[str]] = None, + feature_types: FeatureTypes = None, nthread: Optional[int] = None, group: Optional[ArrayLike] = None, qid: Optional[ArrayLike] = None, @@ -594,10 +595,15 @@ def __init__( Whether print messages during construction feature_names : list, optional Set names for features. - feature_types : + feature_types : FeatureTypes Set types for features. When `enable_categorical` is set to `True`, string - "c" represents categorical data type. + "c" represents categorical data type while "q" represents numerical feature + type. For categorical features, the input is assumed to be preprocessed and + encoded by the users. The encoding can be done via + :py:class:`sklearn.preprocessing.OrdinalEncoder` or pandas dataframe + `.cat.codes` method. This is useful when users want to specify categorical + features without having to construct a dataframe as input. nthread : integer, optional Number of threads to use for loading data when parallelization is @@ -1062,12 +1068,7 @@ def feature_names(self, feature_names: FeatureNames) -> None: @property def feature_types(self) -> Optional[List[str]]: - """Get feature types (column types). - - Returns - ------- - feature_types : list or None - """ + """Get feature types. See :py:class:`DMatrix` for details.""" length = c_bst_ulong() sarr = ctypes.POINTER(ctypes.c_char_p)() _check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle, @@ -1083,8 +1084,8 @@ def feature_types(self) -> Optional[List[str]]: def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: """Set feature types (column types). - This is for displaying the results and categorical data support. See doc string - of :py:obj:`xgboost.DMatrix` for details. + This is for displaying the results and categorical data support. See + :py:class:`DMatrix` for details. Parameters ---------- @@ -1667,7 +1668,7 @@ def _set_feature_info(self, features: Optional[List[str]], field: str) -> None: @property def feature_types(self) -> Optional[List[str]]: """Feature types for this booster. Can be directly set by input data or by - assignment. + assignment. See :py:class:`DMatrix` for details. """ return self._get_feature_info("feature_type") diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index a09eeefa0840..942893f0a32d 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -54,10 +54,11 @@ from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat from .compat import lazy_isinstance +from ._typing import FeatureNames, FeatureTypes + from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter from .core import Objective, Metric from .core import _deprecate_positional_args, _has_categorical -from .data import FeatureNames from .training import train as worker_train from .tracker import RabitTracker, get_host_ip from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase @@ -327,7 +328,7 @@ def __init__( missing: float = None, silent: bool = False, # pylint: disable=unused-argument feature_names: FeatureNames = None, - feature_types: Optional[List[str]] = None, + feature_types: FeatureTypes = None, group: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, @@ -1601,7 +1602,11 @@ async def _predict_async( predts = predts.to_dask_array() else: test_dmatrix = await DaskDMatrix( - self.client, data=data, base_margin=base_margin, missing=self.missing + self.client, + data=data, + base_margin=base_margin, + missing=self.missing, + feature_types=self.feature_types ) predts = await predict( self.client, @@ -1640,7 +1645,9 @@ async def _apply_async( iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: iteration_range = self._get_iteration_range(iteration_range) - test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing) + test_dmatrix = await DaskDMatrix( + self.client, data=X, missing=self.missing, feature_types=self.feature_types, + ) predts = await predict( self.client, model=self.get_booster(), @@ -1755,6 +1762,7 @@ async def _fit_async( eval_qid=None, missing=self.missing, enable_categorical=self.enable_categorical, + feature_types=self.feature_types, ) if callable(self.objective): @@ -1849,6 +1857,7 @@ async def _fit_async( eval_qid=None, missing=self.missing, enable_categorical=self.enable_categorical, + feature_types=self.feature_types, ) # pylint: disable=attribute-defined-outside-init @@ -2054,6 +2063,7 @@ async def _fit_async( eval_qid=eval_qid, missing=self.missing, enable_categorical=self.enable_categorical, + feature_types=self.feature_types, ) if eval_metric is not None: if callable(eval_metric): diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 119b354fc6dd..f5a935e09f8e 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -13,6 +13,7 @@ from .core import c_array, _LIB, _check_call, c_str from .core import _cuda_array_interface from .core import DataIter, _ProxyDMatrix, DMatrix, FeatureNames +from ._typing import FeatureTypes from .compat import lazy_isinstance, DataFrame c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name @@ -70,7 +71,7 @@ def _from_scipy_csr( missing, nthread, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): """Initialize data from a CSR matrix.""" if len(data.indices) != len(data.data): @@ -109,7 +110,7 @@ def _from_scipy_csc( data, missing, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): if len(data.indices) != len(data.data): raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}") @@ -165,7 +166,7 @@ def _from_numpy_array( missing, nthread, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): """Initialize data from a 2-D numpy matrix. @@ -228,6 +229,12 @@ def _is_modin_df(data): } +_enable_cat_err = ( + "When categorical type is supplied, DMatrix parameter " + "`enable_categorical` must be set to `True`." +) + + def _invalid_dataframe_dtype(data: Any) -> None: # pandas series has `dtypes` but it's just a single object # cudf series doesn't have `dtypes`. @@ -241,9 +248,8 @@ def _invalid_dataframe_dtype(data: Any) -> None: else: err = "" - msg = """DataFrame.dtypes for data must be int, float, bool or category. When -categorical type is supplied, DMatrix parameter `enable_categorical` must -be set to `True`.""" + err + type_err = "DataFrame.dtypes for data must be int, float, bool or category." + msg = f"""{type_err} {_enable_cat_err} {err}""" raise ValueError(msg) @@ -340,8 +346,8 @@ def _from_pandas_df( missing: float, nthread: int, feature_names: FeatureNames, - feature_types: Optional[List[str]], -) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]: + feature_types: FeatureTypes, +) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]: data, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types ) @@ -382,7 +388,7 @@ def _from_pandas_series( nthread: int, enable_categorical: bool, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): from pandas.api.types import is_categorical_dtype @@ -413,7 +419,7 @@ def _is_dt_df(data): def _transform_dt_df( data, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, meta=None, meta_type=None, ): @@ -454,9 +460,9 @@ def _from_dt_df( missing, nthread, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, -) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]: +) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]: if enable_categorical: raise ValueError("categorical data in datatable is not supported yet.") data, feature_names, feature_types = _transform_dt_df( @@ -542,10 +548,10 @@ def _from_arrow( data, missing: float, nthread: int, - feature_names: Optional[List[str]], - feature_types: Optional[List[str]], + feature_names: FeatureNames, + feature_types: FeatureTypes, enable_categorical: bool, -) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]: +) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]: import pyarrow as pa if not all( @@ -621,7 +627,7 @@ def _cudf_array_interfaces(data, cat_codes: list) -> bytes: def _transform_cudf_df( data, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, ): try: @@ -687,7 +693,7 @@ def _from_cudf_df( missing, nthread, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, ) -> Tuple[ctypes.c_void_p, Any, Any]: data, cat_codes, feature_names, feature_types = _transform_cudf_df( @@ -735,7 +741,7 @@ def _from_cupy_array( missing, nthread, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): """Initialize DMatrix from cupy ndarray.""" data = _transform_cupy_array(data) @@ -782,7 +788,7 @@ def _from_dlpack( missing, nthread, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): data = _transform_dlpack(data) return _from_cupy_array(data, missing, nthread, feature_names, @@ -797,7 +803,7 @@ def _from_uri( data, missing, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): _warn_unused_missing(data, missing) handle = ctypes.c_void_p() @@ -817,7 +823,7 @@ def _from_list( missing, n_threads, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): array = np.array(data) _check_data_shape(data) @@ -833,7 +839,7 @@ def _from_tuple( missing, n_threads, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, ): return _from_list(data, missing, n_threads, feature_names, feature_types) @@ -869,10 +875,12 @@ def dispatch_data_backend( missing, threads, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool = False, ): '''Dispatch data for DMatrix.''' + if feature_types is not None and not enable_categorical and "c" in feature_types: + raise ValueError(_enable_cat_err) if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) if _is_scipy_csr(data): @@ -884,8 +892,7 @@ def dispatch_data_backend( data.tocsr(), missing, threads, feature_names, feature_types ) if _is_numpy_array(data): - return _from_numpy_array(data, missing, threads, feature_names, - feature_types) + return _from_numpy_array(data, missing, threads, feature_names, feature_types) if _is_uri(data): return _from_uri(data, missing, feature_names, feature_types) if _is_list(data): @@ -1101,7 +1108,7 @@ def reset(self) -> None: def _proxy_transform( data, feature_names: FeatureNames, - feature_types: Optional[List[str]], + feature_types: FeatureTypes, enable_categorical: bool, ): if _is_cudf_df(data) or _is_cudf_ser(data): diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index d27cc6354641..98e4b43713e8 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -14,7 +14,7 @@ from .training import train from .callback import TrainingCallback from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array -from ._typing import ArrayLike +from ._typing import ArrayLike, FeatureTypes # Do not use class names on scikit-learn directly. Re-define the classes on # .compat to guarantee the behavior without scikit-learn @@ -211,6 +211,13 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: should be used to specify categorical data type. Also, JSON/UBJSON serialization format is required. + feature_types : FeatureTypes + + .. versionadded:: 2.0.0 + + Used for specifying feature types without constructing a dataframe. See + :py:class:`DMatrix` for details. + max_cat_to_onehot : Optional[int] .. versionadded:: 1.6.0 @@ -394,6 +401,7 @@ def _wrap_evaluation_matrices( eval_qid: Optional[Sequence[Any]], create_dmatrix: Callable, enable_categorical: bool, + feature_types: FeatureTypes, ) -> Tuple[Any, List[Tuple[Any, str]]]: """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. @@ -408,6 +416,7 @@ def _wrap_evaluation_matrices( feature_weights=feature_weights, missing=missing, enable_categorical=enable_categorical, + feature_types=feature_types, ) n_validation = 0 if eval_set is None else len(eval_set) @@ -455,6 +464,7 @@ def validate_or_none(meta: Optional[Sequence], name: str) -> Sequence: base_margin=base_margin_eval_set[i], missing=missing, enable_categorical=enable_categorical, + feature_types=feature_types, ) evals.append(m) nevals = len(evals) @@ -518,6 +528,7 @@ def __init__( validate_parameters: Optional[bool] = None, predictor: Optional[str] = None, enable_categorical: bool = False, + feature_types: FeatureTypes = None, max_cat_to_onehot: Optional[int] = None, eval_metric: Optional[Union[str, List[str], Callable]] = None, early_stopping_rounds: Optional[int] = None, @@ -562,6 +573,7 @@ def __init__( self.validate_parameters = validate_parameters self.predictor = predictor self.enable_categorical = enable_categorical + self.feature_types = feature_types self.max_cat_to_onehot = max_cat_to_onehot self.eval_metric = eval_metric self.early_stopping_rounds = early_stopping_rounds @@ -684,6 +696,7 @@ def get_xgb_params(self) -> Dict[str, Any]: "enable_categorical", "early_stopping_rounds", "callbacks", + "feature_types", } filtered = {} for k, v in params.items(): @@ -715,6 +728,8 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None: # numpy array is not JSON serializable meta['classes_'] = self.classes_.tolist() continue + if k == "feature_types": + continue try: json.dumps({k: v}) meta[k] = v @@ -754,6 +769,8 @@ def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: if k == 'classes_': self.classes_ = np.array(v) continue + if k == "feature_types": + self.feature_types = self.get_booster().feature_types if k == "_estimator_type": if self._get_type() != v: raise TypeError( @@ -944,6 +961,7 @@ def fit( eval_qid=None, create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), enable_categorical=self.enable_categorical, + feature_types=self.feature_types ) params = self.get_xgb_params() @@ -1063,9 +1081,11 @@ def predict( pass test = DMatrix( - X, base_margin=base_margin, + X, + base_margin=base_margin, missing=self.missing, nthread=self.n_jobs, + feature_types=self.feature_types, enable_categorical=self.enable_categorical ) return self.get_booster().predict( @@ -1106,7 +1126,9 @@ def apply( self.get_booster(), ntree_limit, iteration_range ) iteration_range = self._get_iteration_range(iteration_range) - test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs) + test_dmatrix = DMatrix( + X, missing=self.missing, feature_types=self.feature_types, nthread=self.n_jobs + ) return self.get_booster().predict( test_dmatrix, pred_leaf=True, @@ -1395,6 +1417,7 @@ def fit( eval_qid=None, create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), enable_categorical=self.enable_categorical, + feature_types=self.feature_types, ) self._Booster = train( @@ -1826,6 +1849,7 @@ def fit( eval_qid=eval_qid, create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), enable_categorical=self.enable_categorical, + feature_types=self.feature_types, ) evals_result: TrainingCallback.EvalsLog = {} diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 4e80409d4764..031b2fab7da0 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -306,6 +306,13 @@ def test_categorical(client: "Client") -> None: run_categorical(client, "approx", X, X_onehot, y) run_categorical(client, "hist", X, X_onehot, y) + ft = ["c"] * X.shape[1] + reg = xgb.dask.DaskXGBRegressor( + tree_method="hist", feature_types=ft, enable_categorical=True + ) + reg.fit(X, y) + assert reg.get_booster().feature_types == ft + def test_dask_predict_shape_infer(client: "Client") -> None: X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index a2e70ae6de2e..41554d461c93 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1277,6 +1277,32 @@ def test_estimator_reg(estimator, check): check(estimator) +def test_categorical(): + X, y = tm.make_categorical(n_samples=32, n_features=2, n_categories=3, onehot=False) + ft = ["c"] * X.shape[1] + reg = xgb.XGBRegressor( + tree_method="hist", + feature_types=ft, + max_cat_to_onehot=1, + enable_categorical=True, + ) + reg.fit(X.values, y, eval_set=[(X.values, y)]) + from_cat = reg.evals_result()["validation_0"]["rmse"] + predt_cat = reg.predict(X.values) + assert reg.get_booster().feature_types == ft + + onehot, y = tm.make_categorical( + n_samples=32, n_features=2, n_categories=3, onehot=True + ) + reg = xgb.XGBRegressor(tree_method="hist") + reg.fit(onehot, y, eval_set=[(onehot, y)]) + from_enc = reg.evals_result()["validation_0"]["rmse"] + predt_enc = reg.predict(onehot) + + np.testing.assert_allclose(from_cat, from_enc) + np.testing.assert_allclose(predt_cat, predt_enc) + + def test_prediction_config(): reg = xgb.XGBRegressor() assert reg._can_use_inplace_predict() is True From 9e88510634a3c296ff6a1fd6e680bea8ffe53d2f Mon Sep 17 00:00:00 2001 From: jiamingy Date: Wed, 20 Apr 2022 19:00:59 +0800 Subject: [PATCH 2/5] lint. --- python-package/xgboost/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index f5a935e09f8e..fbe4b0863af6 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -229,7 +229,7 @@ def _is_modin_df(data): } -_enable_cat_err = ( +_ENABLE_CAT_ERR = ( "When categorical type is supplied, DMatrix parameter " "`enable_categorical` must be set to `True`." ) @@ -249,7 +249,7 @@ def _invalid_dataframe_dtype(data: Any) -> None: err = "" type_err = "DataFrame.dtypes for data must be int, float, bool or category." - msg = f"""{type_err} {_enable_cat_err} {err}""" + msg = f"""{type_err} {_ENABLE_CAT_ERR} {err}""" raise ValueError(msg) @@ -880,7 +880,7 @@ def dispatch_data_backend( ): '''Dispatch data for DMatrix.''' if feature_types is not None and not enable_categorical and "c" in feature_types: - raise ValueError(_enable_cat_err) + raise ValueError(_ENABLE_CAT_ERR) if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) if _is_scipy_csr(data): From d52833d9ece32d1acf159bba1a2e5f0f94652f17 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Wed, 20 Apr 2022 19:01:35 +0800 Subject: [PATCH 3/5] mypy. --- python-package/xgboost/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 046322613bc3..1c537d365a73 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1648,7 +1648,7 @@ def _get_feature_info(self, field: str) -> Optional[List[str]]: feature_info = from_cstr_to_pystr(sarr, length) return feature_info if feature_info else None - def _set_feature_info(self, features: Optional[List[str]], field: str) -> None: + def _set_feature_info(self, features: Optional[Sequence[str]], field: str) -> None: if features is not None: assert isinstance(features, list) feature_info_bytes = [bytes(f, encoding="utf-8") for f in features] From 5985dc5e6d71d6985d54fef5184075ebb19ad4b8 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Wed, 20 Apr 2022 19:15:22 +0800 Subject: [PATCH 4/5] Test for IO. --- python-package/xgboost/sklearn.py | 3 +++ tests/python/test_with_sklearn.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 98e4b43713e8..f9545e62de87 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -729,6 +729,8 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None: meta['classes_'] = self.classes_.tolist() continue if k == "feature_types": + # Use the `feature_types` attribute from booster instead. + meta["feature_types"] = None continue try: json.dumps({k: v}) @@ -771,6 +773,7 @@ def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None: continue if k == "feature_types": self.feature_types = self.get_booster().feature_types + continue if k == "_estimator_type": if self._get_type() != v: raise TypeError( diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 41554d461c93..0228385a6556 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1290,6 +1290,12 @@ def test_categorical(): from_cat = reg.evals_result()["validation_0"]["rmse"] predt_cat = reg.predict(X.values) assert reg.get_booster().feature_types == ft + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "model.json") + reg.save_model(path) + reg = xgb.XGBRegressor() + reg.load_model(path) + assert reg.feature_types == ft onehot, y = tm.make_categorical( n_samples=32, n_features=2, n_categories=3, onehot=True From b43f1dbc4ae909e4884aa23727ea4c40de2d3de9 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Wed, 20 Apr 2022 20:43:08 +0800 Subject: [PATCH 5/5] Fix. --- python-package/xgboost/data.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index fbe4b0863af6..00d47599fe73 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -230,8 +230,8 @@ def _is_modin_df(data): _ENABLE_CAT_ERR = ( - "When categorical type is supplied, DMatrix parameter " - "`enable_categorical` must be set to `True`." + "When categorical type is supplied, DMatrix parameter `enable_categorical` must " + "be set to `True`." ) @@ -879,8 +879,6 @@ def dispatch_data_backend( enable_categorical: bool = False, ): '''Dispatch data for DMatrix.''' - if feature_types is not None and not enable_categorical and "c" in feature_types: - raise ValueError(_ENABLE_CAT_ERR) if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) if _is_scipy_csr(data):