From 915d7280616e04f9e5b45a9e1ebc0589cf2aa07c Mon Sep 17 00:00:00 2001 From: jiamingy Date: Fri, 29 Apr 2022 23:03:04 +0800 Subject: [PATCH 01/12] Use max cat instead. --- src/common/quantile.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 2e9abb8f20b4..c7d0e5564164 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -470,8 +470,10 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b void AddCategories(std::set const &categories, HistogramCuts *cuts) { auto &cut_values = cuts->cut_values_.HostVector(); - for (auto const &v : categories) { - cut_values.push_back(AsCat(v)); + auto max_cat = std::accumulate(categories.cbegin(), categories.cend(), .0f, + [](auto l, auto r) { return std::max(l, r); }); + for (float i = 0; i < max_cat; ++i) { + cut_values.push_back(AsCat(i)); } } From 6bb12e2a9601c39e0f360c4b1f1923adb414e055 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 30 Apr 2022 00:05:22 +0800 Subject: [PATCH 02/12] GPU. --- src/common/quantile.cc | 1 + src/common/quantile.cu | 38 +++++++++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index c7d0e5564164..34970e325012 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -472,6 +472,7 @@ void AddCategories(std::set const &categories, HistogramCuts *cuts) { auto &cut_values = cuts->cut_values_.HostVector(); auto max_cat = std::accumulate(categories.cbegin(), categories.cend(), .0f, [](auto l, auto r) { return std::max(l, r); }); + CHECK_GE(max_cat + 1, categories.size()); for (float i = 0; i < max_cat; ++i) { cut_values.push_back(AsCat(i)); } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 1be6ea23bd30..160fab47834a 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2020 by XGBoost Contributors + * Copyright 2020-2022 by XGBoost Contributors */ #include #include @@ -611,7 +611,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { p_cuts->min_vals_.SetDevice(device_); auto d_min_values = p_cuts->min_vals_.DeviceSpan(); - auto in_cut_values = dh::ToSpan(this->Current()); + auto const in_cut_values = dh::ToSpan(this->Current()); // Set up output ptr p_cuts->cut_ptrs_.SetDevice(device_); @@ -619,21 +619,37 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { h_out_columns_ptr.clear(); h_out_columns_ptr.push_back(0); auto const& h_feature_types = this->feature_types_.ConstHostSpan(); + + // Set up output cuts + std::vector max_values; + if (has_categorical_) { + dh::XGBCachingDeviceAllocator alloc; + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(d_in_columns_ptr, i); }); + CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1); + max_values.resize(d_in_columns_ptr.size() - 1); + dh::caching_device_vector d_max_values(d_in_columns_ptr.size() - 1); + thrust::reduce_by_key(thrust::cuda::par(alloc), it, it + in_cut_values.size(), + dh::tbegin(in_cut_values), dh::TypedDiscard{}, + d_max_values.begin(), thrust::equal_to{}, + [] __device__(auto l, auto r) { return l.value > r.value ? l : r; }); + dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values)); + } + for (bst_feature_t i = 0; i < num_columns_; ++i) { - size_t column_size = std::max(static_cast(1ul), - this->Column(i).size()); + size_t column_size = std::max(static_cast(1ul), this->Column(i).size()); if (IsCat(h_feature_types, i)) { - h_out_columns_ptr.push_back(static_cast(column_size)); + CHECK_GE(max_values[i].value + 1, column_size); + h_out_columns_ptr.push_back(max_values[i].value); } else { - h_out_columns_ptr.push_back(std::min(static_cast(column_size), - static_cast(num_bins_))); + h_out_columns_ptr.push_back( + std::min(static_cast(column_size), static_cast(num_bins_))); } } - std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), - h_out_columns_ptr.begin()); + std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), h_out_columns_ptr.begin()); auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan(); - // Set up output cuts size_t total_bins = h_out_columns_ptr.back(); p_cuts->cut_values_.SetDevice(device_); p_cuts->cut_values_.Resize(total_bins); @@ -668,7 +684,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { if (IsCat(d_ft, column_id)) { assert(out_column.size() == in_column.size()); - out_column[idx] = in_column[idx].value; + out_column[idx] = idx; return; } From 1f40cbf8d1a81101b8d675242ca7e21e488380c8 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sat, 30 Apr 2022 00:15:19 +0800 Subject: [PATCH 03/12] Fixes. --- src/common/quantile.cc | 4 ++-- src/common/quantile.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 34970e325012..59818965b33f 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -473,8 +473,8 @@ void AddCategories(std::set const &categories, HistogramCuts *cuts) { auto max_cat = std::accumulate(categories.cbegin(), categories.cend(), .0f, [](auto l, auto r) { return std::max(l, r); }); CHECK_GE(max_cat + 1, categories.size()); - for (float i = 0; i < max_cat; ++i) { - cut_values.push_back(AsCat(i)); + for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) { + cut_values.push_back(i); } } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 160fab47834a..5e43ff8a8e11 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -641,7 +641,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { size_t column_size = std::max(static_cast(1ul), this->Column(i).size()); if (IsCat(h_feature_types, i)) { CHECK_GE(max_values[i].value + 1, column_size); - h_out_columns_ptr.push_back(max_values[i].value); + h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0. } else { h_out_columns_ptr.push_back( std::min(static_cast(column_size), static_cast(num_bins_))); From 4e82d5af95e0c1972874ee0fbbc98faae3b02001 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 30 Apr 2022 00:42:13 +0800 Subject: [PATCH 04/12] Test. --- tests/python-gpu/test_gpu_updaters.py | 3 +++ tests/python/test_updaters.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index e9d2bf06e229..435806f5d353 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -61,6 +61,9 @@ def test_gpu_hist(self, param, num_rounds, dataset): def test_categorical(self, rows, cols, rounds, cats): self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") + def test_max_cat(self) -> None: + self.cputest.run_max_cat("gpu_hist") + def test_categorical_32_cat(self): '''32 hits the bound of integer bitset, so special test''' rows = 1000 diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 4b56d37d4493..fa02b009a0f5 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -1,3 +1,5 @@ +from random import choice +from string import ascii_lowercase import testing as tm import pytest import xgboost as xgb @@ -169,6 +171,30 @@ def run_invalid_category(self, tree_method: str) -> None: def test_invalid_category(self) -> None: self.run_invalid_category("approx") + self.run_invalid_category("hist") + + def run_max_cat(self, tree_method: str) -> None: + """Test data with size smaller than number of categories.""" + import pandas as pd + n_cat = 100 + n = 5 + X = pd.Series( + ["".join(choice(ascii_lowercase) for i in range(3)) for i in range(n_cat)], + dtype="category", + )[:n].to_frame() + + reg = xgb.XGBRegressor( + enable_categorical=True, + tree_method=tree_method, + n_estimators=10, + ) + y = pd.Series(range(n)) + reg.fit(X=X, y=y, eval_set=[(X, y)]) + assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"]) + + @pytest.mark.parametrize("tree_method", ["hist", "approx"]) + def test_max_cat(self, tree_method) -> None: + self.run_max_cat(tree_method) def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): onehot, label = tm.make_categorical(rows, cols, cats, True) From 759ef91fcc0ac29f96e45a560bfce70a8deb3195 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 30 Apr 2022 01:02:18 +0800 Subject: [PATCH 05/12] Use contructor. --- src/common/device_helpers.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 334e3b4f89bf..ebaeabb3ae4e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -725,6 +725,7 @@ using TypedDiscardCTK114 = thrust::discard_iterator; template class TypedDiscard : public thrust::discard_iterator { public: + using thrust::discard_iterator::discard_iterator; using value_type = T; // NOLINT }; } // namespace detail From f1971e185c39254876c5c0368191dc92397f16d5 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 30 Apr 2022 01:29:36 +0800 Subject: [PATCH 06/12] Use normal discard iter. --- src/common/quantile.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 5e43ff8a8e11..7a572f7fa5ed 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -625,13 +625,14 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { if (has_categorical_) { dh::XGBCachingDeviceAllocator alloc; auto it = dh::MakeTransformIterator( - thrust::make_counting_iterator(0ul), - [=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(d_in_columns_ptr, i); }); + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t { + return dh::SegmentId(d_in_columns_ptr, i); + }); CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1); max_values.resize(d_in_columns_ptr.size() - 1); dh::caching_device_vector d_max_values(d_in_columns_ptr.size() - 1); thrust::reduce_by_key(thrust::cuda::par(alloc), it, it + in_cut_values.size(), - dh::tbegin(in_cut_values), dh::TypedDiscard{}, + dh::tbegin(in_cut_values), thrust::make_discard_iterator(), d_max_values.begin(), thrust::equal_to{}, [] __device__(auto l, auto r) { return l.value > r.value ? l : r; }); dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values)); From 2550b62df960e73a27263d3dbcee15ab66e8f9ea Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 30 Apr 2022 01:29:55 +0800 Subject: [PATCH 07/12] Revert "Use contructor." This reverts commit 759ef91fcc0ac29f96e45a560bfce70a8deb3195. --- src/common/device_helpers.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index ebaeabb3ae4e..334e3b4f89bf 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -725,7 +725,6 @@ using TypedDiscardCTK114 = thrust::discard_iterator; template class TypedDiscard : public thrust::discard_iterator { public: - using thrust::discard_iterator::discard_iterator; using value_type = T; // NOLINT }; } // namespace detail From a788dcfc9267f3edf7a8166f49d6135ce127cb40 Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 30 Apr 2022 02:34:17 +0800 Subject: [PATCH 08/12] Fuse it into a single call. --- src/common/categorical.h | 2 +- src/common/common.h | 2 ++ src/common/quantile.cu | 74 +++++++++++++++++++--------------------- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/src/common/categorical.h b/src/common/categorical.h index 5eff62264cf2..e3c8c61ae4fb 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -75,7 +75,7 @@ inline void InvalidCategory() { // values to be less than this last representable value. auto str = std::to_string(OutOfRangeCat()); LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, " - "less than total umber of categories in training data and less than " + + "less than total number of categories in training data and less than " + str; } diff --git a/src/common/common.h b/src/common/common.h index aa2d8197b4a1..69475457a2f6 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -192,6 +192,8 @@ class IndexTransformIter { value_type operator*() const { return fn_(iter_); } auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; } + bool operator==(IndexTransformIter const &that) const { return iter_ == that.iter_; } + bool operator!=(IndexTransformIter const &that) const { return !(*this == that); } IndexTransformIter &operator++() { iter_++; diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 7a572f7fa5ed..d263753ceb71 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -583,13 +583,13 @@ void SketchContainer::AllReduce() { namespace { struct InvalidCatOp { - Span values; - Span ptrs; + Span values; + Span ptrs; Span ft; XGBOOST_DEVICE bool operator()(size_t i) const { auto fidx = dh::SegmentId(ptrs, i); - return IsCat(ft, fidx) && InvalidCat(values[i]); + return IsCat(ft, fidx) && InvalidCat(values[i].value); } }; } // anonymous namespace @@ -620,24 +620,52 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { h_out_columns_ptr.push_back(0); auto const& h_feature_types = this->feature_types_.ConstHostSpan(); - // Set up output cuts + auto d_ft = feature_types_.ConstDeviceSpan(); + std::vector max_values; + float max_cat{0.0}; if (has_categorical_) { dh::XGBCachingDeviceAllocator alloc; - auto it = dh::MakeTransformIterator( + auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t { return dh::SegmentId(d_in_columns_ptr, i); }); + auto invalid_op = InvalidCatOp{in_cut_values, d_in_columns_ptr, d_ft}; + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { + auto fidx = dh::SegmentId(d_in_columns_ptr, i); + auto v = in_cut_values[i]; + if (IsCat(d_ft, fidx)) { + if (invalid_op(i)) { + // use inf to indicate invalid value, this way we can keep it as in + // indicator in the reduce operation as it's always the greatest value. + v.value = std::numeric_limits::infinity(); + } + } + return v; + }); CHECK_EQ(num_columns_, d_in_columns_ptr.size() - 1); max_values.resize(d_in_columns_ptr.size() - 1); dh::caching_device_vector d_max_values(d_in_columns_ptr.size() - 1); - thrust::reduce_by_key(thrust::cuda::par(alloc), it, it + in_cut_values.size(), - dh::tbegin(in_cut_values), thrust::make_discard_iterator(), - d_max_values.begin(), thrust::equal_to{}, + thrust::reduce_by_key(thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, + thrust::make_discard_iterator(), d_max_values.begin(), + thrust::equal_to{}, [] __device__(auto l, auto r) { return l.value > r.value ? l : r; }); dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_values)); + auto max_it = common::MakeIndexTransformIter([&](auto i) { + if (IsCat(h_feature_types, i)) { + return max_values[i].value; + } + return .0f; + }); + max_cat = std::accumulate(max_it, max_it + max_values.size(), 0.f, + [](float l, float r) { return l > r ? l : r; }); + if (std::isinf(max_cat)) { + InvalidCategory(); + } } + // Set up output cuts for (bst_feature_t i = 0; i < num_columns_; ++i) { size_t column_size = std::max(static_cast(1ul), this->Column(i).size()); if (IsCat(h_feature_types, i)) { @@ -655,7 +683,6 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { p_cuts->cut_values_.SetDevice(device_); p_cuts->cut_values_.Resize(total_bins); auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); - auto d_ft = feature_types_.ConstDeviceSpan(); dh::LaunchN(total_bins, [=] __device__(size_t idx) { auto column_id = dh::SegmentId(d_out_columns_ptr, idx); @@ -701,36 +728,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { out_column[idx] = in_column[idx+1].value; }); - float max_cat{-1.0f}; - if (has_categorical_) { - auto invalid_op = InvalidCatOp{out_cut_values, d_out_columns_ptr, d_ft}; - auto it = dh::MakeTransformIterator>( - thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { - auto fidx = dh::SegmentId(d_out_columns_ptr, i); - if (IsCat(d_ft, fidx)) { - auto invalid = invalid_op(i); - auto v = out_cut_values[i]; - return thrust::make_pair(invalid, v); - } - return thrust::make_pair(false, std::numeric_limits::min()); - }); - - bool invalid{false}; - dh::XGBCachingDeviceAllocator alloc; - thrust::tie(invalid, max_cat) = - thrust::reduce(thrust::cuda::par(alloc), it, it + out_cut_values.size(), - thrust::make_pair(false, std::numeric_limits::min()), - [=] XGBOOST_DEVICE(thrust::pair const &l, - thrust::pair const &r) { - return thrust::make_pair(l.first || r.first, std::max(l.second, r.second)); - }); - if (invalid) { - InvalidCategory(); - } - } - p_cuts->SetCategorical(this->has_categorical_, max_cat); - timer_.Stop(__func__); } } // namespace common From fe22620d918385de1ee7b08aa540363c484a8598 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sat, 30 Apr 2022 16:06:26 +0800 Subject: [PATCH 09/12] Simplify the code. --- src/common/quantile.cc | 37 +++++++++---------------------------- src/common/quantile.cu | 7 +++---- 2 files changed, 12 insertions(+), 32 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 59818965b33f..12cee0e5faa8 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -468,14 +468,17 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b } } -void AddCategories(std::set const &categories, HistogramCuts *cuts) { +auto AddCategories(std::set const &categories, HistogramCuts *cuts) { + if (std::any_of(categories.cbegin(), categories.cend(), InvalidCat)) { + InvalidCategory(); + } auto &cut_values = cuts->cut_values_.HostVector(); - auto max_cat = std::accumulate(categories.cbegin(), categories.cend(), .0f, - [](auto l, auto r) { return std::max(l, r); }); + auto max_cat = *std::max_element(categories.cbegin(), categories.cend()); CHECK_GE(max_cat + 1, categories.size()); for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) { cut_values.push_back(i); } + return max_cat; } template @@ -508,11 +511,12 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { } }); + float max_cat{-1.f}; for (size_t fid = 0; fid < reduced.size(); ++fid) { size_t max_num_bins = std::min(num_cuts[fid], max_bins_); typename WQSketch::SummaryContainer const& a = final_summaries[fid]; if (IsCat(feature_types_, fid)) { - AddCategories(categories_.at(fid), cuts); + max_cat = std::max(max_cat, AddCategories(categories_.at(fid), cuts)); } else { AddCutPoint(a, max_num_bins, cuts); // push a value that is greater than anything @@ -530,30 +534,7 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { cuts->cut_ptrs_.HostVector().push_back(cut_size); } - if (has_categorical_) { - for (auto const &feat : categories_) { - if (std::any_of(feat.cbegin(), feat.cend(), InvalidCat)) { - InvalidCategory(); - } - } - auto const &ptrs = cuts->Ptrs(); - auto const &vals = cuts->Values(); - - float max_cat{-std::numeric_limits::infinity()}; - for (size_t i = 1; i < ptrs.size(); ++i) { - if (IsCat(feature_types_, i - 1)) { - auto beg = ptrs[i - 1]; - auto end = ptrs[i]; - auto feat = Span{vals}.subspan(beg, end - beg); - auto max_elem = *std::max_element(feat.cbegin(), feat.cend()); - if (max_elem > max_cat) { - max_cat = max_elem; - } - } - } - cuts->SetCategorical(true, max_cat); - } - + cuts->SetCategorical(this->has_categorical_, max_cat); monitor_.Stop(__func__); } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index d263753ceb71..648065c6c0e1 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -623,7 +623,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { auto d_ft = feature_types_.ConstDeviceSpan(); std::vector max_values; - float max_cat{0.0}; + float max_cat{-1.f}; if (has_categorical_) { dh::XGBCachingDeviceAllocator alloc; auto key_it = dh::MakeTransformIterator( @@ -656,10 +656,9 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { if (IsCat(h_feature_types, i)) { return max_values[i].value; } - return .0f; + return -1.f; }); - max_cat = std::accumulate(max_it, max_it + max_values.size(), 0.f, - [](float l, float r) { return l > r ? l : r; }); + max_cat = *std::max_element(max_it, max_it + max_values.size()); if (std::isinf(max_cat)) { InvalidCategory(); } From d16aac5952f9ace8b4335c4e762cd94ac17c324d Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sat, 30 Apr 2022 16:12:26 +0800 Subject: [PATCH 10/12] Check max cat. --- src/common/categorical.h | 5 +++++ src/common/quantile.cc | 2 +- src/common/quantile.cu | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/common/categorical.h b/src/common/categorical.h index e3c8c61ae4fb..67d44a6c456f 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -79,6 +79,11 @@ inline void InvalidCategory() { str; } +inline void CheckMaxCat(float max_cat, size_t n_categories) { + CHECK_GE(max_cat + 1, n_categories) + << "Maximum cateogry should not be lesser than the total number of categories."; +} + /*! * \brief Whether should we use onehot encoding for categorical data. */ diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 12cee0e5faa8..13a5a15560ba 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -474,7 +474,7 @@ auto AddCategories(std::set const &categories, HistogramCuts *cuts) { } auto &cut_values = cuts->cut_values_.HostVector(); auto max_cat = *std::max_element(categories.cbegin(), categories.cend()); - CHECK_GE(max_cat + 1, categories.size()); + CheckMaxCat(max_cat, categories.size()); for (bst_cat_t i = 0; i <= AsCat(max_cat); ++i) { cut_values.push_back(i); } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 648065c6c0e1..a1abde029f50 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -668,7 +668,8 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { for (bst_feature_t i = 0; i < num_columns_; ++i) { size_t column_size = std::max(static_cast(1ul), this->Column(i).size()); if (IsCat(h_feature_types, i)) { - CHECK_GE(max_values[i].value + 1, column_size); + // column_size is the number of unique values in that feature. + CheckMaxCat(max_values[i].value, column_size); h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0. } else { h_out_columns_ptr.push_back( From 4094515932d2601d578097594c12c8b60fecd32f Mon Sep 17 00:00:00 2001 From: fis Date: Sat, 30 Apr 2022 16:49:04 +0800 Subject: [PATCH 11/12] Assign. --- src/common/common.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/common/common.h b/src/common/common.h index 69475457a2f6..4949d61e4582 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -188,6 +188,11 @@ class IndexTransformIter { */ explicit IndexTransformIter(Fn &&op) : fn_{op} {} IndexTransformIter(IndexTransformIter const &) = default; + IndexTransformIter& operator=(IndexTransformIter&&) = default; + IndexTransformIter& operator=(IndexTransformIter const& that) { + iter_ = that.iter_; + return *this; + } value_type operator*() const { return fn_(iter_); } From 7510c349edf37d51c1e078354349924a277151f9 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sat, 30 Apr 2022 17:36:25 +0800 Subject: [PATCH 12/12] Remove assert --- src/common/quantile.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/src/common/quantile.cu b/src/common/quantile.cu index a1abde029f50..33179551631b 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -711,7 +711,6 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { } if (IsCat(d_ft, column_id)) { - assert(out_column.size() == in_column.size()); out_column[idx] = idx; return; }