From 46e0bce212f87a923d8a740288370f45291f818e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 5 May 2022 19:56:49 +0800 Subject: [PATCH] Use maximum category in sketch. (#7853) --- src/common/categorical.h | 7 +- src/common/common.h | 7 ++ src/common/quantile.cc | 40 +++------- src/common/quantile.cu | 104 +++++++++++++++----------- tests/python-gpu/test_gpu_updaters.py | 3 + tests/python/test_updaters.py | 26 +++++++ 6 files changed, 113 insertions(+), 74 deletions(-) diff --git a/src/common/categorical.h b/src/common/categorical.h index 1e3f5c960971..a54d823d8267 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -74,10 +74,15 @@ inline void InvalidCategory() { // values to be less than this last representable value. auto str = std::to_string(OutOfRangeCat()); LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, " - "less than total umber of categories in training data and less than " + + "less than total number of categories in training data and less than " + str; } +inline void CheckMaxCat(float max_cat, size_t n_categories) { + CHECK_GE(max_cat + 1, n_categories) + << "Maximum cateogry should not be lesser than the total number of categories."; +} + /*! * \brief Whether should we use onehot encoding for categorical data. */ diff --git a/src/common/common.h b/src/common/common.h index aa2d8197b4a1..4949d61e4582 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -188,10 +188,17 @@ class IndexTransformIter { */ explicit IndexTransformIter(Fn &&op) : fn_{op} {} IndexTransformIter(IndexTransformIter const &) = default; + IndexTransformIter& operator=(IndexTransformIter&&) = default; + IndexTransformIter& operator=(IndexTransformIter const& that) { + iter_ = that.iter_; + return *this; + } value_type operator*() const { return fn_(iter_); } auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; } + bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; } + bool operator!=(IndexTransformIter const &that) const { return !(*this == that); } IndexTransformIter &operator++() { iter_++; diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 2e9abb8f20b4..13a5a15560ba 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -468,11 +468,17 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b } } -void AddCategories(std::set const &categories, HistogramCuts *cuts) { +auto AddCategories(std::set const &categories, HistogramCuts *cuts) { + if (std::any_of(categories.cbegin(), categories.cend(), InvalidCat)) { + InvalidCategory(); + } auto &cut_values = cuts->cut_values_.HostVector(); - for (auto const &v : categories) { - cut_values.push_back(AsCat(v)); + auto max_cat = *std::max_element(categories.cbegin(), categories.cend()); + CheckMaxCat(max_cat, categories.size()); + for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) { + cut_values.push_back(i); } + return max_cat; } template @@ -505,11 +511,12 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { } }); + float max_cat{-1.f}; for (size_t fid = 0; fid < reduced.size(); ++fid) { size_t max_num_bins = std::min(num_cuts[fid], max_bins_); typename WQSketch::SummaryContainer const& a = final_summaries[fid]; if (IsCat(feature_types_, fid)) { - AddCategories(categories_.at(fid), cuts); + max_cat = std::max(max_cat, AddCategories(categories_.at(fid), cuts)); } else { AddCutPoint(a, max_num_bins, cuts); // push a value that is greater than anything @@ -527,30 +534,7 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { cuts->cut_ptrs_.HostVector().push_back(cut_size); } - if (has_categorical_) { - for (auto const &feat : categories_) { - if (std::any_of(feat.cbegin(), feat.cend(), InvalidCat)) { - InvalidCategory(); - } - } - auto const &ptrs = cuts->Ptrs(); - auto const &vals = cuts->Values(); - - float max_cat{-std::numeric_limits::infinity()}; - for (size_t i = 1; i < ptrs.size(); ++i) { - if (IsCat(feature_types_, i - 1)) { - auto beg = ptrs[i - 1]; - auto end = ptrs[i]; - auto feat = Span{vals}.subspan(beg, end - beg); - auto max_elem = *std::max_element(feat.cbegin(), feat.cend()); - if (max_elem > max_cat) { - max_cat = max_elem; - } - } - } - cuts->SetCategorical(true, max_cat); - } - + cuts->SetCategorical(this->has_categorical_, max_cat); monitor_.Stop(__func__); } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 1be6ea23bd30..33179551631b 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2020 by XGBoost Contributors + * Copyright 2020-2022 by XGBoost Contributors */ #include #include @@ -583,13 +583,13 @@ void SketchContainer::AllReduce() { namespace { struct InvalidCatOp { - Span values; - Span ptrs; + Span values; + Span ptrs; Span ft; XGBOOST_DEVICE bool operator()(size_t i) const { auto fidx = dh::SegmentId(ptrs, i); - return IsCat(ft, fidx) && InvalidCat(values[i]); + return IsCat(ft, fidx) && InvalidCat(values[i].value); } }; } // anonymous namespace @@ -611,7 +611,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { p_cuts->min_vals_.SetDevice(device_); auto d_min_values = p_cuts->min_vals_.DeviceSpan(); - auto in_cut_values = dh::ToSpan(this->Current()); + auto const in_cut_values = dh::ToSpan(this->Current()); // Set up output ptr p_cuts->cut_ptrs_.SetDevice(device_); @@ -619,26 +619,70 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { h_out_columns_ptr.clear(); h_out_columns_ptr.push_back(0); auto const& h_feature_types = this->feature_types_.ConstHostSpan(); + + auto d_ft = feature_types_.ConstDeviceSpan(); + + std::vector max_values; + float max_cat{-1.f}; + if (has_categorical_) { + dh::XGBCachingDeviceAllocator alloc; + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t { + return dh::SegmentId(d_in_columns_ptr, i); + }); + auto invalid_op = InvalidCatOp{in_cut_values, d_in_columns_ptr, d_ft}; + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { + auto fidx = dh::SegmentId(d_in_columns_ptr, i); + auto v = in_cut_values[i]; + if (IsCat(d_ft, fidx)) { + if (invalid_op(i)) { + // use inf to indicate invalid value, this way we can keep it as in + // indicator in the reduce operation as it's always the greatest value. + v.value = std::numeric_limits::infinity(); + } + } + return v; + }); + CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1); + max_values.resize(d_in_columns_ptr.size() - 1); + dh::caching_device_vector d_max_values(d_in_columns_ptr.size() - 1); + thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, + thrust::make_discard_iterator(), d_max_values.begin(), + thrust::equal_to{}, + [] __device__(auto l, auto r) { return l.value > r.value ? l : r; }); + dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values)); + auto max_it = common::MakeIndexTransformIter([&](auto i) { + if (IsCat(h_feature_types, i)) { + return max_values[i].value; + } + return -1.f; + }); + max_cat = *std::max_element(max_it, max_it + max_values.size()); + if (std::isinf(max_cat)) { + InvalidCategory(); + } + } + + // Set up output cuts for (bst_feature_t i = 0; i < num_columns_; ++i) { - size_t column_size = std::max(static_cast(1ul), - this->Column(i).size()); + size_t column_size = std::max(static_cast(1ul), this->Column(i).size()); if (IsCat(h_feature_types, i)) { - h_out_columns_ptr.push_back(static_cast(column_size)); + // column_size is the number of unique values in that feature. + CheckMaxCat(max_values[i].value, column_size); + h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0. } else { - h_out_columns_ptr.push_back(std::min(static_cast(column_size), - static_cast(num_bins_))); + h_out_columns_ptr.push_back( + std::min(static_cast(column_size), static_cast(num_bins_))); } } - std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), - h_out_columns_ptr.begin()); + std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), h_out_columns_ptr.begin()); auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan(); - // Set up output cuts size_t total_bins = h_out_columns_ptr.back(); p_cuts->cut_values_.SetDevice(device_); p_cuts->cut_values_.Resize(total_bins); auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); - auto d_ft = feature_types_.ConstDeviceSpan(); dh::LaunchN(total_bins, [=] __device__(size_t idx) { auto column_id = dh::SegmentId(d_out_columns_ptr, idx); @@ -667,8 +711,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { } if (IsCat(d_ft, column_id)) { - assert(out_column.size() == in_column.size()); - out_column[idx] = in_column[idx].value; + out_column[idx] = idx; return; } @@ -684,36 +727,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { out_column[idx] = in_column[idx+1].value; }); - float max_cat{-1.0f}; - if (has_categorical_) { - auto invalid_op = InvalidCatOp{out_cut_values, d_out_columns_ptr, d_ft}; - auto it = dh::MakeTransformIterator>( - thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { - auto fidx = dh::SegmentId(d_out_columns_ptr, i); - if (IsCat(d_ft, fidx)) { - auto invalid = invalid_op(i); - auto v = out_cut_values[i]; - return thrust::make_pair(invalid, v); - } - return thrust::make_pair(false, std::numeric_limits::min()); - }); - - bool invalid{false}; - dh::XGBCachingDeviceAllocator alloc; - thrust::tie(invalid, max_cat) = - thrust::reduce(thrust::cuda::par(alloc), it, it + out_cut_values.size(), - thrust::make_pair(false, std::numeric_limits::min()), - [=] XGBOOST_DEVICE(thrust::pair const &l, - thrust::pair const &r) { - return thrust::make_pair(l.first || r.first, std::max(l.second, r.second)); - }); - if (invalid) { - InvalidCategory(); - } - } - p_cuts->SetCategorical(this->has_categorical_, max_cat); - timer_.Stop(__func__); } } // namespace common diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index cf5f726032e4..8748ddcbdf91 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -60,6 +60,9 @@ def test_gpu_hist(self, param, num_rounds, dataset): def test_categorical(self, rows, cols, rounds, cats): self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") + def test_max_cat(self) -> None: + self.cputest.run_max_cat("gpu_hist") + def test_categorical_32_cat(self): '''32 hits the bound of integer bitset, so special test''' rows = 1000 diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 4b56d37d4493..fa02b009a0f5 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -1,3 +1,5 @@ +from random import choice +from string import ascii_lowercase import testing as tm import pytest import xgboost as xgb @@ -169,6 +171,30 @@ def run_invalid_category(self, tree_method: str) -> None: def test_invalid_category(self) -> None: self.run_invalid_category("approx") + self.run_invalid_category("hist") + + def run_max_cat(self, tree_method: str) -> None: + """Test data with size smaller than number of categories.""" + import pandas as pd + n_cat = 100 + n = 5 + X = pd.Series( + ["".join(choice(ascii_lowercase) for i in range(3)) for i in range(n_cat)], + dtype="category", + )[:n].to_frame() + + reg = xgb.XGBRegressor( + enable_categorical=True, + tree_method=tree_method, + n_estimators=10, + ) + y = pd.Series(range(n)) + reg.fit(X=X, y=y, eval_set=[(X, y)]) + assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"]) + + @pytest.mark.parametrize("tree_method", ["hist", "approx"]) + def test_max_cat(self, tree_method) -> None: + self.run_max_cat(tree_method) def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): onehot, label = tm.make_categorical(rows, cols, cats, True)