Skip to content

Commit

Permalink
Update documents and tests. (#7659)
Browse files Browse the repository at this point in the history

* Revise documents after recent refactoring and cat support.
* Add tests for behavior of max_depth and max_leaves.
  • Loading branch information
trivialfis committed Feb 25, 2022
1 parent 5eed299 commit 18a4af6
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 44 deletions.
14 changes: 6 additions & 8 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ Parameters for Tree Booster

* ``max_depth`` [default=6]

- Maximum depth of a tree. Increasing this value will make the model more complex and more likely to overfit. 0 is only accepted in ``lossguide`` growing policy when ``tree_method`` is set as ``hist`` or ``gpu_hist`` and it indicates no limit on depth. Beware that XGBoost aggressively consumes memory when training a deep tree.
- range: [0,∞] (0 is only accepted in ``lossguide`` growing policy when ``tree_method`` is set as ``hist`` or ``gpu_hist``)
- Maximum depth of a tree. Increasing this value will make the model more complex and more likely to overfit. 0 indicates no limit on depth. Beware that XGBoost aggressively consumes memory when training a deep tree. ``exact`` tree method requires non-zero value.
- range: [0,∞]

* ``min_child_weight`` [default=1]

Expand Down Expand Up @@ -164,7 +164,7 @@ Parameters for Tree Booster

- Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: ``sum(negative instances) / sum(positive instances)``. See :doc:`Parameters Tuning </tutorials/param_tuning>` for more discussion. Also, see Higgs Kaggle competition demo for examples: `R <https://github.com/dmlc/xgboost/blob/master/demo/kaggle-higgs/higgs-train.R>`_, `py1 <https://github.com/dmlc/xgboost/blob/master/demo/kaggle-higgs/higgs-numpy.py>`_, `py2 <https://github.com/dmlc/xgboost/blob/master/demo/kaggle-higgs/higgs-cv.py>`_, `py3 <https://github.com/dmlc/xgboost/blob/master/demo/guide-python/cross_validation.py>`_.

* ``updater`` [default= ``grow_colmaker,prune``]
* ``updater``

- A comma separated string defining the sequence of tree updaters to run, providing a modular way to construct and to modify the trees. This is an advanced parameter that is usually set automatically, depending on some other parameters. However, it could be also set explicitly by a user. The following updaters exist:

Expand All @@ -177,8 +177,6 @@ Parameters for Tree Booster
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.

- In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune`` by default, and you can set ``tree_method`` as ``hist`` to use ``grow_histmaker``.

* ``refresh_leaf`` [default=1]

- This is a parameter of the ``refresh`` updater. When this flag is 1, tree leafs as well as tree nodes' stats are updated. When it is 0, only node stats are updated.
Expand All @@ -194,19 +192,19 @@ Parameters for Tree Booster
* ``grow_policy`` [default= ``depthwise``]

- Controls a way new nodes are added to the tree.
- Currently supported only if ``tree_method`` is set to ``hist`` or ``gpu_hist``.
- Currently supported only if ``tree_method`` is set to ``hist``, ``approx`` or ``gpu_hist``.
- Choices: ``depthwise``, ``lossguide``

- ``depthwise``: split at nodes closest to the root.
- ``lossguide``: split at nodes with highest loss change.

* ``max_leaves`` [default=0]

- Maximum number of nodes to be added. Only relevant when ``grow_policy=lossguide`` is set.
- Maximum number of nodes to be added. Not used by ``exact`` tree method.

* ``max_bin``, [default=256]

- Only used if ``tree_method`` is set to ``hist`` or ``gpu_hist``.
- Only used if ``tree_method`` is set to ``hist``, ``approx`` or ``gpu_hist``.
- Maximum number of discrete bins to bucket continuous features.
- Increasing this number improves the optimality of splits at the cost of higher computation time.

Expand Down
29 changes: 29 additions & 0 deletions doc/treemethod.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,32 @@ was never tested and contained some unknown bugs, we decided to remove it and fo
resources on more promising algorithms instead. For accuracy, most of the time
``approx``, ``hist`` and ``gpu_hist`` are enough with some parameters tuning, so removing
them don't have any real practical impact.


**************
Feature Matrix
**************

Following table summarizes some differences in supported features between 4 tree methods,
`T` means supported while `F` means unsupported.

+------------------+-----------+---------------------+---------------------+------------------------+
| | Exact | Approx | Hist | GPU Hist |
+==================+===========+=====================+=====================+========================+
| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide |
+------------------+-----------+---------------------+---------------------+------------------------+
| max_leaves | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+
| sampling method | uniform | uniform | uniform | gradient_based/uniform |
+------------------+-----------+---------------------+---------------------+------------------------+
| categorical data | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+
| External memory | F | T | P | P |
+------------------+-----------+---------------------+---------------------+------------------------+
| Distributed | F | T | T | T |
+------------------+-----------+---------------------+---------------------+------------------------+

Features/parameters that are not mentioned here are universally supported for all 4 tree
methods (for instance, column sampling and constraints). The `P` in external memory means
partially supported. Please note that both categorical data and external memory are
experimental.
7 changes: 1 addition & 6 deletions doc/tutorials/feature_interaction_constraint.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ first and second constraints (``[0, 1]``, ``[2, 3, 4]``).

.. |fig1| image:: ../_static/feature_interaction_illustration2.svg
:scale: 7%
:align: middle
:align: middle

.. |fig2| image:: ../_static/feature_interaction_illustration3.svg
:scale: 7%
Expand Down Expand Up @@ -174,11 +174,6 @@ parameter:
num_boost_round = 1000, evals = evallist,
early_stopping_rounds = 10)
**Choice of tree construction algorithm**. To use feature interaction constraints, be sure
to set the ``tree_method`` parameter to one of the following: ``exact``, ``hist``,
``approx`` or ``gpu_hist``. Support for ``gpu_hist`` and ``approx`` is added only in
1.0.0.

**************
Advanced topic
**************
Expand Down
27 changes: 12 additions & 15 deletions doc/tutorials/monotonic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
Monotonic Constraints
#####################

It is often the case in a modeling problem or project that the functional form of an acceptable model is constrained in some way. This may happen due to business considerations, or because of the type of scientific question being investigated. In some cases, where there is a very strong prior belief that the true relationship has some quality, constraints can be used to improve the predictive performance of the model.
It is often the case in a modeling problem or project that the functional form of an acceptable model is constrained in some way. This may happen due to business considerations, or because of the type of scientific question being investigated. In some cases, where there is a very strong prior belief that the true relationship has some quality, constraints can be used to improve the predictive performance of the model.

A common type of constraint in this situation is that certain features bear a **monotonic** relationship to the predicted response:

.. math::
f(x_1, x_2, \ldots, x, \ldots, x_{n-1}, x_n) \leq f(x_1, x_2, \ldots, x', \ldots, x_{n-1}, x_n)
whenever :math:`x \leq x'` is an **increasing constraint**; or
whenever :math:`x \leq x'` is an **increasing constraint**; or

.. math::
f(x_1, x_2, \ldots, x, \ldots, x_{n-1}, x_n) \geq f(x_1, x_2, \ldots, x', \ldots, x_{n-1}, x_n)
whenever :math:`x \leq x'` is a **decreasing constraint**.

XGBoost has the ability to enforce monotonicity constraints on any features used in a boosted model.
XGBoost has the ability to enforce monotonicity constraints on any features used in a boosted model.

****************
A Simple Example
Expand Down Expand Up @@ -60,8 +60,8 @@ Suppose the following code fits your model without monotonicity constraints

.. code-block:: python
model_no_constraints = xgb.train(params, dtrain,
num_boost_round = 1000, evals = evallist,
model_no_constraints = xgb.train(params, dtrain,
num_boost_round = 1000, evals = evallist,
early_stopping_rounds = 10)
Then fitting with monotonicity constraints only requires adding a single parameter
Expand All @@ -71,8 +71,8 @@ Then fitting with monotonicity constraints only requires adding a single paramet
params_constrained = params.copy()
params_constrained['monotone_constraints'] = "(1,-1)"
model_with_constraints = xgb.train(params_constrained, dtrain,
num_boost_round = 1000, evals = evallist,
model_with_constraints = xgb.train(params_constrained, dtrain,
num_boost_round = 1000, evals = evallist,
early_stopping_rounds = 10)
In this example the training data ``X`` has two columns, and by using the parameter values ``(1,-1)`` we are telling XGBoost to impose an increasing constraint on the first predictor and a decreasing constraint on the second.
Expand All @@ -82,14 +82,11 @@ Some other examples:
- ``(1,0)``: An increasing constraint on the first predictor and no constraint on the second.
- ``(0,-1)``: No constraint on the first predictor and a decreasing constraint on the second.

**Choice of tree construction algorithm**. To use monotonic constraints, be
sure to set the ``tree_method`` parameter to one of ``exact``, ``hist``, and
``gpu_hist``.

**Note for the 'hist' tree construction algorithm**.
If ``tree_method`` is set to either ``hist`` or ``gpu_hist``, enabling monotonic
constraints may produce unnecessarily shallow trees. This is because the
If ``tree_method`` is set to either ``hist``, ``approx`` or ``gpu_hist``, enabling
monotonic constraints may produce unnecessarily shallow trees. This is because the
``hist`` method reduces the number of candidate splits to be considered at each
split. Monotonic constraints may wipe out all available split candidates, in
which case no split is made. To reduce the effect, you may want to increase
the ``max_bin`` parameter to consider more split candidates.
split. Monotonic constraints may wipe out all available split candidates, in which case no
split is made. To reduce the effect, you may want to increase the ``max_bin`` parameter to
consider more split candidates.
2 changes: 2 additions & 0 deletions src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class ColMaker: public TreeUpdater {
std::vector<int> newnodes;
this->InitData(gpair, *p_fmat);
this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
// We can check max_leaves too, but might break some grid searching pipelines.
CHECK_GT(param_.max_depth, 0) << "exact tree method doesn't support unlimited depth.";
for (int depth = 0; depth < param_.max_depth; ++depth) {
this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree);
this->ResetPosition(qexpand_, p_fmat, *p_tree);
Expand Down
105 changes: 91 additions & 14 deletions tests/cpp/tree/test_tree_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,89 @@ class TestGrowPolicy : public ::testing::Test {
true);
}

void TestTreeGrowPolicy(std::string tree_method, std::string policy) {
{
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
learner->SetParam("tree_method", tree_method);
learner->SetParam("max_leaves", "16");
learner->SetParam("grow_policy", policy);
learner->Configure();
std::unique_ptr<Learner> TrainOneIter(std::string tree_method, std::string policy,
int32_t max_leaves, int32_t max_depth) {
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
learner->SetParam("tree_method", tree_method);
if (max_leaves >= 0) {
learner->SetParam("max_leaves", std::to_string(max_leaves));
}
if (max_depth >= 0) {
learner->SetParam("max_depth", std::to_string(max_depth));
}
learner->SetParam("grow_policy", policy);

auto check_max_leave = [&]() {
Json model{Object{}};
learner->SaveModel(&model);
auto j_tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
RegTree tree;
tree.LoadModel(j_tree);
CHECK_LE(tree.GetNumLeaves(), max_leaves);
};

auto check_max_depth = [&](int32_t sol) {
Json model{Object{}};
learner->SaveModel(&model);

auto j_tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
RegTree tree;
tree.LoadModel(j_tree);
bst_node_t depth = 0;
tree.WalkTree([&](bst_node_t nidx) {
depth = std::max(tree.GetDepth(nidx), depth);
return true;
});
if (sol > -1) {
CHECK_EQ(depth, sol);
} else {
CHECK_EQ(depth, max_depth) << "tree method: " << tree_method << " policy: " << policy
<< " leaves:" << max_leaves << ", depth:" << max_depth;
}
};

if (max_leaves == 0 && max_depth == 0) {
// unconstrainted
if (tree_method != "gpu_hist") {
// GPU pre-allocates for all nodes.
learner->UpdateOneIter(0, Xy_);
}
} else if (max_leaves > 0 && max_depth == 0) {
learner->UpdateOneIter(0, Xy_);
check_max_leave();
} else if (max_leaves == 0 && max_depth > 0) {
learner->UpdateOneIter(0, Xy_);
check_max_depth(-1);
} else if (max_leaves > 0 && max_depth > 0) {
learner->UpdateOneIter(0, Xy_);
check_max_leave();
check_max_depth(2);
} else if (max_leaves == -1 && max_depth == 0) {
// default max_leaves is 0, so both of them are now 0
} else {
// default parameters
learner->UpdateOneIter(0, Xy_);
}
return learner;
}

void TestCombination(std::string tree_method) {
for (auto policy : {"depthwise", "lossguide"}) {
// -1 means default
for (auto leaves : {-1, 0, 3}) {
for (auto depth : {-1, 0, 3}) {
this->TrainOneIter(tree_method, policy, leaves, depth);
}
}
}
}

void TestTreeGrowPolicy(std::string tree_method, std::string policy) {
{
/**
* max_leaves
*/
auto learner = this->TrainOneIter(tree_method, policy, 16, -1);
Json model{Object{}};
learner->SaveModel(&model);

Expand All @@ -38,13 +112,10 @@ class TestGrowPolicy : public ::testing::Test {
ASSERT_EQ(tree.GetNumLeaves(), 16);
}
{
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
learner->SetParam("tree_method", tree_method);
learner->SetParam("max_depth", "3");
learner->SetParam("grow_policy", policy);
learner->Configure();

learner->UpdateOneIter(0, Xy_);
/**
* max_depth
*/
auto learner = this->TrainOneIter(tree_method, policy, -1, 3);
Json model{Object{}};
learner->SaveModel(&model);

Expand All @@ -64,17 +135,23 @@ class TestGrowPolicy : public ::testing::Test {
TEST_F(TestGrowPolicy, Approx) {
this->TestTreeGrowPolicy("approx", "depthwise");
this->TestTreeGrowPolicy("approx", "lossguide");

this->TestCombination("approx");
}

TEST_F(TestGrowPolicy, Hist) {
this->TestTreeGrowPolicy("hist", "depthwise");
this->TestTreeGrowPolicy("hist", "lossguide");

this->TestCombination("hist");
}

#if defined(XGBOOST_USE_CUDA)
TEST_F(TestGrowPolicy, GpuHist) {
this->TestTreeGrowPolicy("gpu_hist", "depthwise");
this->TestTreeGrowPolicy("gpu_hist", "lossguide");

this->TestCombination("gpu_hist");
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost
2 changes: 1 addition & 1 deletion tests/python-gpu/test_gpu_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def noop(*args, **kwargs):
rng = np.random.RandomState(1994)

shap_parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(0, 11),
'max_depth': strategies.integers(1, 11),
'max_leaves': strategies.integers(0, 256),
'num_parallel_tree': strategies.sampled_from([1, 10]),
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
Expand Down

0 comments on commit 18a4af6

Please sign in to comment.