-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using column_sampler for optimization of ColWiseBuildHist #8319
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,9 @@ | |
#include <algorithm> | ||
#include <limits> | ||
#include <vector> | ||
#include <memory> | ||
|
||
#include "../../common/random.h" | ||
#include "../../collective/communicator-inl.h" | ||
#include "../../common/hist_util.h" | ||
#include "../../data/gradient_index.h" | ||
|
@@ -25,7 +27,10 @@ class HistogramBuilder { | |
common::GHistBuilder builder_; | ||
common::ParallelGHistBuilder buffer_; | ||
BatchParam param_; | ||
TrainParam train_param_; | ||
int32_t n_threads_{-1}; | ||
std::shared_ptr<common::ColumnSampler> column_sampler_; | ||
std::vector<int> fids_; | ||
size_t n_batches_{0}; | ||
// Whether XGBoost is running in distributed environment. | ||
bool is_distributed_{false}; | ||
|
@@ -39,12 +44,15 @@ class HistogramBuilder { | |
* \param is_distributed Mostly used for testing to allow injecting parameters instead | ||
* of using global rabit variable. | ||
*/ | ||
void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches, | ||
bool is_distributed) { | ||
void Reset(uint32_t total_bins, BatchParam p, const TrainParam& train_param, | ||
std::shared_ptr<common::ColumnSampler> column_sampler, | ||
int32_t n_threads, size_t n_batches, bool is_distributed) { | ||
CHECK_GE(n_threads, 1); | ||
n_threads_ = n_threads; | ||
column_sampler_ = column_sampler; | ||
n_batches_ = n_batches; | ||
param_ = p; | ||
train_param_ = train_param; | ||
hist_.Init(total_bins); | ||
hist_local_worker_.Init(total_bins); | ||
buffer_.Init(total_bins); | ||
|
@@ -59,7 +67,7 @@ class HistogramBuilder { | |
GHistIndexMatrix const &gidx, | ||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build, | ||
common::RowSetCollection const &row_set_collection, | ||
const std::vector<GradientPair> &gpair_h, | ||
const std::vector<GradientPair> &gpair_h, int depth, | ||
bool force_read_by_column) { | ||
const size_t n_nodes = nodes_for_explicit_hist_build.size(); | ||
CHECK_GT(n_nodes, 0); | ||
|
@@ -76,6 +84,20 @@ class HistogramBuilder { | |
buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); | ||
} | ||
|
||
constexpr float kColsampleTh = 0.1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why 0.1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is an ad-hoc threshold value. |
||
bool column_sampling = (column_sampler_ != nullptr) && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When column sampler is nullptr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I set it to nullptr in tests of the histogram builder without using this optimization. |
||
(train_param_.colsample_bytree < kColsampleTh || | ||
train_param_.colsample_bylevel < kColsampleTh); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about bynode? Is it used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not now, maybe later, one can investigate these options more deeply. |
||
if (column_sampling) { | ||
const size_t n_sampled_features = column_sampler_->GetFeatureSet(depth)->Size(); | ||
fids_.resize(n_sampled_features); | ||
for (size_t i = 0; i < n_sampled_features; ++i) { | ||
fids_[i] = column_sampler_->GetFeatureSet(depth)->ConstHostVector()[i]; | ||
} | ||
} else { | ||
fids_.resize(0); | ||
} | ||
|
||
// Parallel processing by nodes and data in each node | ||
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { | ||
const auto tid = static_cast<unsigned>(omp_get_thread_num()); | ||
|
@@ -87,7 +109,7 @@ class HistogramBuilder { | |
elem.begin + end_of_row_set, nid); | ||
auto hist = buffer_.GetInitializedHist(tid, nid_in_set); | ||
if (rid_set.Size() != 0) { | ||
builder_.template BuildHist<any_missing>(gpair_h, rid_set, gidx, hist, | ||
builder_.template BuildHist<any_missing>(gpair_h, rid_set, gidx, hist, fids_, | ||
force_read_by_column); | ||
} | ||
}); | ||
|
@@ -114,7 +136,7 @@ class HistogramBuilder { | |
RegTree *p_tree, common::RowSetCollection const &row_set_collection, | ||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build, | ||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick, | ||
std::vector<GradientPair> const &gpair, | ||
std::vector<GradientPair> const &gpair, int depth, | ||
bool force_read_by_column = false) { | ||
int starting_index = std::numeric_limits<int>::max(); | ||
int sync_count = 0; | ||
|
@@ -126,12 +148,12 @@ class HistogramBuilder { | |
if (gidx.IsDense()) { | ||
this->BuildLocalHistograms<false>(page_id, space, gidx, | ||
nodes_for_explicit_hist_build, | ||
row_set_collection, gpair, | ||
row_set_collection, gpair, depth, | ||
force_read_by_column); | ||
} else { | ||
this->BuildLocalHistograms<true>(page_id, space, gidx, | ||
nodes_for_explicit_hist_build, | ||
row_set_collection, gpair, | ||
row_set_collection, gpair, depth, | ||
force_read_by_column); | ||
} | ||
|
||
|
@@ -153,7 +175,7 @@ class HistogramBuilder { | |
common::RowSetCollection const &row_set_collection, | ||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build, | ||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick, | ||
std::vector<GradientPair> const &gpair, | ||
std::vector<GradientPair> const &gpair, int depth, | ||
bool force_read_by_column = false) { | ||
const size_t n_nodes = nodes_for_explicit_hist_build.size(); | ||
// create space of size (# rows in each node) | ||
|
@@ -166,7 +188,7 @@ class HistogramBuilder { | |
256); | ||
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, | ||
nodes_for_explicit_hist_build, nodes_for_subtraction_trick, | ||
gpair, force_read_by_column); | ||
gpair, depth, force_read_by_column); | ||
} | ||
|
||
void SyncHistogramDistributed( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
fids
empty if there's no sampling?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fids is empty is case of condition here is false.