diff --git a/doc/tutorials/categorical.rst b/doc/tutorials/categorical.rst index dd30a6ec4397..f302e5e47e79 100644 --- a/doc/tutorials/categorical.rst +++ b/doc/tutorials/categorical.rst @@ -114,11 +114,11 @@ Miscellaneous By default, XGBoost assumes input categories are integers starting from 0 till the number of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid -values due to mistakes or missing values. It can be negative value, floating point value -that can not be represented by 32-bit integer, or values that are larger than actual -number of unique categories. During training this is validated but for prediction it's -treated as the same as missing value for performance reasons. Lastly, missing values are -treated as the same as numerical features. +values due to mistakes or missing values. It can be negative value, integer values that +can not be accurately represented by 32-bit floating point, or values that are larger than +actual number of unique categories. During training this is validated but for prediction +it's treated as the same as missing value for performance reasons. Lastly, missing values +are treated as the same as numerical features (using the learned split direction). ********** Next Steps diff --git a/src/common/categorical.h b/src/common/categorical.h index e1d4d2c2a44c..ba6313225025 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -1,5 +1,5 @@ /*! - * Copyright 2020-2021 by XGBoost Contributors + * Copyright 2020-2022 by XGBoost Contributors * \file categorical.h */ #ifndef XGBOOST_COMMON_CATEGORICAL_H_ @@ -32,9 +32,17 @@ inline XGBOOST_DEVICE bool IsCat(Span ft, bst_feature_t fidx) return !ft.empty() && ft[fidx] == FeatureType::kCategorical; } +constexpr inline bst_cat_t OutOfRangeCat() { + // See the round trip assert in `InvalidCat`. + return static_cast(16777217) - static_cast(1); +} inline XGBOOST_DEVICE bool InvalidCat(float cat) { - return cat < 0 || cat > static_cast(std::numeric_limits::max()); + constexpr auto kMaxCat = OutOfRangeCat(); + static_assert(static_cast(static_cast(kMaxCat)) == kMaxCat, ""); + static_assert(static_cast(static_cast(kMaxCat + 1)) != kMaxCat + 1, ""); + static_assert(static_cast(kMaxCat + 1) == kMaxCat, ""); + return cat < 0 || cat >= kMaxCat; } /* \brief Whether should it traverse to left branch of a tree. @@ -53,9 +61,13 @@ inline XGBOOST_DEVICE bool Decision(common::Span cats, float cat } inline void InvalidCategory() { - LOG(FATAL) << "Invalid categorical value detected. Categorical value " - "should be non-negative, less than maximum size of int32 and less than total " - "number of categories in training data."; + // OutOfRangeCat() can be accurately represented, but everything after it will be + // rounded toward it, so we use >= for comparison check. As a result, we require input + // values to be less than this last representable value. + auto str = std::to_string(OutOfRangeCat()); + LOG(FATAL) << "Invalid categorical value detected. Categorical value should be non-negative, " + "less than total umber of categories in training data and less than " + + str; } /*! diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 8780c7539f3f..4712a6938527 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2020-2021 by XGBoost Contributors + * Copyright 2020-2022 by XGBoost Contributors */ #include #include @@ -27,6 +27,7 @@ SketchContainerImpl::SketchContainerImpl(std::vector column sketches_.resize(columns_size_.size()); CHECK_GE(n_threads_, 1); categories_.resize(columns_size_.size()); + has_categorical_ = std::any_of(feature_types_.cbegin(), feature_types_.cend(), IsCatOp{}); } template @@ -187,7 +188,7 @@ void SketchContainerImpl::PushRowPage(SparsePage const &page, MetaInfo if (is_dense) { for (size_t ii = begin; ii < end; ii++) { if (IsCat(feature_types_, ii)) { - categories_[ii].emplace(AsCat(p_inst[ii].fvalue)); + categories_[ii].emplace(p_inst[ii].fvalue); } else { sketches_[ii].Push(p_inst[ii].fvalue, w); } @@ -197,7 +198,7 @@ void SketchContainerImpl::PushRowPage(SparsePage const &page, MetaInfo auto const& entry = p_inst[i]; if (entry.index >= begin && entry.index < end) { if (IsCat(feature_types_, entry.index)) { - categories_[entry.index].emplace(AsCat(entry.fvalue)); + categories_[entry.index].emplace(entry.fvalue); } else { sketches_[entry.index].Push(entry.fvalue, w); } @@ -352,10 +353,10 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b } } -void AddCategories(std::set const &categories, HistogramCuts *cuts) { +void AddCategories(std::set const &categories, HistogramCuts *cuts) { auto &cut_values = cuts->cut_values_.HostVector(); for (auto const &v : categories) { - cut_values.push_back(v); + cut_values.push_back(AsCat(v)); } } @@ -410,6 +411,15 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back()); cuts->cut_ptrs_.HostVector().push_back(cut_size); } + + if (has_categorical_) { + for (auto const &feat : categories_) { + if (std::any_of(feat.cbegin(), feat.cend(), InvalidCat)) { + InvalidCategory(); + } + } + } + monitor_.Stop(__func__); } @@ -457,7 +467,7 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const & // second pass if (IsCat(feature_types_, fidx)) { for (auto c : column) { - categories_[fidx].emplace(AsCat(c.fvalue)); + categories_[fidx].emplace(c.fvalue); } } else { for (auto c : column) { diff --git a/src/common/quantile.h b/src/common/quantile.h index 37720694fd74..bd515d7afaff 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2021 by Contributors + * Copyright 2014-2022 by XGBoost Contributors * \file quantile.h * \brief util to compute quantiles * \author Tianqi Chen @@ -706,13 +706,14 @@ template class SketchContainerImpl { protected: std::vector sketches_; - std::vector> categories_; + std::vector> categories_; std::vector const feature_types_; std::vector columns_size_; int32_t max_bins_; bool use_group_ind_{false}; int32_t n_threads_; + bool has_categorical_{false}; Monitor monitor_; public: diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index c183ad910676..2d3a44226e54 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ #define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ @@ -303,8 +303,9 @@ template class HistEvaluator { if (candidate.split.is_cat) { std::vector split_cats; if (candidate.split.cat_bits.empty()) { - CHECK_LT(candidate.split.split_value, std::numeric_limits::max()) - << "Categorical feature value too large."; + if (common::InvalidCat(candidate.split.split_value)) { + common::InvalidCategory(); + } auto cat = common::AsCat(candidate.split.split_value); split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0); LBitField32 cat_bits; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 199be0a4c5d1..9f92b8654307 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #include #include @@ -572,11 +572,11 @@ struct GPUHistMakerDevice { if (is_cat) { CHECK_LT(candidate.split.fvalue, std::numeric_limits::max()) << "Categorical feature value too large."; - auto cat = common::AsCat(candidate.split.fvalue); - if (common::InvalidCat(cat)) { + if (common::InvalidCat(candidate.split.fvalue)) { common::InvalidCategory(); } - std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0); + auto cat = common::AsCat(candidate.split.fvalue); + std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0); LBitField32 cats_bits(split_cats); cats_bits.Set(cat); dh::CopyToD(split_cats, &node_categories); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 22799e533b0d..4aa3647e0f0f 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -60,20 +60,9 @@ def test_categorical_32_cat(self): rounds = 4 self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") + @pytest.mark.skipif(**tm.no_cupy()) def test_invalid_categorical(self): - import cupy as cp - rng = np.random.default_rng() - X = rng.normal(loc=0, scale=1, size=1000).reshape(100, 10) - y = rng.normal(loc=0, scale=1, size=100) - - # Check is performe during sketching. - Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10) - with pytest.raises(ValueError): - xgb.train({"tree_method": "gpu_hist"}, Xy) - - X, y = cp.array(X), cp.array(y) - with pytest.raises(ValueError): - Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10) + self.cputest.run_invalid_category("gpu_hist") @pytest.mark.skipif(**tm.no_cupy()) @given(parameter_strategy, strategies.integers(1, 20), diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 2af485676016..df8adcc424da 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -133,6 +133,41 @@ def test_hist_degenerate_case(self): w = [0, 0, 1, 0] model.fit(X, y, sample_weight=w) + def run_invalid_category(self, tree_method: str) -> None: + rng = np.random.default_rng() + # too large + X = rng.integers(low=0, high=4, size=1000).reshape(100, 10) + y = rng.normal(loc=0, scale=1, size=100) + X[13, 7] = np.iinfo(np.int32).max + 1 + + # Check is performed during sketching. + Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10) + with pytest.raises(ValueError): + xgb.train({"tree_method": tree_method}, Xy) + + X[13, 7] = 16777216 + Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10) + with pytest.raises(ValueError): + xgb.train({"tree_method": tree_method}, Xy) + + # mixed positive and negative values + X = rng.normal(loc=0, scale=1, size=1000).reshape(100, 10) + y = rng.normal(loc=0, scale=1, size=100) + + Xy = xgb.DMatrix(X, y, feature_types=["c"] * 10) + with pytest.raises(ValueError): + xgb.train({"tree_method": tree_method}, Xy) + + if tree_method == "gpu_hist": + import cupy as cp + + X, y = cp.array(X), cp.array(y) + with pytest.raises(ValueError): + Xy = xgb.DeviceQuantileDMatrix(X, y, feature_types=["c"] * 10) + + def test_invalid_category(self) -> None: + self.run_invalid_category("approx") + def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): onehot, label = tm.make_categorical(rows, cols, cats, True) cat, _ = tm.make_categorical(rows, cols, cats, False)