diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index cf415b9e9afc..0735dcd48323 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -321,6 +321,7 @@ class DataIter: def __init__(self): self._handle = _ProxyDMatrix() self.exception = None + self.enable_categorical = False @property def proxy(self): @@ -346,13 +347,12 @@ def data_handle( data, feature_names=None, feature_types=None, - enable_categorical=False, **kwargs ): from .data import dispatch_device_quantile_dmatrix_set_data from .data import _device_quantile_transform data, feature_names, feature_types = _device_quantile_transform( - data, feature_names, feature_types, enable_categorical, + data, feature_names, feature_types, self.enable_categorical, ) dispatch_device_quantile_dmatrix_set_data(self.proxy, data) self.proxy.set_info( @@ -1106,15 +1106,10 @@ def _init(self, data, enable_categorical, **meta): data = _transform_dlpack(data) if _is_iter(data): it = data - if enable_categorical: - raise NotImplementedError( - "categorical support is not enabled on data iterator." - ) else: - it = SingleBatchInternalIter( - data=data, enable_categorical=enable_categorical, **meta - ) + it = SingleBatchInternalIter(data=data, **meta) + it.enable_categorical = enable_categorical reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper) next_callback = ctypes.CFUNCTYPE( ctypes.c_int, diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index a07a61224d34..a6b47906c1ec 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -182,7 +182,7 @@ def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements lazy_isinstance(value[0], 'cudf.core.series', 'Series'): from cudf import concat as CUDF_concat # pylint: disable=import-error return CUDF_concat(value, axis=0) - if lazy_isinstance(value[0], 'cupy.core.core', 'ndarray'): + if lazy_isinstance(value[0], 'cupy._core.core', 'ndarray'): import cupy # pylint: disable=c-extension-no-member,no-member d = cupy.cuda.runtime.getDevice() @@ -258,6 +258,7 @@ def __init__( self.feature_names = feature_names self.feature_types = feature_types self.missing = missing + self.enable_categorical = enable_categorical if qid is not None and weight is not None: raise NotImplementedError("per-group weight is not implemented.") @@ -265,10 +266,6 @@ def __init__( raise NotImplementedError( "group structure is not implemented, use qid instead." ) - if enable_categorical: - raise NotImplementedError( - "categorical support is not enabled on `DaskDMatrix`." - ) if len(data.shape) != 2: raise ValueError( @@ -311,7 +308,7 @@ async def _map_local_data( qid: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None, - label_upper_bound: Optional[_DaskCollection] = None + label_upper_bound: Optional[_DaskCollection] = None, ) -> "DaskDMatrix": '''Obtain references to local data.''' @@ -430,6 +427,7 @@ def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: 'feature_weights': self.feature_weights, 'meta_names': self.meta_names, 'missing': self.missing, + 'enable_categorical': self.enable_categorical, 'parts': self.worker_map.get(worker_addr, None), 'is_quantile': self.is_quantile} @@ -668,6 +666,7 @@ def _create_device_quantile_dmatrix( missing: float, parts: Optional[_DataParts], max_bin: int, + enable_categorical: bool, ) -> DeviceQuantileDMatrix: worker = distributed.get_worker() if parts is None: @@ -680,6 +679,7 @@ def _create_device_quantile_dmatrix( feature_names=feature_names, feature_types=feature_types, max_bin=max_bin, + enable_categorical=enable_categorical, ) return d @@ -709,6 +709,7 @@ def _create_device_quantile_dmatrix( feature_types=feature_types, nthread=worker.nthreads, max_bin=max_bin, + enable_categorical=enable_categorical, ) dmatrix.set_info(feature_weights=feature_weights) return dmatrix @@ -720,6 +721,7 @@ def _create_dmatrix( feature_weights: Optional[Any], meta_names: List[str], missing: float, + enable_categorical: bool, parts: Optional[_DataParts] ) -> DMatrix: '''Get data that local to worker from DaskDMatrix. @@ -734,9 +736,12 @@ def _create_dmatrix( if list_of_parts is None: msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address) LOGGER.warning(msg) - d = DMatrix(numpy.empty((0, 0)), - feature_names=feature_names, - feature_types=feature_types) + d = DMatrix( + numpy.empty((0, 0)), + feature_names=feature_names, + feature_types=feature_types, + enable_categorical=enable_categorical, + ) return d T = TypeVar('T') @@ -764,6 +769,7 @@ def concat_or_none(data: Tuple[Optional[T], ...]) -> Optional[T]: feature_names=feature_names, feature_types=feature_types, nthread=worker.nthreads, + enable_categorical=enable_categorical, ) dmatrix.set_info( base_margin=_base_margin, diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 271f4ac36149..6065f28e2147 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1151,12 +1151,12 @@ struct SegmentedUniqueReduceOp { * \return Number of unique values in total. */ template + typename ValOutIt, typename CompValue, typename CompKey> size_t SegmentedUnique(const thrust::detail::execution_policy_base &exec, KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first, ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out, - Comp comp) { + CompValue comp, CompKey comp_key=thrust::equal_to{}) { using Key = thrust::pair::value_type>; auto unique_key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(static_cast(0)), @@ -1177,7 +1177,7 @@ SegmentedUnique(const thrust::detail::execution_policy_base &exec exec, unique_key_it, unique_key_it + n_inputs, val_first, reduce_it, val_out, [=] __device__(Key const &l, Key const &r) { - if (l.first == r.first) { + if (comp_key(l.first, r.first)) { // In the same segment. return comp(l.second, r.second); } @@ -1195,7 +1195,9 @@ template size_t SegmentedUnique(Inputs &&...inputs) { dh::XGBCachingDeviceAllocator alloc; - return SegmentedUnique(thrust::cuda::par(alloc), std::forward(inputs)...); + return SegmentedUnique(thrust::cuda::par(alloc), + std::forward(inputs)..., + thrust::equal_to{}); } /** diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 469ce4499a2f..f85fea945275 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -129,60 +129,52 @@ void SortByWeight(dh::device_vector* weights, }); } -struct IsCatOp { - XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; } -}; - void RemoveDuplicatedCategories( int32_t device, MetaInfo const &info, Span d_cuts_ptr, dh::device_vector *p_sorted_entries, - dh::caching_device_vector* p_column_sizes_scan) { + dh::caching_device_vector *p_column_sizes_scan) { auto d_feature_types = info.feature_types.ConstDeviceSpan(); - auto& column_sizes_scan = *p_column_sizes_scan; - if (!info.feature_types.Empty() && - thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), - IsCatOp{})) { - auto& sorted_entries = *p_sorted_entries; - // Removing duplicated entries in categorical features. - dh::caching_device_vector new_column_scan(column_sizes_scan.size()); - dh::SegmentedUnique( - column_sizes_scan.data().get(), - column_sizes_scan.data().get() + column_sizes_scan.size(), - sorted_entries.begin(), sorted_entries.end(), - new_column_scan.data().get(), sorted_entries.begin(), - [=] __device__(Entry const &l, Entry const &r) { - if (l.index == r.index) { - if (IsCat(d_feature_types, l.index)) { - return l.fvalue == r.fvalue; - } - } - return false; - }); + CHECK(!d_feature_types.empty()); + auto &column_sizes_scan = *p_column_sizes_scan; + auto &sorted_entries = *p_sorted_entries; + // Removing duplicated entries in categorical features. + dh::caching_device_vector new_column_scan(column_sizes_scan.size()); + dh::SegmentedUnique(column_sizes_scan.data().get(), + column_sizes_scan.data().get() + column_sizes_scan.size(), + sorted_entries.begin(), sorted_entries.end(), + new_column_scan.data().get(), sorted_entries.begin(), + [=] __device__(Entry const &l, Entry const &r) { + if (l.index == r.index) { + if (IsCat(d_feature_types, l.index)) { + return l.fvalue == r.fvalue; + } + } + return false; + }); - // Renew the column scan and cut scan based on categorical data. - auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan); - dh::caching_device_vector new_cuts_size( - info.num_col_ + 1); - auto d_new_cuts_size = dh::ToSpan(new_cuts_size); - auto d_new_columns_ptr = dh::ToSpan(new_column_scan); - CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); - dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) { - d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx]; - if (idx == d_new_columns_ptr.size() - 1) { - return; - } - if (IsCat(d_feature_types, idx)) { - // Cut size is the same as number of categories in input. - d_new_cuts_size[idx] = - d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx]; - } else { - d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx]; - } - }); - // Turn size into ptr. - thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), - new_cuts_size.cend(), d_cuts_ptr.data()); - } + // Renew the column scan and cut scan based on categorical data. + auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan); + dh::caching_device_vector new_cuts_size( + info.num_col_ + 1); + auto d_new_cuts_size = dh::ToSpan(new_cuts_size); + auto d_new_columns_ptr = dh::ToSpan(new_column_scan); + CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); + dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) { + d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx]; + if (idx == d_new_columns_ptr.size() - 1) { + return; + } + if (IsCat(d_feature_types, idx)) { + // Cut size is the same as number of categories in input. + d_new_cuts_size[idx] = + d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx]; + } else { + d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx]; + } + }); + // Turn size into ptr. + thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), + new_cuts_size.cend(), d_cuts_ptr.data()); } } // namespace detail @@ -215,8 +207,11 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page, 0, sorted_entries.size(), &cuts_ptr, &column_sizes_scan); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, - &column_sizes_scan); + + if (sketch_container->HasCategorical()) { + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, + &sorted_entries, &column_sizes_scan); + } auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); @@ -281,8 +276,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page, 0, sorted_entries.size(), &cuts_ptr, &column_sizes_scan); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, - &column_sizes_scan); + if (sketch_container->HasCategorical()) { + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, + &sorted_entries, &column_sizes_scan); + } + auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); // Extract cuts diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 33892d589bfc..607f8fa9fce0 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -210,6 +210,7 @@ void MergeImpl(int32_t device, Span const &d_x, Span const &x_ptr, Span const &d_y, Span const &y_ptr, + Span feature_types, Span out, Span out_ptr) { dh::safe_cuda(cudaSetDevice(device)); @@ -408,31 +409,6 @@ size_t SketchContainer::ScanInput(Span entries, Span d_col return n_uniques; } -size_t SketchContainer::Unique() { - timer_.Start(__func__); - dh::safe_cuda(cudaSetDevice(device_)); - this->columns_ptr_.SetDevice(device_); - Span d_column_scan = this->columns_ptr_.DeviceSpan(); - CHECK_EQ(d_column_scan.size(), num_columns_ + 1); - Span entries = dh::ToSpan(this->Current()); - HostDeviceVector scan_out(d_column_scan.size()); - scan_out.SetDevice(device_); - auto d_scan_out = scan_out.DeviceSpan(); - - d_column_scan = this->columns_ptr_.DeviceSpan(); - size_t n_uniques = dh::SegmentedUnique( - d_column_scan.data(), d_column_scan.data() + d_column_scan.size(), - entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(), - entries.data(), - detail::SketchUnique{}); - this->columns_ptr_.Copy(scan_out); - CHECK(!this->columns_ptr_.HostCanRead()); - - this->Current().resize(n_uniques); - timer_.Stop(__func__); - return n_uniques; -} - void SketchContainer::Prune(size_t to) { timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_)); @@ -490,13 +466,20 @@ void SketchContainer::Merge(Span d_that_columns_ptr, this->Other().resize(this->Current().size() + that.size()); CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size()); - MergeImpl(device_, this->Data(), this->ColumnsPtr(), - that, d_that_columns_ptr, - dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan()); + auto feature_types = this->FeatureTypes().ConstDeviceSpan(); + MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr, + feature_types, dh::ToSpan(this->Other()), + columns_ptr_b_.DeviceSpan()); this->columns_ptr_.Copy(columns_ptr_b_); CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1); this->Alternate(); + if (this->HasCategorical()) { + auto d_feature_types = this->FeatureTypes().ConstDeviceSpan(); + this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) { + return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx); + }); + } timer_.Stop(__func__); } diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 29288be30a30..eab7290bf733 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -16,6 +16,19 @@ class HistogramCuts; using WQSketch = WQuantileSketch; using SketchEntry = WQSketch::Entry; +namespace detail { +struct IsCatOp { + XGBOOST_DEVICE bool operator()(FeatureType ft) { + return ft == FeatureType::kCategorical; + } +}; +struct SketchUnique { + XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const { + return a.value - b.value == 0; + } +}; +} // namespace detail + /*! * \brief A container that holds the device sketches. Sketching is performed per-column, * but fused into single operation for performance. @@ -43,6 +56,8 @@ class SketchContainer { HostDeviceVector columns_ptr_; HostDeviceVector columns_ptr_b_; + bool has_categorical_{false}; + dh::device_vector& Current() { if (current_buffer_) { return entries_a_; @@ -102,14 +117,21 @@ class SketchContainer { this->feature_types_.SetDevice(device); this->feature_types_.ConstDeviceSpan(); this->feature_types_.ConstHostSpan(); + + auto d_feature_types = feature_types_.ConstDeviceSpan(); + has_categorical_ = + !d_feature_types.empty() && + thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), + detail::IsCatOp{}); + timer_.Init(__func__); } /* \brief Return GPU ID for this container. */ int32_t DeviceIdx() const { return device_; } + /* \brief Whether the predictor matrix contains categorical features. */ + bool HasCategorical() const { return has_categorical_; } /* \brief Accumulate weights of duplicated entries in input. */ size_t ScanInput(Span entries, Span d_columns_ptr_in); - /* \brief Removes all the duplicated elements in quantile structure. */ - size_t Unique(); /* Fix rounding error and re-establish invariance. The error is mostly generated by the * addition inside `RMinNext` and subtraction in `RMaxPrev`. */ void FixError(); @@ -154,15 +176,35 @@ class SketchContainer { SketchContainer(const SketchContainer&) = delete; SketchContainer& operator=(const SketchContainer&) = delete; -}; -namespace detail { -struct SketchUnique { - XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const { - return a.value - b.value == 0; + /* \brief Removes all the duplicated elements in quantile structure. */ + template > + size_t Unique(KeyComp key_comp = thrust::equal_to{}) { + timer_.Start(__func__); + dh::safe_cuda(cudaSetDevice(device_)); + this->columns_ptr_.SetDevice(device_); + Span d_column_scan = this->columns_ptr_.DeviceSpan(); + CHECK_EQ(d_column_scan.size(), num_columns_ + 1); + Span entries = dh::ToSpan(this->Current()); + HostDeviceVector scan_out(d_column_scan.size()); + scan_out.SetDevice(device_); + auto d_scan_out = scan_out.DeviceSpan(); + dh::XGBCachingDeviceAllocator alloc; + + d_column_scan = this->columns_ptr_.DeviceSpan(); + size_t n_uniques = dh::SegmentedUnique( + thrust::cuda::par(alloc), d_column_scan.data(), + d_column_scan.data() + d_column_scan.size(), entries.data(), + entries.data() + entries.size(), scan_out.DevicePointer(), + entries.data(), detail::SketchUnique{}, key_comp); + this->columns_ptr_.Copy(scan_out); + CHECK(!this->columns_ptr_.HostCanRead()); + + this->Current().resize(n_uniques); + timer_.Stop(__func__); + return n_uniques; } }; -} // namespace detail } // namespace common } // namespace xgboost diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 59eee0e6ff99..ffae315cd026 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -134,17 +134,20 @@ struct WriteCompressedEllpackFunctor { const common::CompressedBufferWriter& writer, AdapterBatchT batch, EllpackDeviceAccessor accessor, + common::Span feature_types, const data::IsValidFunctor& is_valid) : d_buffer(buffer), writer(writer), batch(std::move(batch)), accessor(std::move(accessor)), + feature_types(std::move(feature_types)), is_valid(is_valid) {} common::CompressedByteT* d_buffer; common::CompressedBufferWriter writer; AdapterBatchT batch; EllpackDeviceAccessor accessor; + common::Span feature_types; data::IsValidFunctor is_valid; using Tuple = thrust::tuple; @@ -154,7 +157,12 @@ struct WriteCompressedEllpackFunctor { // -1 because the scan is inclusive size_t output_position = accessor.row_stride * e.row_idx + out.get<1>() - 1; - auto bin_idx = accessor.SearchBin(e.value, e.column_idx); + uint32_t bin_idx = 0; + if (common::IsCat(feature_types, e.column_idx)) { + bin_idx = accessor.SearchBin(e.value, e.column_idx); + } else { + bin_idx = accessor.SearchBin(e.value, e.column_idx); + } writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position); } return 0; @@ -184,8 +192,9 @@ class TypedDiscard : public thrust::discard_iterator { // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template -void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, - int device_idx, float missing) { +void CopyDataToEllpack(const AdapterBatchT &batch, + common::Span feature_types, + EllpackPageImpl *dst, int device_idx, float missing) { // Some witchcraft happens here // The goal is to copy valid elements out of the input to an ELLPACK matrix // with a given row stride, using no extra working memory Standard stream @@ -220,7 +229,8 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, // We redirect the scan output into this functor to do the actual writing WriteCompressedEllpackFunctor functor( - d_compressed_buffer, writer, batch, device_accessor, is_valid); + d_compressed_buffer, writer, batch, device_accessor, feature_types, + is_valid); TypedDiscard discard; thrust::transform_output_iterator< WriteCompressedEllpackFunctor, decltype(discard)> @@ -263,22 +273,22 @@ template EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread, common::Span row_counts_span, + common::Span feature_types, size_t row_stride, size_t n_rows, size_t n_cols, common::HistogramCuts const& cuts) { dh::safe_cuda(cudaSetDevice(device)); *this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows); - CopyDataToEllpack(batch, this, device, missing); + CopyDataToEllpack(batch, feature_types, this, device, missing); WriteNullValues(this, device, row_counts_span); } -#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \ - template EllpackPageImpl::EllpackPageImpl( \ - __BATCH_T batch, float missing, int device, \ - bool is_dense, int nthread, \ - common::Span row_counts_span, \ - size_t row_stride, size_t n_rows, size_t n_cols, \ - common::HistogramCuts const& cuts); +#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \ + template EllpackPageImpl::EllpackPageImpl( \ + __BATCH_T batch, float missing, int device, bool is_dense, int nthread, \ + common::Span row_counts_span, \ + common::Span feature_types, size_t row_stride, \ + size_t n_rows, size_t n_cols, common::HistogramCuts const &cuts); ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) @@ -467,11 +477,17 @@ size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride, return compressed_size_bytes; } -EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(int device) const { +EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( + int device, common::Span feature_types) const { gidx_buffer.SetDevice(device); - return EllpackDeviceAccessor( - device, cuts_, is_dense, row_stride, base_rowid, n_rows, - common::CompressedIterator(gidx_buffer.ConstDevicePointer(), - NumSymbols())); + return {device, + cuts_, + is_dense, + row_stride, + base_rowid, + n_rows, + common::CompressedIterator(gidx_buffer.ConstDevicePointer(), + NumSymbols()), + feature_types}; } } // namespace xgboost diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index ddee683ed51e..01861d141407 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -10,6 +10,7 @@ #include "../common/compressed_iterator.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" +#include "../common/categorical.h" #include namespace xgboost { @@ -31,13 +32,17 @@ struct EllpackDeviceAccessor { /*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */ common::Span gidx_fvalue_map; + common::Span feature_types; + EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts, bool is_dense, size_t row_stride, size_t base_rowid, - size_t n_rows,common::CompressedIterator gidx_iter) + size_t n_rows,common::CompressedIterator gidx_iter, + common::Span feature_types) : is_dense(is_dense), row_stride(row_stride), base_rowid(base_rowid), - n_rows(n_rows) ,gidx_iter(gidx_iter){ + n_rows(n_rows) ,gidx_iter(gidx_iter), + feature_types{feature_types} { cuts.cut_values_.SetDevice(device); cuts.cut_ptrs_.SetDevice(device); cuts.min_vals_.SetDevice(device); @@ -64,12 +69,23 @@ struct EllpackDeviceAccessor { return gidx; } + template __device__ uint32_t SearchBin(float value, size_t column_id) const { auto beg = feature_segments[column_id]; auto end = feature_segments[column_id + 1]; - auto it = - thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin()+ beg, gidx_fvalue_map.cbegin() + end, value); - uint32_t idx = it - gidx_fvalue_map.cbegin(); + uint32_t idx = 0; + if (is_cat) { + auto it = dh::MakeTransformIterator( + gidx_fvalue_map.cbegin(), [](float v) { return common::AsCat(v); }); + idx = thrust::lower_bound(thrust::seq, it + beg, it + end, + common::AsCat(value)) - + it; + } else { + auto it = thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin() + beg, + gidx_fvalue_map.cbegin() + end, value); + idx = it - gidx_fvalue_map.cbegin(); + } + if (idx == end) { idx -= 1; } @@ -134,10 +150,12 @@ class EllpackPageImpl { explicit EllpackPageImpl(DMatrix* dmat, const BatchParam& parm); template - explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense, int nthread, + explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, + bool is_dense, int nthread, common::Span row_counts_span, + common::Span feature_types, size_t row_stride, size_t n_rows, size_t n_cols, - common::HistogramCuts const& cuts); + common::HistogramCuts const &cuts); /*! \brief Copy the elements of the given ELLPACK page into this page. * @@ -176,7 +194,9 @@ class EllpackPageImpl { * not found). */ size_t NumSymbols() const { return cuts_.TotalBins() + 1; } - EllpackDeviceAccessor GetDeviceAccessor(int device) const; + EllpackDeviceAccessor + GetDeviceAccessor(int device, + common::Span feature_types = {}) const; private: /*! diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index d7569fe300cf..87fd4af93f12 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -148,9 +148,13 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin return GetRowCounts(value, row_counts_span, get_device(), missing); }); auto is_dense = this->IsDense(); + + proxy->Info().feature_types.SetDevice(get_device()); + auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan(); auto new_impl = Dispatch(proxy, [&](auto const &value) { - return EllpackPageImpl(value, missing, get_device(), is_dense, nthread, - row_counts_span, row_stride, rows, cols, cuts); + return EllpackPageImpl(value, missing, get_device(), is_dense, nthread, + row_counts_span, d_feature_types, row_stride, rows, + cols, cuts); }); size_t num_elements = page_->Impl()->Copy(get_device(), &new_impl, offset); offset += num_elements; diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 002ee4c3dd90..2bbdf7cf3748 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -155,6 +155,9 @@ struct EllpackLoader { if (gidx == -1) { return nan(""); } + if (common::IsCat(matrix.feature_types, fidx)) { + return matrix.gidx_fvalue_map[gidx]; + } // The gradient index needs to be shifted by one as min values are not included in the // cuts. if (gidx == matrix.feature_segments[fidx]) { @@ -592,8 +595,10 @@ class GPUPredictor : public xgboost::Predictor { } else { size_t batch_offset = 0; for (auto const& page : dmat->GetBatches()) { + dmat->Info().feature_types.SetDevice(generic_param_->gpu_id); + auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); this->PredictInternal( - page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), + page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types), d_model, out_preds, batch_offset); diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 386d2f609751..b86e2133705b 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -19,7 +19,7 @@ ENV PATH=/opt/python/bin:$PATH # Create new Conda environment with cuDF, Dask, and cuPy RUN \ conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \ - python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy \ + python=3.7 cudf=21.08* rmm=21.08* cudatoolkit=$CUDA_VERSION_ARG dask dask-cuda dask-cudf cupy=9.1* \ numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis ENV GOSU_VERSION 1.10 diff --git a/tests/cpp/data/test_iterative_device_dmatrix.cu b/tests/cpp/data/test_iterative_device_dmatrix.cu index e440103365e4..3e318221e9a8 100644 --- a/tests/cpp/data/test_iterative_device_dmatrix.cu +++ b/tests/cpp/data/test_iterative_device_dmatrix.cu @@ -68,7 +68,16 @@ void TestEquivalent(float sparsity) { auto const& buffer_from_iter = page_concatenated->gidx_buffer; auto const& buffer_from_data = ellpack.Impl()->gidx_buffer; ASSERT_NE(buffer_from_data.Size(), 0); - ASSERT_EQ(buffer_from_data.ConstHostVector(), buffer_from_data.ConstHostVector()); + + common::CompressedIterator data_buf{ + buffer_from_data.ConstHostPointer(), from_data.NumSymbols()}; + common::CompressedIterator data_iter{ + buffer_from_iter.ConstHostPointer(), from_iter.NumSymbols()}; + CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols()); + CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride); + for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) { + CHECK_EQ(data_buf[i], data_iter[i]); + } } } diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 0705dac40e3f..da5f23090421 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -225,6 +225,9 @@ void TestCategoricalPrediction(std::string name) { row[split_ind] = split_cat; auto m = GetDMatrixFromData(row, 1, kCols); + std::vector types(10, FeatureType::kCategorical); + m->Info().feature_types.HostVector() = types; + predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); predictor->PredictBatch(m.get(), &out_predictions, model, 0); ASSERT_EQ(out_predictions.predictions.Size(), 1ul); diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index ca934bcdb6da..36bf6023071c 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -225,19 +225,28 @@ class IterForDMatrixTest(xgb.core.DataIter): ROWS_PER_BATCH = 100 # data is splited by rows BATCHES = 16 - def __init__(self): + def __init__(self, categorical): '''Generate some random data for demostration. Actual data can be anything that is currently supported by XGBoost. ''' import cudf self.rows = self.ROWS_PER_BATCH - rng = np.random.RandomState(1994) - self._data = [ - cudf.DataFrame( - {'a': rng.randn(self.ROWS_PER_BATCH), - 'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES - self._labels = [rng.randn(self.rows)] * self.BATCHES + + if categorical: + self._data = [] + self._labels = [] + for i in range(self.BATCHES): + X, y = tm.make_categorical(self.ROWS_PER_BATCH, 4, 13, False) + self._data.append(cudf.from_pandas(X)) + self._labels.append(y) + else: + rng = np.random.RandomState(1994) + self._data = [ + cudf.DataFrame( + {'a': rng.randn(self.ROWS_PER_BATCH), + 'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES + self._labels = [rng.randn(self.rows)] * self.BATCHES self.it = 0 # set iterator to 0 super().__init__() @@ -272,24 +281,26 @@ def next(self, input_data): @pytest.mark.skipif(**tm.no_cudf()) -def test_from_cudf_iter(): +@pytest.mark.parametrize("enable_categorical", [True, False]) +def test_from_cudf_iter(enable_categorical): rounds = 100 - it = IterForDMatrixTest() + it = IterForDMatrixTest(enable_categorical) + params = {"tree_method": "gpu_hist"} # Use iterator - m_it = xgb.DeviceQuantileDMatrix(it) - reg_with_it = xgb.train({'tree_method': 'gpu_hist'}, m_it, - num_boost_round=rounds) - predict_with_it = reg_with_it.predict(m_it) + m_it = xgb.DeviceQuantileDMatrix(it, enable_categorical=enable_categorical) + reg_with_it = xgb.train(params, m_it, num_boost_round=rounds) + + X = it.as_array() + y = it.as_array_labels() - # Without using iterator - m = xgb.DMatrix(it.as_array(), it.as_array_labels()) + m = xgb.DMatrix(X, y, enable_categorical=enable_categorical) assert m_it.num_col() == m.num_col() assert m_it.num_row() == m.num_row() - reg = xgb.train({'tree_method': 'gpu_hist'}, m, - num_boost_round=rounds) - predict = reg.predict(m) + reg = xgb.train(params, m, num_boost_round=rounds) + predict = reg.predict(m) + predict_with_it = reg_with_it.predict(m_it) np.testing.assert_allclose(predict_with_it, predict) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 64bea85975e0..a08f99079c62 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -1,11 +1,13 @@ import sys import os -from typing import Type, TypeVar, Any, Dict, List +from typing import Type, TypeVar, Any, Dict, List, Tuple import pytest import numpy as np import asyncio import xgboost import subprocess +import tempfile +import json from collections import OrderedDict from inspect import signature from hypothesis import given, strategies, settings, note @@ -41,6 +43,49 @@ pass +def make_categorical( + client: Client, + n_samples: int, + n_features: int, + n_categories: int, + onehot: bool = False, +) -> Tuple[dd.DataFrame, dd.Series]: + workers = _get_client_workers(client) + n_workers = len(workers) + dfs = [] + + def pack(**kwargs: Any) -> dd.DataFrame: + X, y = tm.make_categorical(**kwargs) + X["label"] = y + return X + + meta = pack( + n_samples=1, n_features=n_features, n_categories=n_categories, onehot=False + ) + + for i, worker in enumerate(workers): + l_n_samples = min( + n_samples // n_workers, n_samples - i * (n_samples // n_workers) + ) + future = client.submit( + pack, + n_samples=l_n_samples, + n_features=n_features, + n_categories=n_categories, + onehot=False, + workers=[worker], + ) + dfs.append(future) + + df = dd.from_delayed(dfs, meta=meta) + y = df["label"] + X = df[df.columns.difference(["label"])] + + if onehot: + return dd.get_dummies(X), y + return X, y + + def run_with_dask_dataframe(DMatrixT: Type, client: Client) -> None: import cupy as cp cp.cuda.runtime.setDevice(0) @@ -126,6 +171,62 @@ def run_with_dask_array(DMatrixT: Type, client: Client) -> None: inplace_predictions) +@pytest.mark.skipif(**tm.no_dask_cudf()) +def test_categorical(local_cuda_cluster: LocalCUDACluster) -> None: + with Client(local_cuda_cluster) as client: + import dask_cudf + + rounds = 10 + X, y = make_categorical(client, 10000, 30, 13) + X = dask_cudf.from_dask_dataframe(X) + + X_onehot, _ = make_categorical(client, 10000, 30, 13, True) + X_onehot = dask_cudf.from_dask_dataframe(X_onehot) + + parameters = {"tree_method": "gpu_hist"} + + m = dxgb.DaskDMatrix(client, X_onehot, y, enable_categorical=True) + by_etl_results = dxgb.train( + client, + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + )["history"] + + m = dxgb.DaskDMatrix(client, X, y, enable_categorical=True) + output = dxgb.train( + client, + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + ) + by_builtin_results = output["history"] + + np.testing.assert_allclose( + np.array(by_etl_results["Train"]["rmse"]), + np.array(by_builtin_results["Train"]["rmse"]), + rtol=1e-3, + ) + assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) + + model = output["booster"] + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, "model.json") + model.save_model(path) + with open(path, "r") as fd: + categorical = json.load(fd) + + categories_sizes = np.array( + categorical["learner"]["gradient_booster"]["model"]["trees"][-1][ + "categories_sizes" + ] + ) + assert categories_sizes.shape[0] != 0 + np.testing.assert_allclose(categories_sizes, 1) + + def to_cp(x: Any, DMatrixT: Type) -> Any: import cupy if isinstance(x, np.ndarray) and \ diff --git a/tests/python/testing.py b/tests/python/testing.py index 2feeaf0a0e40..947b303a0e39 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -236,7 +236,7 @@ def get_mq2008(dpath): @memory.cache def make_categorical( - n_samples: int, n_features: int, n_categories: int, onehot_enc: bool + n_samples: int, n_features: int, n_categories: int, onehot: bool ): import pandas as pd @@ -244,7 +244,7 @@ def make_categorical( pd_dict = {} for i in range(n_features + 1): - c = rng.randint(low=0, high=n_categories + 1, size=n_samples) + c = rng.randint(low=0, high=n_categories, size=n_samples) pd_dict[str(i)] = pd.Series(c, dtype=np.int64) df = pd.DataFrame(pd_dict) @@ -255,11 +255,13 @@ def make_categorical( label += 1 df = df.astype("category") - if onehot_enc: - cat = pd.get_dummies(df) - else: - cat = df - return cat, label + categories = np.arange(0, n_categories) + for col in df.columns: + df[col] = df[col].cat.set_categories(categories) + + if onehot: + return pd.get_dummies(df), label + return df, label _unweighted_datasets_strategy = strategies.sampled_from(