diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 9a9177ff3d00..cf415b9e9afc 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -231,17 +231,6 @@ def _numpy2ctypes_type(dtype): return _NUMPY_TO_CTYPES_MAPPING[dtype] -def _array_interface(data: np.ndarray) -> bytes: - assert ( - data.dtype.hasobject is False - ), "Input data contains `object` dtype. Expecting numeric data." - interface = data.__array_interface__ - if "mask" in interface: - interface["mask"] = interface["mask"].__array_interface__ - interface_str = bytes(json.dumps(interface), "utf-8") - return interface_str - - def _cuda_array_interface(data) -> bytes: assert ( data.dtype.hasobject is False @@ -353,11 +342,17 @@ def next_wrapper(self, this): # pylint: disable=unused-argument if self.exception is not None: return 0 - def data_handle(data, feature_names=None, feature_types=None, **kwargs): + def data_handle( + data, + feature_names=None, + feature_types=None, + enable_categorical=False, + **kwargs + ): from .data import dispatch_device_quantile_dmatrix_set_data from .data import _device_quantile_transform data, feature_names, feature_types = _device_quantile_transform( - data, feature_names, feature_types + data, feature_names, feature_types, enable_categorical, ) dispatch_device_quantile_dmatrix_set_data(self.proxy, data) self.proxy.set_info( @@ -1023,7 +1018,7 @@ def _set_data_from_cuda_interface(self, data): def _set_data_from_cuda_columnar(self, data): '''Set data from CUDA columnar format.1''' from .data import _cudf_array_interfaces - interfaces_str = _cudf_array_interfaces(data) + _, interfaces_str = _cudf_array_interfaces(data) _check_call( _LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar( self.handle, @@ -1076,10 +1071,6 @@ def __init__( # pylint: disable=super-init-not-called self.handle = data return - if enable_categorical: - raise NotImplementedError( - 'categorical support is not enabled on DeviceQuantileDMatrix.' - ) if qid is not None and group is not None: raise ValueError( 'Only one of the eval_qid or eval_group for each evaluation ' @@ -1098,9 +1089,10 @@ def __init__( # pylint: disable=super-init-not-called feature_weights=feature_weights, feature_names=feature_names, feature_types=feature_types, + enable_categorical=enable_categorical, ) - def _init(self, data, feature_names, feature_types, **meta): + def _init(self, data, enable_categorical, **meta): from .data import ( _is_dlpack, _transform_dlpack, @@ -1114,9 +1106,13 @@ def _init(self, data, feature_names, feature_types, **meta): data = _transform_dlpack(data) if _is_iter(data): it = data + if enable_categorical: + raise NotImplementedError( + "categorical support is not enabled on data iterator." + ) else: it = SingleBatchInternalIter( - data, **meta, feature_names=feature_names, feature_types=feature_types + data=data, enable_categorical=enable_categorical, **meta ) reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper) @@ -1920,6 +1916,7 @@ def inplace_predict( f"got {data.shape[1]}" ) + from .data import _array_interface if isinstance(data, np.ndarray): from .data import _ensure_np_dtype data, _ = _ensure_np_dtype(data, data.dtype) @@ -1974,7 +1971,7 @@ def inplace_predict( if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): from .data import _cudf_array_interfaces - interfaces_str = _cudf_array_interfaces(data) + _, interfaces_str = _cudf_array_interfaces(data) _check_call( _LIB.XGBoosterPredictFromCudaColumnar( self.handle, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 2b0d0bb1f9ef..655051ad5c79 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,12 +5,12 @@ import json import warnings import os -from typing import Any +from typing import Any, Tuple import numpy as np from .core import c_array, _LIB, _check_call, c_str -from .core import _array_interface, _cuda_array_interface +from .core import _cuda_array_interface from .core import DataIter, _ProxyDMatrix, DMatrix from .compat import lazy_isinstance @@ -41,6 +41,17 @@ def _is_scipy_csr(data): return isinstance(data, scipy.sparse.csr_matrix) +def _array_interface(data: np.ndarray) -> bytes: + assert ( + data.dtype.hasobject is False + ), "Input data contains `object` dtype. Expecting numeric data." + interface = data.__array_interface__ + if "mask" in interface: + interface["mask"] = interface["mask"].__array_interface__ + interface_str = bytes(json.dumps(interface), "utf-8") + return interface_str + + def _from_scipy_csr(data, missing, nthread, feature_names, feature_types): """Initialize data from a CSR matrix.""" if len(data.indices) != len(data.data): @@ -179,7 +190,7 @@ def _is_modin_df(data): 'float16': 'float', 'float32': 'float', 'float64': 'float', - 'bool': 'i' + 'bool': 'i', } @@ -349,54 +360,73 @@ def _is_cudf_df(data): return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) -def _cudf_array_interfaces(data): - '''Extract CuDF __cuda_array_interface__''' +def _cudf_array_interfaces(data) -> Tuple[list, list]: + """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of + data and a list of array interfaces. The data is list of categorical codes that + caller can safely ignore, but have to keep their reference alive until usage of array + interface is finished. + + """ + from cudf.utils.dtypes import is_categorical_dtype + cat_codes = [] interfaces = [] if _is_cudf_ser(data): interfaces.append(data.__cuda_array_interface__) else: for col in data: - interface = data[col].__cuda_array_interface__ - if 'mask' in interface: - interface['mask'] = interface['mask'].__cuda_array_interface__ + if is_categorical_dtype(data[col].dtype): + codes = data[col].cat.codes + interface = codes.__cuda_array_interface__ + cat_codes.append(codes) + else: + interface = data[col].__cuda_array_interface__ + if "mask" in interface: + interface["mask"] = interface["mask"].__cuda_array_interface__ interfaces.append(interface) - interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8') - return interfaces_str + interfaces_str = bytes(json.dumps(interfaces, indent=2), "utf-8") + return cat_codes, interfaces_str + +def _transform_cudf_df(data, feature_names, feature_types, enable_categorical): + from cudf.utils.dtypes import is_categorical_dtype -def _transform_cudf_df(data, feature_names, feature_types): if feature_names is None: if _is_cudf_ser(data): feature_names = [data.name] - elif lazy_isinstance( - data.columns, 'cudf.core.multiindex', 'MultiIndex'): - feature_names = [ - ' '.join([str(x) for x in i]) - for i in data.columns - ] + elif lazy_isinstance(data.columns, "cudf.core.multiindex", "MultiIndex"): + feature_names = [" ".join([str(x) for x in i]) for i in data.columns] else: feature_names = data.columns.format() if feature_types is None: + feature_types = [] if _is_cudf_ser(data): dtypes = [data.dtype] else: dtypes = data.dtypes - feature_types = [_pandas_dtype_mapper[d.name] - for d in dtypes] + for dtype in dtypes: + if is_categorical_dtype(dtype) and enable_categorical: + feature_types.append("categorical") + else: + feature_types.append(_pandas_dtype_mapper[dtype.name]) return data, feature_names, feature_types -def _from_cudf_df(data, missing, nthread, feature_names, feature_types): +def _from_cudf_df( + data, missing, nthread, feature_names, feature_types, enable_categorical +): data, feature_names, feature_types = _transform_cudf_df( - data, feature_names, feature_types) - interfaces_str = _cudf_array_interfaces(data) + data, feature_names, feature_types, enable_categorical + ) + _, interfaces_str = _cudf_array_interfaces(data) handle = ctypes.c_void_p() _check_call( _LIB.XGDMatrixCreateFromArrayInterfaceColumns( interfaces_str, ctypes.c_float(missing), ctypes.c_int(nthread), - ctypes.byref(handle))) + ctypes.byref(handle), + ) + ) return handle, feature_names, feature_types @@ -554,12 +584,10 @@ def dispatch_data_backend(data, missing, threads, if _is_pandas_series(data): return _from_pandas_series(data, missing, threads, feature_names, feature_types) - if _is_cudf_df(data): - return _from_cudf_df(data, missing, threads, feature_names, - feature_types) - if _is_cudf_ser(data): - return _from_cudf_df(data, missing, threads, feature_names, - feature_types) + if _is_cudf_df(data) or _is_cudf_ser(data): + return _from_cudf_df( + data, missing, threads, feature_names, feature_types, enable_categorical + ) if _is_cupy_array(data): return _from_cupy_array(data, missing, threads, feature_names, feature_types) @@ -731,30 +759,8 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 area for meta info. ''' - def __init__( - self, data, - label, - weight, - base_margin, - group, - qid, - label_lower_bound, - label_upper_bound, - feature_weights, - feature_names, - feature_types - ): - self.data = data - self.label = label - self.weight = weight - self.base_margin = base_margin - self.group = group - self.qid = qid - self.label_lower_bound = label_lower_bound - self.label_upper_bound = label_upper_bound - self.feature_weights = feature_weights - self.feature_names = feature_names - self.feature_types = feature_types + def __init__(self, **kwargs): + self.kwargs = kwargs self.it = 0 # pylint: disable=invalid-name super().__init__() @@ -762,33 +768,24 @@ def next(self, input_data): if self.it == 1: return 0 self.it += 1 - input_data(data=self.data, label=self.label, - weight=self.weight, base_margin=self.base_margin, - group=self.group, - qid=self.qid, - label_lower_bound=self.label_lower_bound, - label_upper_bound=self.label_upper_bound, - feature_weights=self.feature_weights, - feature_names=self.feature_names, - feature_types=self.feature_types) + input_data(**self.kwargs) return 1 def reset(self): self.it = 0 -def _device_quantile_transform(data, feature_names, feature_types): - if _is_cudf_df(data): - return _transform_cudf_df(data, feature_names, feature_types) - if _is_cudf_ser(data): - return _transform_cudf_df(data, feature_names, feature_types) +def _device_quantile_transform(data, feature_names, feature_types, enable_categorical): + if _is_cudf_df(data) or _is_cudf_ser(data): + return _transform_cudf_df( + data, feature_names, feature_types, enable_categorical + ) if _is_cupy_array(data): data = _transform_cupy_array(data) return data, feature_names, feature_types if _is_dlpack(data): return _transform_dlpack(data), feature_names, feature_types - raise TypeError('Value type is not supported for data iterator:' + - str(type(data))) + raise TypeError("Value type is not supported for data iterator:" + str(type(data))) def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None: diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index cfda0a8e0db1..ca934bcdb6da 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -171,6 +171,21 @@ def test_cudf_metainfo_simple_dmatrix(self): def test_cudf_metainfo_device_dmatrix(self): _test_cudf_metainfo(xgb.DeviceQuantileDMatrix) + @pytest.mark.skipif(**tm.no_cudf()) + def test_categorical(self): + import cudf + _X, _y = tm.make_categorical(100, 30, 17, False) + X = cudf.from_pandas(_X) + y = cudf.from_pandas(_y) + + Xy = xgb.DMatrix(X, y, enable_categorical=True) + assert len(Xy.feature_types) == X.shape[1] + assert all(t == "categorical" for t in Xy.feature_types) + + Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True) + assert len(Xy.feature_types) == X.shape[1] + assert all(t == "categorical" for t in Xy.feature_types) + @pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cupy()) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 9e96b7c21d5f..dd2dd1973e9f 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -43,22 +43,8 @@ def test_gpu_hist(self, param, num_rounds, dataset): assert tm.non_increasing(result['train'][dataset.metric]) def run_categorical_basic(self, rows, cols, rounds, cats): - import pandas as pd - rng = np.random.RandomState(1994) - - pd_dict = {} - for i in range(cols): - c = rng.randint(low=0, high=cats+1, size=rows) - pd_dict[str(i)] = pd.Series(c, dtype=np.int64) - - df = pd.DataFrame(pd_dict) - label = df.iloc[:, 0] - for i in range(0, cols-1): - label += df.iloc[:, i] - label += 1 - df = df.astype('category') - onehot = pd.get_dummies(df) - cat = df + onehot, label = tm.make_categorical(rows, cols, cats, True) + cat, _ = tm.make_categorical(rows, cols, cats, False) by_etl_results = {} by_builtin_results = {} diff --git a/tests/python/testing.py b/tests/python/testing.py index 4b2b31e09aa0..2feeaf0a0e40 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -234,6 +234,34 @@ def get_mq2008(dpath): x_valid, y_valid, qid_valid) +@memory.cache +def make_categorical( + n_samples: int, n_features: int, n_categories: int, onehot_enc: bool +): + import pandas as pd + + rng = np.random.RandomState(1994) + + pd_dict = {} + for i in range(n_features + 1): + c = rng.randint(low=0, high=n_categories + 1, size=n_samples) + pd_dict[str(i)] = pd.Series(c, dtype=np.int64) + + df = pd.DataFrame(pd_dict) + label = df.iloc[:, 0] + df = df.iloc[:, 1:] + for i in range(0, n_features): + label += df.iloc[:, i] + label += 1 + + df = df.astype("category") + if onehot_enc: + cat = pd.get_dummies(df) + else: + cat = df + return cat, label + + _unweighted_datasets_strategy = strategies.sampled_from( [TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'), TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),