diff --git a/doc/parameter.rst b/doc/parameter.rst index eca5d46eebe9..3a35666338dd 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -235,14 +235,19 @@ Parameters for Tree Booster list is a group of indices of features that are allowed to interact with each other. See :doc:`/tutorials/feature_interaction_constraint` for more information. -Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method -=========================================================================== +.. _cat-param: + +Parameters for Categorical Feature +================================== + +These parameters are only used for training with categorical data. See +:doc:`/tutorials/categorical` for more information. * ``max_cat_to_onehot`` .. versionadded:: 1.6 - .. note:: The support for this parameter is experimental. + .. note:: This parameter is experimental. ``exact`` tree method is not supported yet. - A threshold for deciding whether XGBoost should use one-hot encoding based split for categorical data. When number of categories is lesser than the threshold then one-hot @@ -250,6 +255,16 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method Only relevant for regression and binary classification. Also, ``exact`` tree method is not supported +* ``max_cat_threshold`` + + .. versionadded:: 2.0 + + .. note:: This parameter is experimental. ``exact`` and ``gpu_hist`` tree methods are + not supported yet. + + - Maximum number of categories considered for each split. Used only by partition-based + splits for preventing over-fitting. + Additional parameters for Dart Booster (``booster=dart``) ========================================================= diff --git a/doc/tutorials/categorical.rst b/doc/tutorials/categorical.rst index f5ae16bafc10..76b88e67dece 100644 --- a/doc/tutorials/categorical.rst +++ b/doc/tutorials/categorical.rst @@ -85,7 +85,7 @@ group the categories that output similar leaf values. During split finding, we f the gradient histogram to prepare the contiguous partitions then enumerate the splits according to these sorted values. One of the related parameters for XGBoost is ``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be -used for each feature, see :doc:`/parameter` for details. +used for each feature, see :ref:`cat-param` for details. ********************** diff --git a/src/common/categorical.h b/src/common/categorical.h index a54d823d8267..ead5f570c44f 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -54,7 +54,7 @@ inline XGBOOST_DEVICE bool InvalidCat(float cat) { */ template inline XGBOOST_DEVICE bool Decision(common::Span cats, float cat, bool dft_left) { - CLBitField32 const s_cats(cats); + KCatBitField const s_cats(cats); // FIXME: Size() is not accurate since it represents the size of bit set instead of // actual number of categories. if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) { diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 3120126829a8..9377ca0bd926 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -144,7 +144,8 @@ class HistEvaluator { auto const &cut_ptr = cut.Ptrs(); auto const &parent = snode_[nidx]; - bst_bin_t n_bins{static_cast(cut_ptr[fidx + 1] - cut_ptr[fidx])}; + bst_bin_t n_bins_feature{static_cast(cut_ptr[fidx + 1] - cut_ptr[fidx])}; + auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature); // statistics on both sides of split GradStats left_sum; @@ -152,7 +153,7 @@ class HistEvaluator { // best split so far SplitEntry best; - auto f_hist = hist.subspan(cut_ptr[fidx], n_bins); + auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature); bst_bin_t ibegin, iend; bst_bin_t f_begin = cut_ptr[fidx]; if (d_step > 0) { @@ -160,7 +161,7 @@ class HistEvaluator { iend = ibegin + n_bins - 1; } else { ibegin = static_cast(cut_ptr[fidx + 1]) - 1; - iend = f_begin; + iend = ibegin - n_bins + 1; } bst_bin_t best_thresh{-1}; @@ -177,7 +178,7 @@ class HistEvaluator { auto loss_chg = evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) - parent.root_gain; - // We don't have a numeric split point, nan hare is a dummy split. + // We don't have a numeric split point, nan here is a dummy split. if (best.Update(loss_chg, fidx, std::numeric_limits::quiet_NaN(), d_step == 1, true, left_sum, right_sum)) { best_thresh = i; @@ -186,10 +187,11 @@ class HistEvaluator { } if (best_thresh != -1) { - auto n = common::CatBitField::ComputeStorageSize(n_bins + 1); + auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1); best.cat_bits = decltype(best.cat_bits)(n, 0); common::CatBitField cat_bits{best.cat_bits}; - bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend; + bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin); + CHECK_GT(partition, 0); std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) { cat_bits.Set(c); }); } diff --git a/src/tree/param.h b/src/tree/param.h index ab9e230984aa..5c6c2e11375d 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -40,6 +40,8 @@ struct TrainParam : public XGBoostParameter { uint32_t max_cat_to_onehot{4}; + bst_bin_t max_cat_threshold{64}; + //----- the rest parameters are less important ---- // minimum amount of hessian(weight) allowed in a child float min_child_weight; @@ -113,6 +115,12 @@ struct TrainParam : public XGBoostParameter { .set_default(4) .set_lower_bound(1) .describe("Maximum number of categories to use one-hot encoding based split."); + DMLC_DECLARE_FIELD(max_cat_threshold) + .set_default(64) + .set_lower_bound(1) + .describe( + "Maximum number of categories considered for split. Used only by partition-based" + "splits."); DMLC_DECLARE_FIELD(min_child_weight) .set_lower_bound(0.0f) .set_default(1.0f) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index fd27e37711e0..043549263e04 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -74,8 +74,8 @@ def test_sparse(self, dataset): strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None, print_blob=True) @pytest.mark.skipif(**tm.no_pandas()) - def test_categorical(self, rows, cols, rounds, cats): - self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") + def test_categorical_ohe(self, rows, cols, rounds, cats): + self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist") @given( strategies.integers(10, 400), @@ -96,7 +96,7 @@ def test_categorical_32_cat(self): cols = 10 cats = 32 rounds = 4 - self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist") + self.cputest.run_categorical_ohe(rows, cols, rounds, cats, "gpu_hist") @pytest.mark.skipif(**tm.no_cupy()) def test_invalid_category(self): diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 889a7c77f783..43550c864446 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -31,6 +31,14 @@ x['max_depth'] > 0 or x['grow_policy'] == 'lossguide')) +cat_parameter_strategy = strategies.fixed_dictionaries( + { + "max_cat_to_onehot": strategies.integers(1, 128), + "max_cat_threshold": strategies.integers(1, 128), + } +) + + def train_result(param, dmat, num_rounds): result = {} xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False, @@ -253,7 +261,7 @@ def run(max_cat_to_onehot: int): # Test with partition-based split run(self.USE_PART) - def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): + def run_categorical_ohe(self, rows, cols, rounds, cats, tree_method): onehot, label = tm.make_categorical(rows, cols, cats, True) cat, _ = tm.make_categorical(rows, cols, cats, False) @@ -328,9 +336,55 @@ def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None, print_blob=True) @pytest.mark.skipif(**tm.no_pandas()) - def test_categorical(self, rows, cols, rounds, cats): - self.run_categorical_basic(rows, cols, rounds, cats, "approx") - self.run_categorical_basic(rows, cols, rounds, cats, "hist") + def test_categorical_ohe(self, rows, cols, rounds, cats): + self.run_categorical_ohe(rows, cols, rounds, cats, "approx") + self.run_categorical_ohe(rows, cols, rounds, cats, "hist") + + @given( + tm.categorical_dataset_strategy, + exact_parameter_strategy, + hist_parameter_strategy, + cat_parameter_strategy, + strategies.integers(4, 32), + strategies.sampled_from(["hist", "approx"]), + ) + @settings(deadline=None, print_blob=True) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical( + self, + dataset: tm.TestDataset, + exact_parameters: Dict[str, Any], + hist_parameters: Dict[str, Any], + cat_parameters: Dict[str, Any], + n_rounds: int, + tree_method: str, + ) -> None: + cat_parameters.update(exact_parameters) + cat_parameters.update(hist_parameters) + cat_parameters["tree_method"] = tree_method + + results = train_result(cat_parameters, dataset.get_dmat(), n_rounds) + tm.non_increasing(results["train"]["rmse"]) + + @given( + hist_parameter_strategy, + cat_parameter_strategy, + strategies.sampled_from(["hist", "approx"]), + ) + @settings(deadline=None, print_blob=True) + def test_categorical_ames_housing( + self, + hist_parameters: Dict[str, Any], + cat_parameters: Dict[str, Any], + tree_method: str, + ) -> None: + cat_parameters.update(hist_parameters) + dataset = tm.TestDataset( + "ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse" + ) + cat_parameters["tree_method"] = tree_method + results = train_result(cat_parameters, dataset.get_dmat(), 16) + tm.non_increasing(results["train"]["rmse"]) @given( strategies.integers(10, 400), diff --git a/tests/python/testing.py b/tests/python/testing.py index cf723e12b0b6..8ff105e0c816 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -214,7 +214,9 @@ def set_params(self, params_in): return params_in def get_dmat(self): - return xgb.DMatrix(self.X, self.y, self.w, base_margin=self.margin) + return xgb.DMatrix( + self.X, self.y, self.w, base_margin=self.margin, enable_categorical=True + ) def get_device_dmat(self): w = None if self.w is None else cp.array(self.w) @@ -277,6 +279,48 @@ def get_sparse(): return X, y +@memory.cache +def get_ames_housing(): + """ + Number of samples: 1460 + Number of features: 20 + Number of categorical features: 10 + Number of numerical features: 10 + """ + from sklearn.datasets import fetch_openml + X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True) + + categorical_columns_subset: list[str] = [ + "BldgType", # 5 cats, no nan + "GarageFinish", # 3 cats, nan + "LotConfig", # 5 cats, no nan + "Functional", # 7 cats, no nan + "MasVnrType", # 4 cats, nan + "HouseStyle", # 8 cats, no nan + "FireplaceQu", # 5 cats, nan + "ExterCond", # 5 cats, no nan + "ExterQual", # 4 cats, no nan + "PoolQC", # 3 cats, nan + ] + + numerical_columns_subset: list[str] = [ + "3SsnPorch", + "Fireplaces", + "BsmtHalfBath", + "HalfBath", + "GarageCars", + "TotRmsAbvGrd", + "BsmtFinSF1", + "BsmtFinSF2", + "GrLivArea", + "ScreenPorch", + ] + + X = X[categorical_columns_subset + numerical_columns_subset] + X[categorical_columns_subset] = X[categorical_columns_subset].astype("category") + return X, y + + @memory.cache def get_mq2008(dpath): from sklearn.datasets import load_svmlight_files @@ -329,7 +373,6 @@ def make_categorical( for i in range(n_features): index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity)) df.iloc[index, i] = np.NaN - assert df.iloc[:, i].isnull().values.any() assert n_categories == np.unique(df.dtypes[i].categories).size if onehot: @@ -337,6 +380,41 @@ def make_categorical( return df, label +def _cat_sampled_from(): + @strategies.composite + def _make_cat(draw): + n_samples = draw(strategies.integers(2, 512)) + n_features = draw(strategies.integers(1, 4)) + n_cats = draw(strategies.integers(1, 128)) + sparsity = draw( + strategies.floats( + min_value=0, + max_value=1, + allow_nan=False, + allow_infinity=False, + allow_subnormal=False, + ) + ) + return n_samples, n_features, n_cats, sparsity + + def _build(args): + n_samples = args[0] + n_features = args[1] + n_cats = args[2] + sparsity = args[3] + return TestDataset( + f"{n_samples}x{n_features}-{n_cats}-{sparsity}", + lambda: make_categorical(n_samples, n_features, n_cats, False, sparsity), + "reg:squarederror", + "rmse", + ) + + return _make_cat().map(_build) + + +categorical_dataset_strategy = _cat_sampled_from() + + @memory.cache def make_sparse_regression( n_samples: int, n_features: int, sparsity: float, as_dense: bool