Skip to content

Commit

Permalink
Column sampling at individual nodes (splits). (#3971)
Browse files Browse the repository at this point in the history
* Column sampling at individual nodes (splits).

* Documented colsample_bynode parameter.

- also updated documentation for colsample_by* parameters

* Updated documentation.

* GetFeatureSet() returns shared pointer to std::vector.

* Sync sampled columns across multiple processes.
  • Loading branch information
canonizer authored and trivialfis committed Dec 14, 2018
1 parent e0a2791 commit 42bf90e
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 82 deletions.
25 changes: 16 additions & 9 deletions doc/parameter.rst
Expand Up @@ -82,15 +82,22 @@ Parameters for Tree Booster
- Subsample ratio of the training instances. Setting it to 0.5 means that XGBoost would randomly sample half of the training data prior to growing trees. and this will prevent overfitting. Subsampling will occur once in every boosting iteration.
- range: (0,1]

* ``colsample_bytree`` [default=1]

- Subsample ratio of columns when constructing each tree. Subsampling will occur once in every boosting iteration.
- range: (0,1]

* ``colsample_bylevel`` [default=1]

- Subsample ratio of columns for each split, in each level. Subsampling will occur each time a new split is made.
- range: (0,1]
* ``colsample_bytree``, ``colsample_bylevel``, ``colsample_bynode`` [default=1]
- This is a family of parameters for subsampling of columns.
- All ``colsample_by*`` parameters have a range of (0, 1], the default value of 1, and
specify the fraction of columns to be subsampled.
- ``colsample_bytree`` is the subsample ratio of columns when constructing each
tree. Subsampling occurs once for every tree constructed.
- ``colsample_bylevel`` is the subsample ratio of columns for each level. Subsampling
occurs once for every new depth level reached in a tree. Columns are subsampled from
the set of columns chosen for the current tree.
- ``colsample_bynode`` is the subsample ratio of columns for each node
(split). Subsampling occurs once every time a new split is evaluated. Columns are
subsampled from the set of columns chosen for the current level.
- ``colsample_by*`` parameters work cumulatively. For instance,
the combination ``{'colsample_bytree':0.5, 'colsample_bylevel':0.5,
'colsample_bynode':0.5}`` with 64 features will leave 4 features to choose from at
each split.

* ``lambda`` [default=1, alias: ``reg_lambda``]

Expand Down
86 changes: 56 additions & 30 deletions src/common/random.h
Expand Up @@ -7,14 +7,15 @@
#ifndef XGBOOST_COMMON_RANDOM_H_
#define XGBOOST_COMMON_RANDOM_H_

#include <rabit/rabit.h>
#include <xgboost/logging.h>
#include <algorithm>
#include <vector>
#include <limits>
#include <map>
#include <memory>
#include <numeric>
#include <random>
#include "host_device_vector.h"

namespace xgboost {
namespace common {
Expand Down Expand Up @@ -75,72 +76,97 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
/**
* \class ColumnSampler
*
* \brief Handles selection of columns due to colsample_bytree and
* colsample_bylevel parameters. Should be initialised before tree
* construction and to reset when tree construction is completed.
* \brief Handles selection of columns due to colsample_bytree, colsample_bylevel and
* colsample_bynode parameters. Should be initialised before tree construction and to
* reset when tree construction is completed.
*/

class ColumnSampler {
HostDeviceVector<int> feature_set_tree_;
std::map<int, HostDeviceVector<int>> feature_set_level_;
std::shared_ptr<std::vector<int>> feature_set_tree_;
std::map<int, std::shared_ptr<std::vector<int>>> feature_set_level_;
float colsample_bylevel_{1.0f};
float colsample_bytree_{1.0f};
float colsample_bynode_{1.0f};

std::vector<int> ColSample(std::vector<int> features, float colsample) const {
if (colsample == 1.0f) return features;
std::shared_ptr<std::vector<int>> ColSample
(std::shared_ptr<std::vector<int>> p_features, float colsample) const {
if (colsample == 1.0f) return p_features;
const auto& features = *p_features;
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));

std::shuffle(features.begin(), features.end(), common::GlobalRandom());
features.resize(n);
std::sort(features.begin(), features.end());

return features;
auto p_new_features = std::make_shared<std::vector<int>>();
auto& new_features = *p_new_features;
new_features.resize(features.size());
std::copy(features.begin(), features.end(), new_features.begin());
std::shuffle(new_features.begin(), new_features.end(), common::GlobalRandom());
new_features.resize(n);
std::sort(new_features.begin(), new_features.end());

// ensure that new_features are the same across ranks
rabit::Broadcast(&new_features, 0);

return p_new_features;
}

public:
/**
* \brief Initialise this object before use.
*
* \param num_col
* \param colsample_bynode
* \param colsample_bylevel
* \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0.
*/
void Init(int64_t num_col, float colsample_bylevel, float colsample_bytree,
bool skip_index_0 = false) {
this->colsample_bylevel_ = colsample_bylevel;
this->colsample_bytree_ = colsample_bytree;
this->Reset();
void Init(int64_t num_col, float colsample_bynode, float colsample_bylevel,
float colsample_bytree, bool skip_index_0 = false) {
colsample_bylevel_ = colsample_bylevel;
colsample_bytree_ = colsample_bytree;
colsample_bynode_ = colsample_bynode;

if (feature_set_tree_ == nullptr) {
feature_set_tree_ = std::make_shared<std::vector<int>>();
}
Reset();

int begin_idx = skip_index_0 ? 1 : 0;
auto& feature_set_h = feature_set_tree_.HostVector();
feature_set_h.resize(num_col - begin_idx);
feature_set_tree_->resize(num_col - begin_idx);
std::iota(feature_set_tree_->begin(), feature_set_tree_->end(), begin_idx);

std::iota(feature_set_h.begin(), feature_set_h.end(), begin_idx);
feature_set_h = ColSample(feature_set_h, this->colsample_bytree_);
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
}

/**
* \brief Resets this object.
*/
void Reset() {
feature_set_tree_.HostVector().clear();
feature_set_tree_->clear();
feature_set_level_.clear();
}

HostDeviceVector<int>& GetFeatureSet(int depth) {
if (this->colsample_bylevel_ == 1.0f) {
/**
* \brief Samples a feature set.
*
* \param depth The tree depth of the node at which to sample.
* \return The sampled feature set.
* \note If colsample_bynode_ < 1.0, this method creates a new feature set each time it
* is called. Therefore, it should be called only once per node.
*/
std::shared_ptr<std::vector<int>> GetFeatureSet(int depth) {
if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) {
return feature_set_tree_;
}

if (feature_set_level_.count(depth) == 0) {
// Level sampling, level does not yet exist so generate it
auto& level = feature_set_level_[depth].HostVector();
level = ColSample(feature_set_tree_.HostVector(), this->colsample_bylevel_);
feature_set_level_[depth] = ColSample(feature_set_tree_, colsample_bylevel_);
}
if (colsample_bynode_ == 1.0f) {
// Level sampling
return feature_set_level_[depth];
}
// Level sampling
return feature_set_level_[depth];
// Need to sample for the node individually
return ColSample(feature_set_level_[depth], colsample_bynode_);
}
};

Expand Down
8 changes: 7 additions & 1 deletion src/tree/param.h
Expand Up @@ -50,7 +50,9 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
float max_delta_step;
// whether we want to do subsample
float subsample;
// whether to subsample columns each split, in each level
// whether to subsample columns in each split (node)
float colsample_bynode;
// whether to subsample columns in each level
float colsample_bylevel;
// whether to subsample columns during tree construction
float colsample_bytree;
Expand Down Expand Up @@ -149,6 +151,10 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe("Row subsample ratio of training instance.");
DMLC_DECLARE_FIELD(colsample_bynode)
.set_range(0.0f, 1.0f)
.set_default(1.0f)
.describe("Subsample ratio of columns, resample on each node (split).");
DMLC_DECLARE_FIELD(colsample_bylevel)
.set_range(0.0f, 1.0f)
.set_default(1.0f)
Expand Down
7 changes: 4 additions & 3 deletions src/tree/updater_colmaker.cc
Expand Up @@ -168,8 +168,8 @@ class ColMaker: public TreeUpdater {
}
}
{
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bylevel,
param_.colsample_bytree);
column_sampler_.Init(fmat.Info().num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
}
{
// setup temp space for each thread
Expand Down Expand Up @@ -625,7 +625,8 @@ class ColMaker: public TreeUpdater {
const std::vector<GradientPair> &gpair,
DMatrix *p_fmat,
RegTree *p_tree) {
const std::vector<int> &feat_set = column_sampler_.GetFeatureSet(depth).HostVector();
auto p_feature_set = column_sampler_.GetFeatureSet(depth);
const auto& feat_set = *p_feature_set;
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
this->UpdateSolution(batch, feat_set, gpair, p_fmat);
}
Expand Down
36 changes: 21 additions & 15 deletions src/tree/updater_gpu_hist.cu
Expand Up @@ -499,6 +499,8 @@ struct DeviceShard {
dh::DVec<GradientPair> node_sum_gradients_d;
/*! \brief row offset in SparsePage (the input data). */
thrust::device_vector<size_t> row_ptrs;
/*! \brief On-device feature set, only actually used on one of the devices */
thrust::device_vector<int> feature_set_d;
/*! The row offset for this shard. */
bst_uint row_begin_idx;
bst_uint row_end_idx;
Expand Down Expand Up @@ -579,28 +581,31 @@ struct DeviceShard {
}

DeviceSplitCandidate EvaluateSplit(int nidx,
const HostDeviceVector<int>& feature_set,
const std::vector<int>& feature_set,
ValueConstraint value_constraint) {
dh::safe_cuda(cudaSetDevice(device_id_));
auto d_split_candidates = temp_memory.GetSpan<DeviceSplitCandidate>(feature_set.Size());
auto d_split_candidates = temp_memory.GetSpan<DeviceSplitCandidate>(feature_set.size());
feature_set_d.resize(feature_set.size());
auto d_features = common::Span<int>(feature_set_d.data().get(),
feature_set_d.size());
dh::safe_cuda(cudaMemcpy(d_features.data(), feature_set.data(),
d_features.size_bytes(), cudaMemcpyDefault));
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
feature_set.Reshard(GPUSet::Range(device_id_, 1));

// One block for each feature
int constexpr BLOCK_THREADS = 256;
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
<<<uint32_t(feature_set.Size()), BLOCK_THREADS, 0>>>(
hist.GetNodeHistogram(nidx), feature_set.DeviceSpan(device_id_), node,
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>
(hist.GetNodeHistogram(nidx), d_features, node,
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
d_split_candidates, value_constraint, monotone_constraints.GetSpan());

dh::safe_cuda(cudaDeviceSynchronize());
std::vector<DeviceSplitCandidate> split_candidates(feature_set.Size());
dh::safe_cuda(
cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
split_candidates.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToHost));
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
split_candidates.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToHost));
DeviceSplitCandidate best_split;
for (auto candidate : split_candidates) {
best_split.Update(candidate, param);
Expand Down Expand Up @@ -1009,7 +1014,8 @@ class GPUHistMakerSpecialised{
}
monitor_.Stop("InitDataOnce", dist_.Devices());

column_sampler_.Init(info_->num_col_, param_.colsample_bylevel, param_.colsample_bytree);
column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);

// Copy gpair & reset memory
monitor_.Start("InitDataReset", dist_.Devices());
Expand Down Expand Up @@ -1100,7 +1106,7 @@ class GPUHistMakerSpecialised{

DeviceSplitCandidate EvaluateSplit(int nidx, RegTree* p_tree) {
return shards_.front()->EvaluateSplit(
nidx, column_sampler_.GetFeatureSet(p_tree->GetDepth(nidx)),
nidx, *column_sampler_.GetFeatureSet(p_tree->GetDepth(nidx)),
node_value_constraints_[nidx]);
}

Expand Down
12 changes: 6 additions & 6 deletions src/tree/updater_quantile_hist.cc
Expand Up @@ -354,11 +354,11 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
p_last_fmat_ = &fmat;
// initialize feature index
if (data_layout_ == kDenseDataOneBased) {
column_sampler_.Init(info.num_col_, param_.colsample_bylevel,
param_.colsample_bytree, true);
column_sampler_.Init(info.num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree, true);
} else {
column_sampler_.Init(info.num_col_, param_.colsample_bylevel,
param_.colsample_bytree, false);
column_sampler_.Init(info.num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree, false);
}
}
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
Expand Down Expand Up @@ -400,8 +400,8 @@ void QuantileHistMaker::Builder::EvaluateSplit(int nid,
const RegTree& tree) {
// start enumeration
const MetaInfo& info = fmat.Info();
const auto& feature_set = column_sampler_.GetFeatureSet(
tree.GetDepth(nid)).HostVector();
auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
const auto& feature_set = *p_feature_set;
const auto nfeature = static_cast<bst_uint>(feature_set.size());
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
best_split_tloc_.resize(nthread);
Expand Down
48 changes: 30 additions & 18 deletions tests/cpp/common/test_random.cc
Expand Up @@ -5,33 +5,45 @@
namespace xgboost {
namespace common {
TEST(ColumnSampler, Test) {
int n = 100;
int n = 128;
ColumnSampler cs;
cs.Init(n, 0.5f, 0.5f);
auto &set0 = cs.GetFeatureSet(0).HostVector();
ASSERT_EQ(set0.size(), 25);

auto &set1 = cs.GetFeatureSet(0).HostVector();
// No node sampling
cs.Init(n, 1.0f, 0.5f, 0.5f);
auto set0 = *cs.GetFeatureSet(0);
ASSERT_EQ(set0.size(), 32);

auto set1 = *cs.GetFeatureSet(0);
ASSERT_EQ(set0, set1);

auto &set2 = cs.GetFeatureSet(1).HostVector();
auto set2 = *cs.GetFeatureSet(1);
ASSERT_NE(set1, set2);
ASSERT_EQ(set2.size(), 25);
ASSERT_EQ(set2.size(), 32);

// Node sampling
cs.Init(n, 0.5f, 1.0f, 0.5f);
auto set3 = *cs.GetFeatureSet(0);
ASSERT_EQ(set3.size(), 32);

// No level sampling, should be the same at different depth
cs.Init(n, 1.0f, 0.5f);
ASSERT_EQ(cs.GetFeatureSet(0).HostVector(), cs.GetFeatureSet(1).HostVector());
auto set4 = *cs.GetFeatureSet(0);
ASSERT_NE(set3, set4);
ASSERT_EQ(set4.size(), 32);

cs.Init(n, 1.0f, 1.0f);
auto &set3 = cs.GetFeatureSet(0).HostVector();
ASSERT_EQ(set3.size(), n);
cs.Init(n, 1.0f, 1.0f);
auto &set4 = cs.GetFeatureSet(0).HostVector();
ASSERT_EQ(set3, set4);
// No level or node sampling, should be the same at different depth
cs.Init(n, 1.0f, 1.0f, 0.5f);
ASSERT_EQ(*cs.GetFeatureSet(0), *cs.GetFeatureSet(1));

cs.Init(n, 1.0f, 1.0f, 1.0f);
auto set5 = *cs.GetFeatureSet(0);
ASSERT_EQ(set5.size(), n);
cs.Init(n, 1.0f, 1.0f, 1.0f);
auto set6 = *cs.GetFeatureSet(0);
ASSERT_EQ(set5, set6);

// Should always be a minimum of one feature
cs.Init(n, 1e-16f, 1e-16f);
ASSERT_EQ(cs.GetFeatureSet(0).HostVector().size(), 1);
cs.Init(n, 1e-16f, 1e-16f, 1e-16f);
ASSERT_EQ(cs.GetFeatureSet(0)->size(), 1);

}
} // namespace common
} // namespace xgboost
2 changes: 2 additions & 0 deletions tests/cpp/tree/test_gpu_hist.cu
Expand Up @@ -227,6 +227,7 @@ TEST(GpuHist, EvaluateSplits) {
TrainParam param;
param.max_depth = 1;
param.n_gpus = 1;
param.colsample_bynode = 1;
param.colsample_bylevel = 1;
param.colsample_bytree = 1;
param.min_child_weight = 0.01;
Expand Down Expand Up @@ -284,6 +285,7 @@ TEST(GpuHist, EvaluateSplits) {
hist_maker.param_ = param;
hist_maker.shards_.push_back(std::move(shard));
hist_maker.column_sampler_.Init(n_cols,
param.colsample_bynode,
param.colsample_bylevel,
param.colsample_bytree,
false);
Expand Down

0 comments on commit 42bf90e

Please sign in to comment.