Skip to content

Commit

Permalink
Revert min value change.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 9, 2020
1 parent d8ac122 commit a4795d1
Show file tree
Hide file tree
Showing 21 changed files with 103 additions and 26 deletions.
6 changes: 0 additions & 6 deletions include/xgboost/span.h
Expand Up @@ -82,7 +82,6 @@ namespace common {
"\tBlock: [%d, %d, %d], Thread: [%d, %d, %d]\n\n", \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #cond, blockIdx.x, \
blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z); \
assert(false); \
asm("trap;"); \
} \
} while (0);
Expand Down Expand Up @@ -662,11 +661,6 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
}

template <typename T, template <class, class...> class Container, typename... Types,
std::size_t Extent = dynamic_extent>
auto MakeSpan(Container<T, Types...> const &container) {
return Span<T, Extent>(container);
}
} // namespace common
} // namespace xgboost

Expand Down
2 changes: 0 additions & 2 deletions src/common/common.h
Expand Up @@ -82,8 +82,6 @@ XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) {
return static_cast<T1>(std::ceil(static_cast<double>(a) / b));
}

constexpr float kTrivialSplit = -std::numeric_limits<float>::infinity();

/*
* Range iterator
*/
Expand Down
28 changes: 24 additions & 4 deletions src/common/hist_util.cc
Expand Up @@ -161,8 +161,10 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,

// Data groups, used in ranking.
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
auto &local_min_vals = p_cuts_->min_vals_.HostVector();
auto &local_cuts = p_cuts_->cut_values_.HostVector();
auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector();
local_min_vals.resize(end_col - beg_col, 0);

for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) {
// Using a local variable makes things easier, but at the cost of memory trashing.
Expand Down Expand Up @@ -197,11 +199,16 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
summary.Reserve(n_bins + 1);
summary.SetPrune(out_summary, n_bins + 1);

// Can be use data[1] as the min values so that we don't need to
// store another array?
float mval = summary.data[0].value;
local_min_vals[col_id - beg_col] = mval - (fabs(mval) + 1e-5);

this->AddCutPoint(summary, max_num_bins);

bst_float cpt = (summary.size > 0) ?
summary.data[summary.size - 1].value :
kTrivialSplit;
local_min_vals[col_id - beg_col];
cpt += fabs(cpt) + 1e-5;
local_cuts.emplace_back(cpt);

Expand Down Expand Up @@ -279,10 +286,14 @@ void SparseCuts::Concat(
std::vector<std::unique_ptr<SparseCuts>> const& cuts, uint32_t n_cols) {
monitor_.Start(__FUNCTION__);
uint32_t nthreads = omp_get_max_threads();
auto &local_min_vals = p_cuts_->min_vals_.HostVector();
auto &local_cuts = p_cuts_->cut_values_.HostVector();
auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector();
local_min_vals.resize(n_cols, std::numeric_limits<float>::max());
size_t min_vals_tail = 0;

for (uint32_t t = 0; t < nthreads; ++t) {
auto& thread_min_vals = cuts[t]->p_cuts_->min_vals_.HostVector();
auto& thread_cuts = cuts[t]->p_cuts_->cut_values_.HostVector();
auto& thread_ptrs = cuts[t]->p_cuts_->cut_ptrs_.HostVector();

Expand All @@ -303,6 +314,12 @@ void SparseCuts::Concat(
for (size_t j = old_iv_size; j < new_iv_size; ++j) {
local_cuts[j] = thread_cuts[j-old_iv_size];
}
// merge min values
for (size_t j = 0; j < thread_min_vals.size(); ++j) {
local_min_vals.at(min_vals_tail + j) =
std::min(local_min_vals.at(min_vals_tail + j), thread_min_vals.at(j));
}
min_vals_tail += thread_min_vals.size();
}
monitor_.Stop(__FUNCTION__);
}
Expand Down Expand Up @@ -409,15 +426,18 @@ void DenseCuts::Init
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
// we need to move this allreduce before loadcheckpoint call in future
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
p_cuts_->min_vals_.HostVector().resize(sketchs.size());

for (auto const& summary : summary_array) {
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
WQSketch::SummaryContainer a;
a.Reserve(max_num_bins + 1);
a.SetPrune(summary, max_num_bins + 1);
a.SetPrune(summary_array[fid], max_num_bins + 1);
const bst_float mval = a.data[0].value;
p_cuts_->min_vals_.HostVector()[fid] = mval - (fabs(mval) + 1e-5);
AddCutPoint(a, max_num_bins);
// push a value that is greater than anything
const bst_float cpt
= (a.size > 0) ? a.data[a.size - 1].value : kTrivialSplit;
= (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5);
p_cuts_->cut_values_.HostVector().push_back(last);
Expand Down
3 changes: 3 additions & 0 deletions src/common/hist_util.cu
Expand Up @@ -158,6 +158,9 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
// 9. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) *
// n_columns + n_columns + n_columns + 1
total += std::min(num_rows, num_bins) * num_columns * sizeof(float);
total += num_columns *
sizeof(std::remove_reference_t<decltype(
std::declval<HistogramCuts>().MinValues())>::value_type);
total += (num_columns + 1) *
sizeof(std::remove_reference_t<decltype(
std::declval<HistogramCuts>().Ptrs())>::value_type);
Expand Down
10 changes: 9 additions & 1 deletion src/common/hist_util.h
Expand Up @@ -47,13 +47,17 @@ class HistogramCuts {
public:
HostDeviceVector<bst_float> cut_values_; // NOLINT
HostDeviceVector<uint32_t> cut_ptrs_; // NOLINT
// storing minimum value in a sketch set.
HostDeviceVector<float> min_vals_; // NOLINT

HistogramCuts();
HistogramCuts(HistogramCuts const& that) {
cut_values_.Resize(that.cut_values_.Size());
cut_ptrs_.Resize(that.cut_ptrs_.Size());
min_vals_.Resize(that.min_vals_.Size());
cut_values_.Copy(that.cut_values_);
cut_ptrs_.Copy(that.cut_ptrs_);
min_vals_.Copy(that.min_vals_);
}

HistogramCuts(HistogramCuts&& that) noexcept(true) {
Expand All @@ -63,15 +67,18 @@ class HistogramCuts {
HistogramCuts& operator=(HistogramCuts const& that) {
cut_values_.Resize(that.cut_values_.Size());
cut_ptrs_.Resize(that.cut_ptrs_.Size());
min_vals_.Resize(that.min_vals_.Size());
cut_values_.Copy(that.cut_values_);
cut_ptrs_.Copy(that.cut_ptrs_);
min_vals_.Copy(that.min_vals_);
return *this;
}

HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) {
monitor_ = std::move(that.monitor_);
cut_ptrs_ = std::move(that.cut_ptrs_);
cut_values_ = std::move(that.cut_values_);
min_vals_ = std::move(that.min_vals_);
return *this;
}

Expand All @@ -88,14 +95,15 @@ class HistogramCuts {
// these for now.
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_.ConstHostVector(); }
std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); }
std::vector<float> const& MinValues() const { return min_vals_.ConstHostVector(); }

size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); }

// Return the index of a cut point that is strictly greater than the input
// value, or the last available index if none exists
BinIdx SearchBin(float value, uint32_t column_id) const {
auto beg = cut_ptrs_.ConstHostVector().at(column_id);
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
auto beg = cut_ptrs_.ConstHostVector()[column_id];
const auto &values = cut_values_.ConstHostVector();
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
BinIdx idx = it - values.cbegin();
Expand Down
10 changes: 10 additions & 0 deletions src/common/quantile.cu
Expand Up @@ -499,6 +499,7 @@ void SketchContainer::AllReduce() {
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
p_cuts->min_vals_.Resize(num_columns_);

// Sync between workers.
this->AllReduce();
Expand All @@ -510,6 +511,9 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {

// Set up inputs
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();

p_cuts->min_vals_.SetDevice(device_);
auto d_min_values = p_cuts->min_vals_.DeviceSpan();
auto in_cut_values = dh::ToSpan(this->Current());

// Set up output ptr
Expand Down Expand Up @@ -553,12 +557,18 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
// column is empty, trees cannot split on it. This is just to be consistent with
// rest of the library.
if (idx == 0) {
d_min_values[column_id] = kRtEps;
out_column[0] = kRtEps;
assert(out_column.size() == 1);
}
return;
}

if (idx == 0 && !IsCat(d_ft, column_id)) {
auto mval = in_column[idx].value;
d_min_values[column_id] = mval - (fabs(mval) + 1e-5);
}

if (IsCat(d_ft, column_id)) {
assert(out_column.size() == in_column.size());
out_column[idx] = in_column[idx].value;
Expand Down
6 changes: 5 additions & 1 deletion src/data/ellpack_page.cuh
Expand Up @@ -52,6 +52,8 @@ struct EllpackDeviceAccessor {
size_t base_rowid{};
size_t n_rows{};
common::CompressedIterator<uint32_t> gidx_iter;
/*! \brief Minimum value for each feature. Size equals to number of features. */
common::Span<const bst_float> min_fvalue;
/*! \brief Histogram cut pointers. Size equals to (number of features + 1). */
common::Span<const uint32_t> feature_segments;
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
Expand All @@ -66,8 +68,10 @@ struct EllpackDeviceAccessor {
n_rows(n_rows) ,gidx_iter(gidx_iter){
cuts.cut_values_.SetDevice(device);
cuts.cut_ptrs_.SetDevice(device);
cuts.min_vals_.SetDevice(device);
gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan();
feature_segments = cuts.cut_ptrs_.ConstDeviceSpan();
min_fvalue = cuts.min_vals_.ConstDeviceSpan();
}
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
Expand Down Expand Up @@ -120,7 +124,7 @@ struct EllpackDeviceAccessor {

XGBOOST_DEVICE size_t NumBins() const { return gidx_fvalue_map.size(); }

XGBOOST_DEVICE size_t NumFeatures() const { return feature_segments.size() - 1; }
XGBOOST_DEVICE size_t NumFeatures() const { return min_fvalue.size(); }
};


Expand Down
2 changes: 2 additions & 0 deletions src/data/ellpack_page_raw_format.cu
Expand Up @@ -19,6 +19,7 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
auto* impl = page->Impl();
fi->Read(&impl->Cuts().cut_values_.HostVector());
fi->Read(&impl->Cuts().cut_ptrs_.HostVector());
fi->Read(&impl->Cuts().min_vals_.HostVector());
fi->Read(&impl->n_rows);
fi->Read(&impl->is_dense);
fi->Read(&impl->row_stride);
Expand All @@ -39,6 +40,7 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
auto* impl = page.Impl();
fo->Write(impl->Cuts().cut_values_.ConstHostVector());
fo->Write(impl->Cuts().cut_ptrs_.ConstHostVector());
fo->Write(impl->Cuts().min_vals_.ConstHostVector());
fo->Write(impl->n_rows);
fo->Write(impl->is_dense);
fo->Write(impl->row_stride);
Expand Down
10 changes: 5 additions & 5 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -57,7 +57,7 @@ struct SparsePageLoader {
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
int shared_elements = blockDim.x * num_features;
dh::BlockFill(smem, shared_elements, nanf(""));
cub::CTA_SYNC();
__syncthreads();
if (global_idx < num_rows) {
bst_uint elem_begin = d_row_ptr[global_idx];
bst_uint elem_end = d_row_ptr[global_idx + 1];
Expand All @@ -66,7 +66,7 @@ struct SparsePageLoader {
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
}
}
cub::CTA_SYNC();
__syncthreads();
}
}
__device__ float GetFvalue(int ridx, int fidx) const {
Expand Down Expand Up @@ -113,7 +113,7 @@ struct EllpackLoader {
// The gradient index needs to be shifted by one as min values are not included in the
// cuts.
if (gidx == matrix.feature_segments[fidx]) {
return common::kTrivialSplit;
return matrix.min_fvalue[fidx];
}
return matrix.gidx_fvalue_map[gidx - 1];
}
Expand All @@ -140,7 +140,7 @@ struct DeviceAdapterLoader {
uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
size_t shared_elements = blockDim.x * num_features;
dh::BlockFill(smem, shared_elements, nanf(""));
cub::CTA_SYNC();
__syncthreads();
if (global_idx < num_rows) {
auto beg = global_idx * columns;
auto end = (global_idx + 1) * columns;
Expand All @@ -149,7 +149,7 @@ struct DeviceAdapterLoader {
}
}
}
cub::CTA_SYNC();
__syncthreads();
}

DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
Expand Down
6 changes: 3 additions & 3 deletions src/tree/gpu_hist/evaluate_splits.cu
Expand Up @@ -128,7 +128,7 @@ struct UpdateNumeric {
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = common::kTrivialSplit;
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
Expand Down Expand Up @@ -180,7 +180,7 @@ __device__ void EvaluateFeature(
inputs.value_constraint, missing_left);
}

cub::CTA_SYNC();
__syncthreads();

// Find thread with best gain
cub::KeyValuePair<int, float> tuple(threadIdx.x, gain);
Expand Down Expand Up @@ -231,7 +231,7 @@ __global__ void EvaluateSplitsKernel(
best_split = DeviceSplitCandidate();
}

cub::CTA_SYNC();
__syncthreads();

// If this block is working on the left or right node
bool is_left = blockIdx.x < left.feature_set.size();
Expand Down
1 change: 1 addition & 0 deletions src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -20,6 +20,7 @@ struct EvaluateSplitInputs {
common::Span<FeatureType const> feature_types;
common::Span<const uint32_t> feature_segments;
common::Span<const float> feature_values;
common::Span<const float> min_fvalue;
common::Span<const GradientSumT> gradient_histogram;
ValueConstraint value_constraint;
common::Span<const int> monotonic_constraints;
Expand Down
3 changes: 3 additions & 0 deletions src/tree/updater_gpu_hist.cu
Expand Up @@ -318,6 +318,7 @@ struct GPUHistMakerDevice {
feature_types,
matrix.feature_segments,
matrix.gidx_fvalue_map,
matrix.min_fvalue,
hist.GetNodeHistogram(nidx),
node_value_constraints[nidx],
dh::ToSpan(monotone_constraints)};
Expand Down Expand Up @@ -356,6 +357,7 @@ struct GPUHistMakerDevice {
feature_types,
matrix.feature_segments,
matrix.gidx_fvalue_map,
matrix.min_fvalue,
hist.GetNodeHistogram(left_nidx),
node_value_constraints[left_nidx],
dh::ToSpan(monotone_constraints)};
Expand All @@ -368,6 +370,7 @@ struct GPUHistMakerDevice {
feature_types,
matrix.feature_segments,
matrix.gidx_fvalue_map,
matrix.min_fvalue,
hist.GetNodeHistogram(right_nidx),
node_value_constraints[right_nidx],
dh::ToSpan(monotone_constraints)};
Expand Down
3 changes: 1 addition & 2 deletions src/tree/updater_quantile_hist.cc
Expand Up @@ -25,7 +25,6 @@
#include "./updater_quantile_hist.h"
#include "./split_evaluator.h"
#include "../common/random.h"
#include "../common/common.h"
#include "../common/hist_util.h"
#include "../common/row_set.h"
#include "../common/column_matrix.h"
Expand Down Expand Up @@ -1334,7 +1333,7 @@ GradStats QuantileHistMaker::Builder<GradientSumT>::EnumerateSplit(
snode.root_gain);
if (i == imin) {
// for leftmost bin, left bound is the smallest feature value
split_pt = common::kTrivialSplit;
split_pt = gmat.cut.MinValues()[fid];
} else {
split_pt = cut_val[i - 1];
}
Expand Down

0 comments on commit a4795d1

Please sign in to comment.