From e5e47c3c998e9bb264ab8d690694f0371cbef459 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 13 Jan 2022 16:11:52 +0800 Subject: [PATCH] Clarify the behavior of invalid categorical value handling. (#7529) --- doc/tutorials/categorical.rst | 12 ++++++++ src/common/categorical.h | 28 +++++++++++------- src/common/quantile.cu | 10 +++---- src/predictor/predict_fn.h | 12 ++++---- src/tree/updater_approx.h | 2 +- src/tree/updater_gpu_hist.cu | 6 ++-- tests/cpp/common/test_categorical.cc | 43 ++++++++++++++++++++++++++++ 7 files changed, 88 insertions(+), 25 deletions(-) create mode 100644 tests/cpp/common/test_categorical.cc diff --git a/doc/tutorials/categorical.rst b/doc/tutorials/categorical.rst index a56b946476ba..dd30a6ec4397 100644 --- a/doc/tutorials/categorical.rst +++ b/doc/tutorials/categorical.rst @@ -108,6 +108,18 @@ feature it's specified as ``"c"``. The Dask module in XGBoost has the same inte :class:`dask.Array ` can also be used as categorical data. +************* +Miscellaneous +************* + +By default, XGBoost assumes input categories are integers starting from 0 till the number +of categories :math:`[0, n_categories)`. However, user might provide inputs with invalid +values due to mistakes or missing values. It can be negative value, floating point value +that can not be represented by 32-bit integer, or values that are larger than actual +number of unique categories. During training this is validated but for prediction it's +treated as the same as missing value for performance reasons. Lastly, missing values are +treated as the same as numerical features. + ********** Next Steps ********** diff --git a/src/common/categorical.h b/src/common/categorical.h index 4cbbbf72ba60..e1d4d2c2a44c 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -5,6 +5,8 @@ #ifndef XGBOOST_COMMON_CATEGORICAL_H_ #define XGBOOST_COMMON_CATEGORICAL_H_ +#include + #include "bitfield.h" #include "xgboost/base.h" #include "xgboost/data.h" @@ -30,22 +32,30 @@ inline XGBOOST_DEVICE bool IsCat(Span ft, bst_feature_t fidx) return !ft.empty() && ft[fidx] == FeatureType::kCategorical; } + +inline XGBOOST_DEVICE bool InvalidCat(float cat) { + return cat < 0 || cat > static_cast(std::numeric_limits::max()); +} + /* \brief Whether should it traverse to left branch of a tree. * * For one hot split, go to left if it's NOT the matching category. */ -inline XGBOOST_DEVICE bool Decision(common::Span cats, bst_cat_t cat) { - auto pos = CLBitField32::ToBitPos(cat); - if (pos.int_pos >= cats.size()) { - return true; - } +template +inline XGBOOST_DEVICE bool Decision(common::Span cats, float cat, bool dft_left) { CLBitField32 const s_cats(cats); - return !s_cats.Check(cat); + // FIXME: Size() is not accurate since it represents the size of bit set instead of + // actual number of categories. + if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) { + return dft_left; + } + return !s_cats.Check(AsCat(cat)); } inline void InvalidCategory() { LOG(FATAL) << "Invalid categorical value detected. Categorical value " - "should be non-negative."; + "should be non-negative, less than maximum size of int32 and less than total " + "number of categories in training data."; } /*! @@ -58,9 +68,7 @@ inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) } struct IsCatOp { - XGBOOST_DEVICE bool operator()(FeatureType ft) { - return ft == FeatureType::kCategorical; - } + XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; } }; using CatBitField = LBitField32; diff --git a/src/common/quantile.cu b/src/common/quantile.cu index d89951915d4f..d15d310c0516 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -581,14 +581,14 @@ void SketchContainer::AllReduce() { } namespace { -struct InvalidCat { +struct InvalidCatOp { Span values; Span ptrs; Span ft; XGBOOST_DEVICE bool operator()(size_t i) { auto fidx = dh::SegmentId(ptrs, i); - return IsCat(ft, fidx) && values[i] < 0; + return IsCat(ft, fidx) && InvalidCat(values[i]); } }; } // anonymous namespace @@ -687,10 +687,10 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { dh::XGBCachingDeviceAllocator alloc; auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan(); auto it = thrust::make_counting_iterator(0ul); + CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size()); - auto invalid = - thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(), - InvalidCat{out_cut_values, ptrs, d_ft}); + auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(), + InvalidCatOp{out_cut_values, ptrs, d_ft}); if (invalid) { InvalidCategory(); } diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index 1547d6e774ae..7ce474023e8a 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -9,16 +9,16 @@ namespace xgboost { namespace predictor { template -inline XGBOOST_DEVICE bst_node_t -GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue, - bool is_missing, RegTree::CategoricalSplitMatrix const &cats) { +inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid, + float fvalue, bool is_missing, + RegTree::CategoricalSplitMatrix const &cats) { if (has_missing && is_missing) { return node.DefaultChild(); } else { if (has_categorical && common::IsCat(cats.split_type, nid)) { - auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, - cats.node_ptr[nid].size); - return Decision(node_categories, common::AsCat(fvalue)) + auto node_categories = + cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size); + return common::Decision(node_categories, fvalue, node.DefaultLeft()) ? node.LeftChild() : node.RightChild(); } else { diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h index 5e16f568f80b..158ab2b2c12a 100644 --- a/src/tree/updater_approx.h +++ b/src/tree/updater_approx.h @@ -95,7 +95,7 @@ class ApproxRowPartitioner { auto node_cats = categories.subspan(segment.beg, segment.size); bool go_left = true; if (is_cat) { - go_left = common::Decision(node_cats, common::AsCat(cut_value)); + go_left = common::Decision(node_cats, cut_value, candidate.split.DefaultLeft()); } else { go_left = cut_value <= candidate.split.split_value; } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 48d58074ef19..199be0a4c5d1 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -396,7 +396,7 @@ struct GPUHistMakerDevice { } else { bool go_left = true; if (split_type == FeatureType::kCategorical) { - go_left = common::Decision(node_cats, common::AsCat(cut_value)); + go_left = common::Decision(node_cats, cut_value, split_node.DefaultLeft()); } else { go_left = cut_value <= split_node.SplitCond(); } @@ -474,7 +474,7 @@ struct GPUHistMakerDevice { auto node_cats = categories.subspan(categories_segments[position].beg, categories_segments[position].size); - go_left = common::Decision(node_cats, common::AsCat(element)); + go_left = common::Decision(node_cats, element, node.DefaultLeft()); } else { go_left = element <= node.SplitCond(); } @@ -573,7 +573,7 @@ struct GPUHistMakerDevice { CHECK_LT(candidate.split.fvalue, std::numeric_limits::max()) << "Categorical feature value too large."; auto cat = common::AsCat(candidate.split.fvalue); - if (cat < 0) { + if (common::InvalidCat(cat)) { common::InvalidCategory(); } std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0); diff --git a/tests/cpp/common/test_categorical.cc b/tests/cpp/common/test_categorical.cc new file mode 100644 index 000000000000..cc8eb0f7e6c4 --- /dev/null +++ b/tests/cpp/common/test_categorical.cc @@ -0,0 +1,43 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#include + +#include + +#include "../../../src/common/categorical.h" + +namespace xgboost { +namespace common { +TEST(Categorical, Decision) { + // inf + float a = std::numeric_limits::infinity(); + + ASSERT_TRUE(common::InvalidCat(a)); + std::vector cats(256, 0); + ASSERT_TRUE(Decision(cats, a, true)); + + // larger than size + a = 256; + ASSERT_TRUE(Decision(cats, a, true)); + + // negative + a = -1; + ASSERT_TRUE(Decision(cats, a, true)); + + CatBitField bits{cats}; + bits.Set(0); + a = -0.5; + ASSERT_TRUE(Decision(cats, a, true)); + + // round toward 0 + a = 0.5; + ASSERT_FALSE(Decision(cats, a, true)); + + // valid + a = 13; + bits.Set(a); + ASSERT_FALSE(Decision(bits.Bits(), a, true)); +} +} // namespace common +} // namespace xgboost