From c91bc5c823c5514f8d4886901baebb320f0b7cc0 Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 17 Oct 2021 22:12:03 +0800 Subject: [PATCH 01/16] Handle missing values in dataframe with category dtype. * Replace -1 in pandas/cudf initializer. * Unify `IsValid` functor. * Mimic pandas data handling in cuDF glue code. * Check invalid categories. --- python-package/xgboost/core.py | 21 ++++--- python-package/xgboost/data.py | 91 ++++++++++++++++++++---------- src/common/categorical.h | 7 ++- src/data/adapter.h | 19 +++++++ src/data/data.cc | 7 ++- src/data/device_adapter.cuh | 23 -------- src/tree/updater_gpu_hist.cu | 1 + tests/python-gpu/test_from_cudf.py | 6 ++ tests/python/test_with_pandas.py | 6 ++ 9 files changed, 116 insertions(+), 65 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 05b354e21930..8dd9f9dd4144 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -397,7 +397,7 @@ def data_handle( feature_names: Optional[List[str]] = None, feature_types: Optional[List[str]] = None, **kwargs: Any, - ): + ) -> None: from .data import dispatch_proxy_set_data from .data import _proxy_transform @@ -409,7 +409,9 @@ def data_handle( ) # Stage the data, meta info are copied inside C++ MetaInfo. self._temporary_data = transformed - dispatch_proxy_set_data(self.proxy, transformed, self._allow_host) + dispatch_proxy_set_data( + self.proxy, transformed, self._allow_host, self._enable_categorical + ) self.proxy.set_info( feature_names=feature_names, feature_types=feature_types, @@ -1090,7 +1092,7 @@ def __init__(self): # pylint: disable=super-init-not-called self.handle = ctypes.c_void_p() _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle))) - def _set_data_from_cuda_interface(self, data): + def _set_data_from_cuda_interface(self, data) -> None: """Set data from CUDA array interface.""" interface = data.__cuda_array_interface__ interface_str = bytes(json.dumps(interface, indent=2), "utf-8") @@ -1098,11 +1100,11 @@ def _set_data_from_cuda_interface(self, data): _LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str) ) - def _set_data_from_cuda_columnar(self, data): + def _set_data_from_cuda_columnar(self, data, enable_categorical: bool) -> None: """Set data from CUDA columnar format.""" from .data import _cudf_array_interfaces - _, interfaces_str = _cudf_array_interfaces(data) + _, interfaces_str = _cudf_array_interfaces(data, enable_categorical) _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) def _set_data_from_array(self, data: np.ndarray): @@ -2009,13 +2011,18 @@ def inplace_predict( from .data import _is_pandas_df, _transform_pandas_df from .data import _array_interface - if _is_pandas_df(data): + if ( + _is_pandas_df(data) + or lazy_isinstance(data, "cudf.core.dataframe", "DataFrame") + ): ft = self.feature_types if ft is None: enable_categorical = False else: enable_categorical = any(f == "c" for f in ft) + if _is_pandas_df(data): data, _, _ = _transform_pandas_df(data, enable_categorical) + if isinstance(data, np.ndarray): from .data import _ensure_np_dtype data, _ = _ensure_np_dtype(data, data.dtype) @@ -2069,7 +2076,7 @@ def inplace_predict( return _prediction_output(shape, dims, preds, True) if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): from .data import _cudf_array_interfaces - _, interfaces_str = _cudf_array_interfaces(data) + _, interfaces_str = _cudf_array_interfaces(data, enable_categorical) _check_call( _LIB.XGBoosterPredictFromCudaColumnar( self.handle, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 8908a3a58b62..f3f18fb85e30 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -12,7 +12,7 @@ from .core import c_array, _LIB, _check_call, c_str from .core import _cuda_array_interface from .core import DataIter, _ProxyDMatrix, DMatrix -from .compat import lazy_isinstance +from .compat import lazy_isinstance, DataFrame c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name @@ -217,36 +217,42 @@ def _is_modin_df(data): } +def _invalid_dataframe_dtype(data) -> None: + bad_fields = [ + str(data.columns[i]) + for i, dtype in enumerate(data.dtypes) + if dtype.name not in _pandas_dtype_mapper + ] + + msg = """DataFrame.dtypes for data must be int, float, bool or category. When + categorical type is supplied, DMatrix parameter `enable_categorical` must + be set to `True`.""" + raise ValueError(msg + ", ".join(bad_fields)) + + def _transform_pandas_df( - data, + data: DataFrame, enable_categorical: bool, feature_names: Optional[List[str]] = None, feature_types: Optional[List[str]] = None, - meta=None, - meta_type=None, -): + meta: Optional[str] = None, + meta_type: Optional[str] = None, +) -> Tuple[np.ndarray, Optional[List[str]], Optional[List[str]]]: import pandas as pd from pandas.api.types import is_sparse, is_categorical_dtype - if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) or - (is_categorical_dtype(dtype) and enable_categorical) - for dtype in data.dtypes): - bad_fields = [ - str(data.columns[i]) for i, dtype in enumerate(data.dtypes) - if dtype.name not in _pandas_dtype_mapper - ] - - msg = """DataFrame.dtypes for data must be int, float, bool or category. When - categorical type is supplied, DMatrix parameter `enable_categorical` must - be set to `True`.""" - raise ValueError(msg + ', '.join(bad_fields)) + if not all( + dtype.name in _pandas_dtype_mapper + or is_sparse(dtype) + or (is_categorical_dtype(dtype) and enable_categorical) + for dtype in data.dtypes + ): + _invalid_dataframe_dtype(data) # handle feature names if feature_names is None and meta is None: if isinstance(data.columns, pd.MultiIndex): - feature_names = [ - ' '.join([str(x) for x in i]) for i in data.columns - ] + feature_names = [" ".join([str(x) for x in i]) for i in data.columns] elif isinstance(data.columns, (pd.Int64Index, pd.RangeIndex)): feature_names = list(map(str, data.columns)) else: @@ -263,21 +269,22 @@ def _transform_pandas_df( else: feature_types.append(_pandas_dtype_mapper[dtype.name]) - # handle categorical codes. + # handle category codes. transformed = pd.DataFrame() if enable_categorical: for i, dtype in enumerate(data.dtypes): if is_categorical_dtype(dtype): - transformed[data.columns[i]] = data[data.columns[i]].cat.codes + # pandas uses -1 as default missing value for categorical data + transformed[data.columns[i]] = data[data.columns[i]].cat.codes.replace( + -1, np.NaN + ) else: transformed[data.columns[i]] = data[data.columns[i]] else: transformed = data if meta and len(data.columns) > 1: - raise ValueError( - f"DataFrame for {meta} cannot have multiple columns" - ) + raise ValueError(f"DataFrame for {meta} cannot have multiple columns") dtype = meta_type if meta_type else np.float32 arr = transformed.values @@ -287,7 +294,7 @@ def _transform_pandas_df( def _from_pandas_df( - data, + data: DataFrame, enable_categorical: bool, missing: float, nthread: int, @@ -299,6 +306,7 @@ def _from_pandas_df( ) return _from_numpy_array(data, missing, nthread, feature_names, feature_types) + def _is_pandas_series(data): try: import pandas as pd @@ -427,7 +435,7 @@ def _is_cudf_df(data): return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) -def _cudf_array_interfaces(data) -> Tuple[list, bytes]: +def _cudf_array_interfaces(data, enable_categorical: bool) -> Tuple[list, bytes]: """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of data and a list of array interfaces. The data is list of categorical codes that caller can safely ignore, but have to keep their reference alive until usage of array @@ -445,8 +453,8 @@ def _cudf_array_interfaces(data) -> Tuple[list, bytes]: interfaces.append(data.__cuda_array_interface__) else: for col in data: - if is_categorical_dtype(data[col].dtype): - codes = data[col].cat.codes + if is_categorical_dtype(data[col].dtype) and enable_categorical: + codes = data[col].cat.codes.replace(-1, np.NaN) interface = codes.__cuda_array_interface__ cat_codes.append(codes) else: @@ -469,13 +477,31 @@ def _transform_cudf_df( except ImportError: from cudf.utils.dtypes import is_categorical_dtype + # FIXME(jiamingy): Handle `is_sparse` once we have support for sparse DF. + if not all( + dtype.name in _pandas_dtype_mapper + or (is_categorical_dtype(dtype) and enable_categorical) + for dtype in data.dtypes + ): + _invalid_dataframe_dtype(data) + + # handle feature names if feature_names is None: if _is_cudf_ser(data): feature_names = [data.name] elif lazy_isinstance(data.columns, "cudf.core.multiindex", "MultiIndex"): feature_names = [" ".join([str(x) for x in i]) for i in data.columns] + elif ( + lazy_isinstance(data.columns, "cudf.core.index", "RangeIndex") + or lazy_isinstance(data.columns, "cudf.core.index", "Int64Index") + # Unique to cuDF, no equivalence in pandas 1.3.3 + or lazy_isinstance(data.columns, "cudf.core.index", "Int32Index") + ): + feature_names = list(map(str, data.columns)) else: feature_names = data.columns.format() + + # handle feature types if feature_types is None: feature_types = [] if _is_cudf_ser(data): @@ -487,6 +513,7 @@ def _transform_cudf_df( feature_types.append(CAT_T) else: feature_types.append(_pandas_dtype_mapper[dtype.name]) + return data, feature_names, feature_types @@ -501,7 +528,7 @@ def _from_cudf_df( data, feature_names, feature_types = _transform_cudf_df( data, feature_names, feature_types, enable_categorical ) - _, interfaces_str = _cudf_array_interfaces(data) + _, interfaces_str = _cudf_array_interfaces(data, enable_categorical) handle = ctypes.c_void_p() config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8") _check_call( @@ -931,7 +958,9 @@ def _proxy_transform( raise TypeError("Value type is not supported for data iterator:" + str(type(data))) -def dispatch_proxy_set_data(proxy: _ProxyDMatrix, data: Any, allow_host: bool) -> None: +def dispatch_proxy_set_data( + proxy: _ProxyDMatrix, data: Any, allow_host: bool, enable_categorical: bool +) -> None: """Dispatch for DeviceQuantileDMatrix.""" if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) diff --git a/src/common/categorical.h b/src/common/categorical.h index 371ae1bd6d8d..3706c4f2370d 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -1,5 +1,5 @@ /*! - * Copyright 2020 by XGBoost Contributors + * Copyright 2020-2021 by XGBoost Contributors * \file categorical.h */ #ifndef XGBOOST_COMMON_CATEGORICAL_H_ @@ -42,6 +42,11 @@ inline XGBOOST_DEVICE bool Decision(common::Span cats, bst_cat_t return !s_cats.Check(cat); } +inline void CheckCat(bst_cat_t cat) { + CHECK_GE(cat, 0) << "Invalid categorical value detected. Categorical value " + "should be non-negative."; +} + struct IsCatOp { XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; diff --git a/src/data/adapter.h b/src/data/adapter.h index 8502ebd3432f..27da8c6e3b36 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -21,6 +21,7 @@ #include "array_interface.h" #include "../c_api/c_api_error.h" +#include "../common/math.h" namespace xgboost { namespace data { @@ -80,6 +81,24 @@ struct COOTuple { float value{0}; }; +struct IsValidFunctor { + float missing; + + XGBOOST_DEVICE explicit IsValidFunctor(float missing) : missing(missing) {} + + XGBOOST_DEVICE bool operator()(float value) const { + return !(common::CheckNAN(value) || value == missing); + } + + XGBOOST_DEVICE bool operator()(const data::COOTuple& e) const { + return !(common::CheckNAN(e.value) || e.value == missing); + } + + XGBOOST_DEVICE bool operator()(const Entry& e) const { + return !(common::CheckNAN(e.fvalue) || e.fvalue == missing); + } +}; + namespace detail { /** diff --git a/src/data/data.cc b/src/data/data.cc index 2ef5e2a1dbdb..7011d558294c 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -987,18 +987,19 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread // Second pass over batch, placing elements in correct position + auto is_valid = data::IsValidFunctor{missing}; #pragma omp parallel num_threads(nthread) { exec.Run([&]() { int tid = omp_get_thread_num(); - size_t begin = tid*thread_size; - size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size; + size_t begin = tid * thread_size; + size_t end = tid != (nthread - 1) ? (tid + 1) * thread_size : batch_size; for (size_t i = begin; i < end; ++i) { auto line = batch.GetLine(i); for (auto j = 0ull; j < line.Size(); j++) { auto element = line.GetElement(j); const size_t key = (element.row_idx - base_rowid); - if (!common::CheckNAN(element.value) && element.value != missing) { + if (is_valid(element)) { builder.Push(key, Entry(element.column_idx, element.value), tid); } } diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 2da786969fcf..628878f319f1 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -15,29 +15,6 @@ namespace xgboost { namespace data { -struct IsValidFunctor : public thrust::unary_function { - float missing; - - XGBOOST_DEVICE explicit IsValidFunctor(float missing) : missing(missing) {} - - __device__ bool operator()(float value) const { - return !(common::CheckNAN(value) || value == missing); - } - - __device__ bool operator()(const data::COOTuple& e) const { - if (common::CheckNAN(e.value) || e.value == missing) { - return false; - } - return true; - } - __device__ bool operator()(const Entry& e) const { - if (common::CheckNAN(e.fvalue) || e.fvalue == missing) { - return false; - } - return true; - } -}; - class CudfAdapterBatch : public detail::NoMetaInfo { friend class CudfAdapter; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 3824691948cf..cbe63d243da4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -580,6 +580,7 @@ struct GPUHistMakerDevice { CHECK_LT(candidate.split.fvalue, std::numeric_limits::max()) << "Categorical feature value too large."; auto cat = common::AsCat(candidate.split.fvalue); + common::CheckCat(cat); std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0); LBitField32 cats_bits(split_cats); cats_bits.Set(cat); diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index fc46d7502044..4dcb4a330f9a 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -186,6 +186,12 @@ def test_cudf_categorical(self): assert len(Xy.feature_types) == X.shape[1] assert all(t == "c" for t in Xy.feature_types) + # test missing value + X = cudf.DataFrame({"f0": ["a", "b", np.NaN]}) + X["f0"] = X["f0"].astype("category") + arr, _, _ = xgb.data._cudf_array_interfaces(X, enable_categorical=True) + assert not np.any(arr == -1) + @pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cupy()) diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 67d1d66f65ba..678556281ef6 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -141,6 +141,12 @@ def test_pandas_categorical(self): assert np.issubdtype(transformed[:, 0].dtype, np.integer) assert transformed[:, 0].min() == 0 + # test missing value + X = pd.DataFrame({"f0": ["a", "b", np.NaN]}) + X["f0"] = X["f0"].astype("category") + arr, _, _ = xgb.data._transform_pandas_df(X, enable_categorical=True) + assert not np.any(arr == -1) + def test_pandas_sparse(self): import pandas as pd rows = 100 From 147f58248548f9388fbeae635ab84aa9f8aade72 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 18 Oct 2021 02:38:03 +0800 Subject: [PATCH 02/16] Use float type. --- python-package/xgboost/data.py | 12 +++++++----- tests/python/test_with_pandas.py | 3 +-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index f3f18fb85e30..b1af41192962 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -225,8 +225,8 @@ def _invalid_dataframe_dtype(data) -> None: ] msg = """DataFrame.dtypes for data must be int, float, bool or category. When - categorical type is supplied, DMatrix parameter `enable_categorical` must - be set to `True`.""" +categorical type is supplied, DMatrix parameter `enable_categorical` must +be set to `True`.""" raise ValueError(msg + ", ".join(bad_fields)) @@ -275,8 +275,10 @@ def _transform_pandas_df( for i, dtype in enumerate(data.dtypes): if is_categorical_dtype(dtype): # pandas uses -1 as default missing value for categorical data - transformed[data.columns[i]] = data[data.columns[i]].cat.codes.replace( - -1, np.NaN + transformed[data.columns[i]] = ( + data[data.columns[i]] + .cat.codes.astype(np.float32) + .replace(-1.0, np.NaN) ) else: transformed[data.columns[i]] = data[data.columns[i]] @@ -454,7 +456,7 @@ def _cudf_array_interfaces(data, enable_categorical: bool) -> Tuple[list, bytes] else: for col in data: if is_categorical_dtype(data[col].dtype) and enable_categorical: - codes = data[col].cat.codes.replace(-1, np.NaN) + codes = data[col].cat.codes.astype(np.float32).replace(-1.0, np.NaN) interface = codes.__cuda_array_interface__ cat_codes.append(codes) else: diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 678556281ef6..730801af83d3 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -138,14 +138,13 @@ def test_pandas_categorical(self): X, enable_categorical=True ) - assert np.issubdtype(transformed[:, 0].dtype, np.integer) assert transformed[:, 0].min() == 0 # test missing value X = pd.DataFrame({"f0": ["a", "b", np.NaN]}) X["f0"] = X["f0"].astype("category") arr, _, _ = xgb.data._transform_pandas_df(X, enable_categorical=True) - assert not np.any(arr == -1) + assert not np.any(arr == -1.0) def test_pandas_sparse(self): import pandas as pd From b6286bde97b3d5111bfd8490b9b3c9bf62cc4d8d Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 18 Oct 2021 02:45:53 +0800 Subject: [PATCH 03/16] Lint. --- python-package/xgboost/data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index b1af41192962..cd58c5ba36b1 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -479,7 +479,6 @@ def _transform_cudf_df( except ImportError: from cudf.utils.dtypes import is_categorical_dtype - # FIXME(jiamingy): Handle `is_sparse` once we have support for sparse DF. if not all( dtype.name in _pandas_dtype_mapper or (is_categorical_dtype(dtype) and enable_categorical) @@ -967,10 +966,12 @@ def dispatch_proxy_set_data( if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) if _is_cudf_df(data): - proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212 + # pylint: disable=W0212 + proxy._set_data_from_cuda_columnar(data, enable_categorical) return if _is_cudf_ser(data): - proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212 + # pylint: disable=W0212 + proxy._set_data_from_cuda_columnar(data, enable_categorical) return if _is_cupy_array(data): proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 From 6c329efa8e259c96ecc41f1def7aa6f07ba813b1 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 21 Oct 2021 18:34:54 +0800 Subject: [PATCH 04/16] Fix dtypes with cuDF series. --- python-package/xgboost/data.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index cd58c5ba36b1..7ccf11515525 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -479,10 +479,15 @@ def _transform_cudf_df( except ImportError: from cudf.utils.dtypes import is_categorical_dtype + if _is_cudf_ser(data): + dtypes = [data.dtype] + else: + dtypes = data.dtypes + if not all( dtype.name in _pandas_dtype_mapper or (is_categorical_dtype(dtype) and enable_categorical) - for dtype in data.dtypes + for dtype in dtypes ): _invalid_dataframe_dtype(data) @@ -505,10 +510,6 @@ def _transform_cudf_df( # handle feature types if feature_types is None: feature_types = [] - if _is_cudf_ser(data): - dtypes = [data.dtype] - else: - dtypes = data.dtypes for dtype in dtypes: if is_categorical_dtype(dtype) and enable_categorical: feature_types.append(CAT_T) From 30d2e17e347182ffaa74414ca01607f094f7d9a8 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 21 Oct 2021 18:50:02 +0800 Subject: [PATCH 05/16] Handle series. --- python-package/xgboost/core.py | 5 +++++ python-package/xgboost/data.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 8dd9f9dd4144..68003365df6e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -400,6 +400,11 @@ def data_handle( ) -> None: from .data import dispatch_proxy_set_data from .data import _proxy_transform + ec = kwargs.get("enable_categorical", None) + if ec is not None and ec != self._enable_categorical: + raise ValueError( + "`enable_categorical` should be specifed in DMatrix constructor" + ) transformed, feature_names, feature_types = _proxy_transform( data, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 7ccf11515525..11c10d1e6d81 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -452,11 +452,20 @@ def _cudf_array_interfaces(data, enable_categorical: bool) -> Tuple[list, bytes] cat_codes = [] interfaces = [] if _is_cudf_ser(data): - interfaces.append(data.__cuda_array_interface__) + if is_categorical_dtype(data.dtype) and enable_categorical: + codes = data.cat.codes.astype(np.float32).replace(-1.0, np.NaN) + cat_codes.append(codes) + interface = codes.__cuda_array_interface__ + else: + interface = data.__cuda_array_interface__ + if "mask" in interface: + interface["mask"] = interface["mask"].__cuda_array_interface__ + interfaces.append(interface) else: for col in data: if is_categorical_dtype(data[col].dtype) and enable_categorical: codes = data[col].cat.codes.astype(np.float32).replace(-1.0, np.NaN) + print(codes, codes.shape) interface = codes.__cuda_array_interface__ cat_codes.append(codes) else: From 435a3ab567fe8be6cb78f637f448d6492edb9e47 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 21 Oct 2021 20:35:03 +0800 Subject: [PATCH 06/16] Fix unweighted data. --- src/common/hist_util.cu | 1 + src/common/hist_util.cuh | 25 +++++++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 33a7c268a6f5..2d3dff0545bf 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -133,6 +133,7 @@ void RemoveDuplicatedCategories( int32_t device, MetaInfo const &info, Span d_cuts_ptr, dh::device_vector *p_sorted_entries, dh::caching_device_vector *p_column_sizes_scan) { + info.feature_types.SetDevice(device); auto d_feature_types = info.feature_types.ConstDeviceSpan(); CHECK(!d_feature_types.empty()); auto &column_sizes_scan = *p_column_sizes_scan; diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 5f2e2add6bdc..c6ca000d095d 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -124,6 +124,11 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, void SortByWeight(dh::device_vector* weights, dh::device_vector* sorted_entries); + +void RemoveDuplicatedCategories( + int32_t device, MetaInfo const &info, Span d_cuts_ptr, + dh::device_vector *p_sorted_entries, + dh::caching_device_vector *p_column_sizes_scan); } // namespace detail // Compute sketch on DMatrix. @@ -132,9 +137,10 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t sketch_batch_num_elements = 0); template -void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, - size_t begin, size_t end, float missing, - SketchContainer* sketch_container, int num_cuts) { +void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info, + int device, size_t columns, size_t begin, size_t end, + float missing, SketchContainer *sketch_container, + int num_cuts) { // Copy current subset of valid elements into temporary storage and sort dh::device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; @@ -142,6 +148,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return batch.GetElement(idx); }); HostDeviceVector cuts_ptr; + cuts_ptr.SetDevice(device); detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, columns, num_cuts, device, &cuts_ptr, @@ -151,8 +158,14 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); - auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); + if (sketch_container->HasCategorical()) { + auto d_cuts_ptr = cuts_ptr.DeviceSpan(); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, + &sorted_entries, &column_sizes_scan); + } + auto d_cuts_ptr = cuts_ptr.DeviceSpan(); + auto const &h_cuts_ptr = cuts_ptr.HostVector(); // Extract the cuts from all columns concurrently sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, @@ -274,8 +287,8 @@ void AdapterDeviceSketch(Batch batch, int num_bins, device, num_cuts_per_feature, false); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); - ProcessSlidingWindow(batch, device, num_cols, - begin, end, missing, sketch_container, num_cuts_per_feature); + ProcessSlidingWindow(batch, info, device, num_cols, begin, end, missing, + sketch_container, num_cuts_per_feature); } } } From 01b53927891231b6cc4965797fc75f598ccb6b8f Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 21 Oct 2021 20:36:01 +0800 Subject: [PATCH 07/16] debug workaround. TODO need to move the transform out of array interface getter. --- python-package/xgboost/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 68003365df6e..4f189a85361a 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1109,7 +1109,7 @@ def _set_data_from_cuda_columnar(self, data, enable_categorical: bool) -> None: """Set data from CUDA columnar format.""" from .data import _cudf_array_interfaces - _, interfaces_str = _cudf_array_interfaces(data, enable_categorical) + self.codes, interfaces_str = _cudf_array_interfaces(data, enable_categorical) _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) def _set_data_from_array(self, data: np.ndarray): From b9042f2ce1561dd54664fd0c6bb0208caa91d820 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 21 Oct 2021 20:47:59 +0800 Subject: [PATCH 08/16] Start working on the data staging for cat codes. --- python-package/xgboost/data.py | 42 ++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 11c10d1e6d81..8e36ecd08878 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -437,7 +437,7 @@ def _is_cudf_df(data): return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) -def _cudf_array_interfaces(data, enable_categorical: bool) -> Tuple[list, bytes]: +def _cudf_array_interfaces(data, cat_codes: list) -> Tuple[list, bytes]: """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of data and a list of array interfaces. The data is list of categorical codes that caller can safely ignore, but have to keep their reference alive until usage of array @@ -449,25 +449,20 @@ def _cudf_array_interfaces(data, enable_categorical: bool) -> Tuple[list, bytes] except ImportError: from cudf.utils.dtypes import is_categorical_dtype - cat_codes = [] interfaces = [] if _is_cudf_ser(data): - if is_categorical_dtype(data.dtype) and enable_categorical: - codes = data.cat.codes.astype(np.float32).replace(-1.0, np.NaN) - cat_codes.append(codes) - interface = codes.__cuda_array_interface__ + if is_categorical_dtype(data.dtype): + interface = cat_codes[0].__cuda_array_interface__ else: interface = data.__cuda_array_interface__ if "mask" in interface: interface["mask"] = interface["mask"].__cuda_array_interface__ interfaces.append(interface) else: - for col in data: - if is_categorical_dtype(data[col].dtype) and enable_categorical: - codes = data[col].cat.codes.astype(np.float32).replace(-1.0, np.NaN) - print(codes, codes.shape) + for i, col in enumerate(data): + if is_categorical_dtype(data[col].dtype): + codes = cat_codes[i] interface = codes.__cuda_array_interface__ - cat_codes.append(codes) else: interface = data[col].__cuda_array_interface__ if "mask" in interface: @@ -525,7 +520,19 @@ def _transform_cudf_df( else: feature_types.append(_pandas_dtype_mapper[dtype.name]) - return data, feature_names, feature_types + # handle categorical data + cat_codes = [] + if _is_cudf_ser(data): + if is_categorical_dtype(data.dtype) and enable_categorical: + codes = data.cat.codes.astype(np.float32).replace(-1.0, np.NaN) + cat_codes.append(codes) + else: + for col in data: + if is_categorical_dtype(data[col].dtype) and enable_categorical: + codes = data[col].cat.codes.astype(np.float32).replace(-1.0, np.NaN) + cat_codes.append(codes) + + return (data, cat_codes), feature_names, feature_types def _from_cudf_df( @@ -536,10 +543,10 @@ def _from_cudf_df( feature_types: Optional[List[str]], enable_categorical: bool, ) -> Tuple[ctypes.c_void_p, Any, Any]: - data, feature_names, feature_types = _transform_cudf_df( + (data, cat_codes), feature_names, feature_types = _transform_cudf_df( data, feature_names, feature_types, enable_categorical ) - _, interfaces_str = _cudf_array_interfaces(data, enable_categorical) + _, interfaces_str = _cudf_array_interfaces(data, cat_codes) handle = ctypes.c_void_p() config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8") _check_call( @@ -970,11 +977,16 @@ def _proxy_transform( def dispatch_proxy_set_data( - proxy: _ProxyDMatrix, data: Any, allow_host: bool, enable_categorical: bool + proxy: _ProxyDMatrix, + data: Any, + cat_codes: list, + allow_host: bool, + enable_categorical: bool ) -> None: """Dispatch for DeviceQuantileDMatrix.""" if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) + if _is_cudf_df(data): # pylint: disable=W0212 proxy._set_data_from_cuda_columnar(data, enable_categorical) From 6ecb3b5250ef52d2ea1388cbf0b6eedf9a454848 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 00:52:38 +0800 Subject: [PATCH 09/16] Stage the categorical codes. --- python-package/xgboost/core.py | 21 +++++++++++---------- python-package/xgboost/data.py | 30 ++++++++++++++---------------- tests/python-gpu/test_from_cudf.py | 17 +++++++++++++++-- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 4f189a85361a..ec89026bf148 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -373,7 +373,7 @@ def _reraise(self) -> None: raise exc # pylint: disable=raising-bad-type def __del__(self) -> None: - assert self._temporary_data is None, self._temporary_data + assert self._temporary_data is None assert self._exception is None def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument @@ -406,17 +406,15 @@ def data_handle( "`enable_categorical` should be specifed in DMatrix constructor" ) - transformed, feature_names, feature_types = _proxy_transform( + new, cat_codes, feature_names, feature_types = _proxy_transform( data, feature_names, feature_types, self._enable_categorical, ) # Stage the data, meta info are copied inside C++ MetaInfo. - self._temporary_data = transformed - dispatch_proxy_set_data( - self.proxy, transformed, self._allow_host, self._enable_categorical - ) + self._temporary_data = (new, cat_codes) + dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host) self.proxy.set_info( feature_names=feature_names, feature_types=feature_types, @@ -1105,11 +1103,11 @@ def _set_data_from_cuda_interface(self, data) -> None: _LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str) ) - def _set_data_from_cuda_columnar(self, data, enable_categorical: bool) -> None: + def _set_data_from_cuda_columnar(self, data, cat_codes: list) -> None: """Set data from CUDA columnar format.""" from .data import _cudf_array_interfaces - self.codes, interfaces_str = _cudf_array_interfaces(data, enable_categorical) + interfaces_str = _cudf_array_interfaces(data, cat_codes) _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) def _set_data_from_array(self, data: np.ndarray): @@ -2080,8 +2078,11 @@ def inplace_predict( ) return _prediction_output(shape, dims, preds, True) if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): - from .data import _cudf_array_interfaces - _, interfaces_str = _cudf_array_interfaces(data, enable_categorical) + from .data import _cudf_array_interfaces, _transform_cudf_df + _data, cat_codes, _, _ = _transform_cudf_df( + data, None, None, enable_categorical + ) + interfaces_str = _cudf_array_interfaces(data, cat_codes) _check_call( _LIB.XGBoosterPredictFromCudaColumnar( self.handle, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 8e36ecd08878..75471f233e8d 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -226,7 +226,7 @@ def _invalid_dataframe_dtype(data) -> None: msg = """DataFrame.dtypes for data must be int, float, bool or category. When categorical type is supplied, DMatrix parameter `enable_categorical` must -be set to `True`.""" +be set to `True`. Bad fields: """ raise ValueError(msg + ", ".join(bad_fields)) @@ -469,7 +469,7 @@ def _cudf_array_interfaces(data, cat_codes: list) -> Tuple[list, bytes]: interface["mask"] = interface["mask"].__cuda_array_interface__ interfaces.append(interface) interfaces_str = bytes(json.dumps(interfaces, indent=2), "utf-8") - return cat_codes, interfaces_str + return interfaces_str def _transform_cudf_df( @@ -529,10 +529,10 @@ def _transform_cudf_df( else: for col in data: if is_categorical_dtype(data[col].dtype) and enable_categorical: - codes = data[col].cat.codes.astype(np.float32).replace(-1.0, np.NaN) + codes = data[col].cat.codes cat_codes.append(codes) - return (data, cat_codes), feature_names, feature_types + return data, cat_codes, feature_names, feature_types def _from_cudf_df( @@ -543,10 +543,10 @@ def _from_cudf_df( feature_types: Optional[List[str]], enable_categorical: bool, ) -> Tuple[ctypes.c_void_p, Any, Any]: - (data, cat_codes), feature_names, feature_types = _transform_cudf_df( + data, cat_codes, feature_names, feature_types = _transform_cudf_df( data, feature_names, feature_types, enable_categorical ) - _, interfaces_str = _cudf_array_interfaces(data, cat_codes) + interfaces_str = _cudf_array_interfaces(data, cat_codes) handle = ctypes.c_void_p() config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8") _check_call( @@ -910,8 +910,7 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): _meta_from_dt(data, name, dtype, handle) return if _is_modin_df(data): - data, _, _ = _transform_pandas_df( - data, False, meta=name, meta_type=dtype) + data, _, _ = _transform_pandas_df(data, False, meta=name, meta_type=dtype) _meta_from_numpy(data, name, dtype, handle) return if _is_modin_series(data): @@ -961,27 +960,26 @@ def _proxy_transform( ) if _is_cupy_array(data): data = _transform_cupy_array(data) - return data, feature_names, feature_types + return data, None, feature_names, feature_types if _is_dlpack(data): - return _transform_dlpack(data), feature_names, feature_types + return _transform_dlpack(data), None, feature_names, feature_types if _is_numpy_array(data): - return data, feature_names, feature_types + return data, None, feature_names, feature_types if _is_scipy_csr(data): return data, feature_names, feature_types if _is_pandas_df(data): arr, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types ) - return arr, feature_names, feature_types + return arr, None, feature_names, feature_types raise TypeError("Value type is not supported for data iterator:" + str(type(data))) def dispatch_proxy_set_data( proxy: _ProxyDMatrix, data: Any, - cat_codes: list, + cat_codes: Optional[list], allow_host: bool, - enable_categorical: bool ) -> None: """Dispatch for DeviceQuantileDMatrix.""" if not _is_cudf_ser(data) and not _is_pandas_series(data): @@ -989,11 +987,11 @@ def dispatch_proxy_set_data( if _is_cudf_df(data): # pylint: disable=W0212 - proxy._set_data_from_cuda_columnar(data, enable_categorical) + proxy._set_data_from_cuda_columnar(data, cat_codes) return if _is_cudf_ser(data): # pylint: disable=W0212 - proxy._set_data_from_cuda_columnar(data, enable_categorical) + proxy._set_data_from_cuda_columnar(data, cat_codes) return if _is_cupy_array(data): proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 4dcb4a330f9a..396085705b34 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -189,8 +189,21 @@ def test_cudf_categorical(self): # test missing value X = cudf.DataFrame({"f0": ["a", "b", np.NaN]}) X["f0"] = X["f0"].astype("category") - arr, _, _ = xgb.data._cudf_array_interfaces(X, enable_categorical=True) - assert not np.any(arr == -1) + df, cat_codes, _, _ = xgb.data._transform_cudf_df( + X, None, None, enable_categorical=True + ) + for col in cat_codes: + assert col.has_nulls + + y = [0, 1, 2] + with pytest.raises(ValueError): + xgb.DMatrix(X, y) + Xy = xgb.DMatrix(X, y, enable_categorical=True) + assert Xy.num_row() == 3 + assert Xy.num_col() == 1 + + with pytest.raises(ValueError): + xgb.DeviceQuantileDMatrix(X, y) @pytest.mark.skipif(**tm.no_cudf()) From 5b3648513e1b929f4453487300242258f873acd1 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 01:06:54 +0800 Subject: [PATCH 10/16] Fix. --- src/data/iterative_device_dmatrix.cu | 1 + tests/python-gpu/test_from_cudf.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 00e502dfa767..d3869eff1c45 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -152,6 +152,7 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin if (batches == 1) { this->info_ = std::move(proxy->Info()); + this->info_.num_nonzero_ = nnz; CHECK_EQ(proxy->Info().labels_.Size(), 0); } diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 396085705b34..9caacd257f0b 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -205,6 +205,10 @@ def test_cudf_categorical(self): with pytest.raises(ValueError): xgb.DeviceQuantileDMatrix(X, y) + Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True) + assert Xy.num_row() == 3 + assert Xy.num_col() == 1 + @pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cupy()) From 4f3f500fe89034ec2f9d4a4ddbb288da0e134b96 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 01:10:47 +0800 Subject: [PATCH 11/16] Linter. --- python-package/xgboost/core.py | 5 ----- python-package/xgboost/data.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index ec89026bf148..b0a0530d8770 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -400,11 +400,6 @@ def data_handle( ) -> None: from .data import dispatch_proxy_set_data from .data import _proxy_transform - ec = kwargs.get("enable_categorical", None) - if ec is not None and ec != self._enable_categorical: - raise ValueError( - "`enable_categorical` should be specifed in DMatrix constructor" - ) new, cat_codes, feature_names, feature_types = _proxy_transform( data, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 75471f233e8d..54cbe55cad64 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -1,4 +1,4 @@ -# pylint: disable=too-many-arguments, too-many-branches +# pylint: disable=too-many-arguments, too-many-branches, too-many-lines # pylint: disable=too-many-return-statements, import-error '''Data dispatching for DMatrix.''' import ctypes From 3dd9df638feb09e3c4567fc1424597f83d8a6c19 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 01:34:52 +0800 Subject: [PATCH 12/16] Handle series. --- python-package/xgboost/data.py | 52 +++++++++++++++++++++--------- tests/python-gpu/test_from_cudf.py | 8 +++++ tests/python/test_with_pandas.py | 8 +++++ 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 54cbe55cad64..6f650c0a551d 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -218,16 +218,22 @@ def _is_modin_df(data): def _invalid_dataframe_dtype(data) -> None: - bad_fields = [ - str(data.columns[i]) - for i, dtype in enumerate(data.dtypes) - if dtype.name not in _pandas_dtype_mapper - ] + # pandas series has `dtypes` but it's just a single object + # cudf series doesn't have `dtypes`. + if hasattr(data, "dtypes") and hasattr(data.dtypes, "__iter__"): + bad_fields = [ + str(data.columns[i]) + for i, dtype in enumerate(data.dtypes) + if dtype.name not in _pandas_dtype_mapper + ] + err = " Invalid columns:" + ", ".join(bad_fields) + else: + err = "" msg = """DataFrame.dtypes for data must be int, float, bool or category. When categorical type is supplied, DMatrix parameter `enable_categorical` must -be set to `True`. Bad fields: """ - raise ValueError(msg + ", ".join(bad_fields)) +be set to `True`.""" + err + raise ValueError(msg) def _transform_pandas_df( @@ -327,13 +333,26 @@ def _is_modin_series(data): def _from_pandas_series( data, - missing, - nthread, + missing: float, + nthread: int, + enable_categorical: bool, feature_names: Optional[List[str]], feature_types: Optional[List[str]], ): + from pandas.api.types import is_categorical_dtype + + if (data.dtype.name not in _pandas_dtype_mapper) and not ( + is_categorical_dtype(data.dtype) and enable_categorical + ): + _invalid_dataframe_dtype(data) + if enable_categorical and is_categorical_dtype(data.dtype): + data = data.cat.codes return _from_numpy_array( - data.values.astype("float"), missing, nthread, feature_names, feature_types + data.values.reshape(data.shape[0], 1).astype("float"), + missing, + nthread, + feature_names, + feature_types, ) @@ -523,8 +542,9 @@ def _transform_cudf_df( # handle categorical data cat_codes = [] if _is_cudf_ser(data): + # unlike pandas, cuDF uses NA for missing data. if is_categorical_dtype(data.dtype) and enable_categorical: - codes = data.cat.codes.astype(np.float32).replace(-1.0, np.NaN) + codes = data.cat.codes cat_codes.append(codes) else: for col in data: @@ -751,8 +771,9 @@ def dispatch_data_backend( return _from_pandas_df(data, enable_categorical, missing, threads, feature_names, feature_types) if _is_pandas_series(data): - return _from_pandas_series(data, missing, threads, feature_names, - feature_types) + return _from_pandas_series( + data, missing, threads, enable_categorical, feature_names, feature_types + ) if _is_cudf_df(data) or _is_cudf_ser(data): return _from_cudf_df( data, missing, threads, feature_names, feature_types, enable_categorical @@ -776,8 +797,9 @@ def dispatch_data_backend( return _from_pandas_df(data, enable_categorical, missing, threads, feature_names, feature_types) if _is_modin_series(data): - return _from_pandas_series(data, missing, threads, feature_names, - feature_types) + return _from_pandas_series( + data, missing, threads, enable_categorical, feature_names, feature_types + ) if _has_array_protocol(data): array = np.asarray(data) return _from_numpy_array(array, missing, threads, feature_names, feature_types) diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 9caacd257f0b..904dbf0934db 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -209,6 +209,14 @@ def test_cudf_categorical(self): assert Xy.num_row() == 3 assert Xy.num_col() == 1 + X = X["f0"] + with pytest.raises(ValueError): + xgb.DMatrix(X, y) + + Xy = xgb.DMatrix(X, y, enable_categorical=True) + assert Xy.num_row() == 3 + assert Xy.num_col() == 1 + @pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cupy()) diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 730801af83d3..0b25993a5ee5 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -146,6 +146,14 @@ def test_pandas_categorical(self): arr, _, _ = xgb.data._transform_pandas_df(X, enable_categorical=True) assert not np.any(arr == -1.0) + X = X["f0"] + with pytest.raises(ValueError): + xgb.DMatrix(X, y) + + Xy = xgb.DMatrix(X, y, enable_categorical=True) + assert Xy.num_row() == 3 + assert Xy.num_col() == 1 + def test_pandas_sparse(self): import pandas as pd rows = 100 From ad5ff3115d1cc417a9c9b72e047c485a5bf96d98 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 01:36:51 +0800 Subject: [PATCH 13/16] Fix scipy. --- python-package/xgboost/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 6f650c0a551d..da335b350cbd 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -988,7 +988,7 @@ def _proxy_transform( if _is_numpy_array(data): return data, None, feature_names, feature_types if _is_scipy_csr(data): - return data, feature_names, feature_types + return data, None, feature_names, feature_types if _is_pandas_df(data): arr, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types From fd162390a7d9dd5f582bd615b3ce021ad43ea92f Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 01:39:51 +0800 Subject: [PATCH 14/16] Weighted data. --- src/common/hist_util.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index c6ca000d095d..419febc22bc9 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -235,6 +235,12 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, detail::SortByWeight(&temp_weights, &sorted_entries); + if (sketch_container->HasCategorical()) { + auto d_cuts_ptr = cuts_ptr.DeviceSpan(); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, + &sorted_entries, &column_sizes_scan); + } + auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); From aa97cb88b53f370cbfb0c31ff601fb9ca3d230c2 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 01:55:01 +0800 Subject: [PATCH 15/16] Add c++ test. --- tests/cpp/common/test_hist_util.cu | 46 ++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 0be7450e9bfd..9cebcfffe673 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -392,6 +392,50 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } +void TestCategoricalSketchAdapter(size_t n, size_t num_categories, + int32_t num_bins, bool weighted) { + auto h_x = GenerateRandomCategoricalSingleColumn(n, num_categories); + thrust::device_vector x(h_x); + auto adapter = AdapterFromData(x, n, 1); + MetaInfo info; + info.num_row_ = n; + info.num_col_ = 1; + info.feature_types.HostVector().push_back(FeatureType::kCategorical); + + if (weighted) { + std::vector weights(n, 0); + SimpleLCG lcg; + SimpleRealUniformDistribution dist(0, 1); + for (auto& v : weights) { + v = dist(&lcg); + } + info.weights_.HostVector() = weights; + } + + ASSERT_EQ(info.feature_types.Size(), 1); + SketchContainer container(info.feature_types, num_bins, 1, n, 0); + AdapterDeviceSketch(adapter.Value(), num_bins, info, + std::numeric_limits::quiet_NaN(), &container); + HistogramCuts cuts; + container.MakeCuts(&cuts); + + std::sort(x.begin(), x.end()); + auto n_uniques = std::unique(x.begin(), x.end()) - x.begin(); + ASSERT_NE(n_uniques, x.size()); + ASSERT_EQ(cuts.TotalBins(), n_uniques); + ASSERT_EQ(n_uniques, num_categories); + + auto& values = cuts.cut_values_.HostVector(); + ASSERT_TRUE(std::is_sorted(values.cbegin(), values.cend())); + auto is_unique = (std::unique(values.begin(), values.end()) - values.begin()) == n_uniques; + ASSERT_TRUE(is_unique); + + x.resize(n_uniques); + for (decltype(n_uniques) i = 0; i < n_uniques; ++i) { + ASSERT_EQ(x[i], values[i]); + } +} + TEST(HistUtil, AdapterDeviceSketchCategorical) { int categorical_sizes[] = {2, 6, 8, 12}; int num_bins = 256; @@ -404,6 +448,8 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) { auto adapter = AdapterFromData(x_device, n, 1); ValidateBatchedCuts(adapter, num_bins, adapter.NumColumns(), adapter.NumRows(), dmat.get()); + TestCategoricalSketchAdapter(n, num_categories, num_bins, true); + TestCategoricalSketchAdapter(n, num_categories, num_bins, false); } } } From 1e9fd1c1ac4a1760d52b5cf8c51a9ce0c52ae1a8 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 22 Oct 2021 17:03:11 +0800 Subject: [PATCH 16/16] Address reviewer's comment. --- python-package/xgboost/core.py | 2 +- python-package/xgboost/data.py | 2 +- tests/cpp/common/test_hist_util.cu | 8 +++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index b0a0530d8770..a362a560c7d0 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2074,7 +2074,7 @@ def inplace_predict( return _prediction_output(shape, dims, preds, True) if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): from .data import _cudf_array_interfaces, _transform_cudf_df - _data, cat_codes, _, _ = _transform_cudf_df( + data, cat_codes, _, _ = _transform_cudf_df( data, None, None, enable_categorical ) interfaces_str = _cudf_array_interfaces(data, cat_codes) diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index da335b350cbd..664911a5474b 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -456,7 +456,7 @@ def _is_cudf_df(data): return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) -def _cudf_array_interfaces(data, cat_codes: list) -> Tuple[list, bytes]: +def _cudf_array_interfaces(data, cat_codes: list) -> bytes: """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of data and a list of array interfaces. The data is list of categorical codes that caller can safely ignore, but have to keep their reference alive until usage of array diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 9cebcfffe673..eb1b04cd55fd 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -419,8 +419,8 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories, HistogramCuts cuts; container.MakeCuts(&cuts); - std::sort(x.begin(), x.end()); - auto n_uniques = std::unique(x.begin(), x.end()) - x.begin(); + thrust::sort(x.begin(), x.end()); + auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin(); ASSERT_NE(n_uniques, x.size()); ASSERT_EQ(cuts.TotalBins(), n_uniques); ASSERT_EQ(n_uniques, num_categories); @@ -431,8 +431,10 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories, ASSERT_TRUE(is_unique); x.resize(n_uniques); + h_x.resize(n_uniques); + thrust::copy(x.begin(), x.end(), h_x.begin()); for (decltype(n_uniques) i = 0; i < n_uniques; ++i) { - ASSERT_EQ(x[i], values[i]); + ASSERT_EQ(h_x[i], values[i]); } }