diff --git a/demo/guide-python/categorical.py b/demo/guide-python/categorical.py index 02ef7a8fc8c3..35d03e4f87b4 100644 --- a/demo/guide-python/categorical.py +++ b/demo/guide-python/categorical.py @@ -44,7 +44,8 @@ def make_categorical( def main() -> None: # Use builtin categorical data support - # Must be pandas DataFrame or cudf DataFrame with categorical data + # For scikit-learn interface, the input data must be pandas DataFrame or cudf + # DataFrame with categorical features X, y = make_categorical(100, 10, 4, False) # Specify `enable_categorical` to True. reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) diff --git a/include/xgboost/feature_map.h b/include/xgboost/feature_map.h index 6083bb95d541..43c52ed227c8 100644 --- a/include/xgboost/feature_map.h +++ b/include/xgboost/feature_map.h @@ -83,7 +83,7 @@ class FeatureMap { if (!strcmp("q", tname)) return kQuantitive; if (!strcmp("int", tname)) return kInteger; if (!strcmp("float", tname)) return kFloat; - if (!strcmp("categorical", tname)) return kCategorical; + if (!strcmp("c", tname)) return kCategorical; LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity"; return kIndicator; } diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 0bbfbca62ba2..8f6b3b6fde38 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -518,8 +518,8 @@ def __init__( base_margin=None, missing: Optional[float] = None, silent=False, - feature_names=None, - feature_types=None, + feature_names: Optional[List[str]] = None, + feature_types: Optional[List[str]] = None, nthread: Optional[int] = None, group=None, qid=None, @@ -558,8 +558,11 @@ def __init__( Whether print messages during construction feature_names : list, optional Set names for features. - feature_types : list, optional - Set types for features. + feature_types : + + Set types for features. When `enable_categorical` is set to `True`, string + "c" represents categorical data type. + nthread : integer, optional Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. @@ -577,11 +580,10 @@ def __init__( .. versionadded:: 1.3.0 - Experimental support of specializing for categorical features. Do - not set to True unless you are interested in development. - Currently it's only available for `gpu_hist` tree method with 1 vs - rest (one hot) categorical split. Also, JSON serialization format, - `gpu_predictor` and pandas input are required. + Experimental support of specializing for categorical features. Do not set to + True unless you are interested in development. Currently it's only available + for `gpu_hist` tree method with 1 vs rest (one hot) categorical split. Also, + JSON serialization format is required. """ if group is not None and qid is not None: @@ -673,8 +675,8 @@ def set_info( qid=None, label_lower_bound=None, label_upper_bound=None, - feature_names=None, - feature_types=None, + feature_names: Optional[List[str]] = None, + feature_types: Optional[List[str]] = None, feature_weights=None ) -> None: """Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`.""" @@ -945,7 +947,7 @@ def slice( return res @property - def feature_names(self) -> List[str]: + def feature_names(self) -> Optional[List[str]]: """Get feature names (column labels). Returns @@ -1033,17 +1035,21 @@ def feature_types(self) -> Optional[List[str]]: return res @feature_types.setter - def feature_types(self, feature_types: Optional[Union[List[Any], Any]]) -> None: + def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None: """Set feature types (column types). - This is for displaying the results and unrelated - to the learning process. + This is for displaying the results and categorical data support. See doc string + of :py:obj:`xgboost.DMatrix` for details. Parameters ---------- feature_types : list or None Labels for features. None will reset existing feature names + """ + # For compatibility reason this function wraps single str input into a list. But + # we should not promote such usage since other than visualization, the field is + # also used for specifying categorical data type. if feature_types is not None: if not isinstance(feature_types, (list, str)): raise TypeError( @@ -2461,8 +2467,13 @@ def _validate_features(self, data: DMatrix): raise ValueError(msg.format(self.feature_names, data.feature_names)) - def get_split_value_histogram(self, feature, fmap='', bins=None, - as_pandas=True): + def get_split_value_histogram( + self, + feature: str, + fmap: Union[os.PathLike, str] = '', + bins: Optional[int] = None, + as_pandas: bool = True + ): """Get split value histogram of a feature Parameters @@ -2510,7 +2521,7 @@ def get_split_value_histogram(self, feature, fmap='', bins=None, except (ValueError, AttributeError, TypeError): # None.index: attr err, None[0]: type err, fn.index(-1): value err feature_t = None - if feature_t == "categorical": + if feature_t == "c": # categorical raise ValueError( "Split value historgam doesn't support categorical split." ) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 714e304c3e4a..8a446e287c5f 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import json import warnings import os -from typing import Any, Tuple, Callable +from typing import Any, Tuple, Callable, Optional, List import numpy as np @@ -16,6 +16,8 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name +CAT_T = "c" + def _warn_unused_missing(data, missing): if (missing is not None) and (not np.isnan(missing)): @@ -57,7 +59,13 @@ def _array_interface(data: np.ndarray) -> bytes: return interface_str -def _from_scipy_csr(data, missing, nthread, feature_names, feature_types): +def _from_scipy_csr( + data, + missing, + nthread, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): """Initialize data from a CSR matrix.""" if len(data.indices) != len(data.data): raise ValueError( @@ -91,7 +99,12 @@ def _is_scipy_csc(data): return isinstance(data, scipy.sparse.csc_matrix) -def _from_scipy_csc(data, missing, feature_names, feature_types): +def _from_scipy_csc( + data, + missing, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): if len(data.indices) != len(data.data): raise ValueError('length mismatch: {} vs {}'.format( len(data.indices), len(data.data))) @@ -142,7 +155,13 @@ def _maybe_np_slice(data, dtype): return data -def _from_numpy_array(data, missing, nthread, feature_names, feature_types): +def _from_numpy_array( + data, + missing, + nthread, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): """Initialize data from a 2-D numpy matrix. """ @@ -199,9 +218,14 @@ def _is_modin_df(data): } -def _transform_pandas_df(data, enable_categorical, - feature_names=None, feature_types=None, - meta=None, meta_type=None): +def _transform_pandas_df( + data, + enable_categorical, + feature_names: Optional[List[str]] = None, + feature_types: Optional[List[str]] = None, + meta=None, + meta_type=None, +): from pandas import MultiIndex, Int64Index, RangeIndex from pandas.api.types import is_sparse, is_categorical_dtype @@ -236,7 +260,7 @@ def _transform_pandas_df(data, enable_categorical, feature_types.append(_pandas_dtype_mapper[ dtype.subtype.name]) elif is_categorical_dtype(dtype) and enable_categorical: - feature_types.append('categorical') + feature_types.append(CAT_T) else: feature_types.append(_pandas_dtype_mapper[dtype.name]) @@ -253,8 +277,14 @@ def _transform_pandas_df(data, enable_categorical, return data, feature_names, feature_types -def _from_pandas_df(data, enable_categorical, missing, nthread, - feature_names, feature_types): +def _from_pandas_df( + data, + enable_categorical: bool, + missing, + nthread, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): data, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types) return _from_numpy_array(data, missing, nthread, feature_names, @@ -277,9 +307,16 @@ def _is_modin_series(data): return isinstance(data, pd.Series) -def _from_pandas_series(data, missing, nthread, feature_types, feature_names): - return _from_numpy_array(data.values.astype('float'), missing, nthread, - feature_names, feature_types) +def _from_pandas_series( + data, + missing, + nthread, + feature_types: Optional[List[str]], + feature_names: Optional[List[str]], +): + return _from_numpy_array( + data.values.astype("float"), missing, nthread, feature_names, feature_types + ) def _is_dt_df(data): @@ -291,8 +328,13 @@ def _is_dt_df(data): _dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'} -def _transform_dt_df(data, feature_names, feature_types, meta=None, - meta_type=None): +def _transform_dt_df( + data, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], + meta=None, + meta_type=None, +): """Validate feature names and types if data table""" if meta and data.shape[1] > 1: raise ValueError( @@ -325,7 +367,16 @@ def _transform_dt_df(data, feature_names, feature_types, meta=None, return data, feature_names, feature_types -def _from_dt_df(data, missing, nthread, feature_names, feature_types): +def _from_dt_df( + data, + missing, + nthread, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], + enable_categorical: bool, +) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]: + if enable_categorical: + raise ValueError("categorical data in datatable is not supported yet.") data, feature_names, feature_types = _transform_dt_df( data, feature_names, feature_types, None, None) @@ -368,7 +419,7 @@ def _is_cudf_df(data): return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) -def _cudf_array_interfaces(data) -> Tuple[list, list]: +def _cudf_array_interfaces(data) -> Tuple[list, bytes]: """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of data and a list of array interfaces. The data is list of categorical codes that caller can safely ignore, but have to keep their reference alive until usage of array @@ -395,7 +446,12 @@ def _cudf_array_interfaces(data) -> Tuple[list, list]: return cat_codes, interfaces_str -def _transform_cudf_df(data, feature_names, feature_types, enable_categorical): +def _transform_cudf_df( + data, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], + enable_categorical: bool, +): from cudf.utils.dtypes import is_categorical_dtype if feature_names is None: @@ -413,14 +469,19 @@ def _transform_cudf_df(data, feature_names, feature_types, enable_categorical): dtypes = data.dtypes for dtype in dtypes: if is_categorical_dtype(dtype) and enable_categorical: - feature_types.append("categorical") + feature_types.append(CAT_T) else: feature_types.append(_pandas_dtype_mapper[dtype.name]) return data, feature_names, feature_types def _from_cudf_df( - data, missing, nthread, feature_names, feature_types, enable_categorical + data, + missing, + nthread, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], + enable_categorical: bool, ) -> Tuple[ctypes.c_void_p, Any, Any]: data, feature_names, feature_types = _transform_cudf_df( data, feature_names, feature_types, enable_categorical @@ -464,7 +525,13 @@ def _transform_cupy_array(data): return data -def _from_cupy_array(data, missing, nthread, feature_names, feature_types): +def _from_cupy_array( + data, + missing, + nthread, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): """Initialize DMatrix from cupy ndarray.""" data = _transform_cupy_array(data) interface_str = _cuda_array_interface(data) @@ -505,7 +572,13 @@ def _transform_dlpack(data): return data -def _from_dlpack(data, missing, nthread, feature_names, feature_types): +def _from_dlpack( + data, + missing, + nthread, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): data = _transform_dlpack(data) return _from_cupy_array(data, missing, nthread, feature_names, feature_types) @@ -515,7 +588,12 @@ def _is_uri(data): return isinstance(data, (str, os.PathLike)) -def _from_uri(data, missing, feature_names, feature_types): +def _from_uri( + data, + missing, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): _warn_unused_missing(data, missing) handle = ctypes.c_void_p() data = os.fspath(os.path.expanduser(data)) @@ -529,7 +607,13 @@ def _is_list(data): return isinstance(data, list) -def _from_list(data, missing, n_threads, feature_names, feature_types): +def _from_list( + data, + missing, + n_threads, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): array = np.array(data) _check_data_shape(data) return _from_numpy_array(array, missing, n_threads, feature_names, feature_types) @@ -539,7 +623,13 @@ def _is_tuple(data): return isinstance(data, tuple) -def _from_tuple(data, missing, n_threads, feature_names, feature_types): +def _from_tuple( + data, + missing, + n_threads, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], +): return _from_list(data, missing, n_threads, feature_names, feature_types) @@ -569,9 +659,14 @@ def _convert_unknown_data(data): return data -def dispatch_data_backend(data, missing, threads, - feature_names, feature_types, - enable_categorical=False): +def dispatch_data_backend( + data, + missing, + threads, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], + enable_categorical: bool = False, +): '''Dispatch data for DMatrix.''' if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) @@ -580,7 +675,9 @@ def dispatch_data_backend(data, missing, threads, if _is_scipy_csc(data): return _from_scipy_csc(data, missing, feature_names, feature_types) if _is_scipy_coo(data): - return _from_scipy_csr(data.tocsr(), missing, threads, feature_names, feature_types) + return _from_scipy_csr( + data.tocsr(), missing, threads, feature_names, feature_types + ) if _is_numpy_array(data): return _from_numpy_array(data, missing, threads, feature_names, feature_types) @@ -612,8 +709,9 @@ def dispatch_data_backend(data, missing, threads, feature_types) if _is_dt_df(data): _warn_unused_missing(data, missing) - return _from_dt_df(data, missing, threads, feature_names, - feature_types) + return _from_dt_df( + data, missing, threads, feature_names, feature_types, enable_categorical + ) if _is_modin_df(data): return _from_pandas_df(data, enable_categorical, missing, threads, feature_names, feature_types) @@ -791,7 +889,12 @@ def reset(self) -> None: self.it = 0 -def _proxy_transform(data, feature_names, feature_types, enable_categorical): +def _proxy_transform( + data, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], + enable_categorical: bool, +): if _is_cudf_df(data) or _is_cudf_ser(data): return _transform_cudf_df( data, feature_names, feature_types, enable_categorical diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index ed91c4b77f9c..ed115200b0d7 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -174,8 +174,7 @@ def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]: .. versionadded:: 1.5.0 Experimental support for categorical data. Do not set to true unless you are - interested in development. Only valid when `gpu_hist` and pandas dataframe are - used. + interested in development. Only valid when `gpu_hist` and dataframe are used. kwargs : dict, optional Keyword arguments for XGBoost Booster object. Full documentation of diff --git a/src/data/data.cc b/src/data/data.cc index a0504f4d5c8b..2ef5e2a1dbdb 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -200,10 +200,10 @@ void LoadFeatureType(std::vectorconst& type_names, std::vectoremplace_back(FeatureType::kNumerical); } else if (elem == "q") { types->emplace_back(FeatureType::kNumerical); - } else if (elem == "categorical") { + } else if (elem == "c") { types->emplace_back(FeatureType::kCategorical); } else { - LOG(FATAL) << "All feature_types must be one of {int, float, i, q, categorical}."; + LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}."; } } } diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 4f20ba69bf08..a1e64f542b20 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -285,7 +285,7 @@ void TestCategoricalTreeDump(std::string format, std::string sep) { pos = str.find(cond_str, pos + 1); ASSERT_NE(pos, std::string::npos); - fmap.PushBack(0, "feat_0", "categorical"); + fmap.PushBack(0, "feat_0", "c"); fmap.PushBack(1, "feat_1", "q"); fmap.PushBack(2, "feat_2", "int"); diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 36bf6023071c..fc46d7502044 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -172,7 +172,7 @@ def test_cudf_metainfo_device_dmatrix(self): _test_cudf_metainfo(xgb.DeviceQuantileDMatrix) @pytest.mark.skipif(**tm.no_cudf()) - def test_categorical(self): + def test_cudf_categorical(self): import cudf _X, _y = tm.make_categorical(100, 30, 17, False) X = cudf.from_pandas(_X) @@ -180,11 +180,11 @@ def test_categorical(self): Xy = xgb.DMatrix(X, y, enable_categorical=True) assert len(Xy.feature_types) == X.shape[1] - assert all(t == "categorical" for t in Xy.feature_types) + assert all(t == "c" for t in Xy.feature_types) Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True) assert len(Xy.feature_types) == X.shape[1] - assert all(t == "categorical" for t in Xy.feature_types) + assert all(t == "c" for t in Xy.feature_types) @pytest.mark.skipif(**tm.no_cudf()) diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index 60b73e67546c..d0504e575e8f 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -169,6 +169,19 @@ def test_dlpack_simple_dmat(self): X = cp.random.random((n, 2)) xgb.DMatrix(X.toDlpack()) + @pytest.mark.skipif(**tm.no_cupy()) + def test_cupy_categorical(self): + import cupy as cp + n_features = 10 + X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False) + X = cp.asarray(X.values.astype(cp.float32)) + y = cp.array(y) + feature_types = ['c'] * n_features + + assert isinstance(X, cp.ndarray) + Xy = xgb.DMatrix(X, y, feature_types=feature_types) + np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types)) + @pytest.mark.skipif(**tm.no_cupy()) def test_dlpack_device_dmat(self): import cupy as cp diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index c3e6a0dadfff..313cf822fb98 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -339,3 +339,44 @@ class Data: Xy = xgb.DMatrix(X, y) assert Xy.num_row() == 10 assert Xy.num_col() == 10 + + @pytest.mark.skipif(**tm.no_pandas()) + def test_np_categorical(self): + n_features = 10 + X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False) + X = X.values.astype(np.float32) + feature_types = ['c'] * n_features + + assert isinstance(X, np.ndarray) + Xy = xgb.DMatrix(X, y, feature_types=feature_types) + np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types)) + + def test_scipy_categorical(self): + from scipy import sparse + n_features = 10 + X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False) + X = X.values.astype(np.float32) + feature_types = ['c'] * n_features + + X[1, 3] = np.NAN + X[2, 4] = np.NAN + X = sparse.csr_matrix(X) + + Xy = xgb.DMatrix(X, y, feature_types=feature_types) + np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types)) + + X = sparse.csc_matrix(X) + + Xy = xgb.DMatrix(X, y, feature_types=feature_types) + np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types)) + + X = sparse.coo_matrix(X) + + Xy = xgb.DMatrix(X, y, feature_types=feature_types) + np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types)) + + def test_uri_categorical(self): + path = os.path.join(dpath, 'agaricus.txt.train') + feature_types = ["q"] * 5 + ["c"] + ["q"] * 120 + Xy = xgb.DMatrix(path + "?indexing_mode=1", feature_types=feature_types) + np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types)) diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index c9400bfa697f..dbd8ceb7e637 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -128,7 +128,7 @@ def test_pandas_categorical(self): X = pd.DataFrame({'f0': X}) y = rng.randn(rows) m = xgb.DMatrix(X, y, enable_categorical=True) - assert m.feature_types[0] == 'categorical' + assert m.feature_types[0] == 'c' def test_pandas_sparse(self): import pandas as pd