diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 45eb5e72593d..c684e6309de8 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -24,6 +24,7 @@ #include "../src/objective/rank_obj.cc" #include "../src/objective/hinge.cc" #include "../src/objective/aft_obj.cc" +#include "../src/objective/adaptive.cc" // gbms #include "../src/gbm/gbm.cc" diff --git a/doc/model.schema b/doc/model.schema index b192cabc6864..02725cb36d31 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -400,7 +400,6 @@ "reg_loss_param" ] }, - { "type": "object", "properties": { @@ -433,6 +432,14 @@ "tweedie_regression_param" ] }, + { + "properties": { + "name": { + "const": "reg:absoluteerror" + } + }, + "type": "object" + }, { "type": "object", "properties": { diff --git a/doc/parameter.rst b/doc/parameter.rst index 781150490082..b361b01d4d9f 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -349,6 +349,7 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective. - ``reg:logistic``: logistic regression. - ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss. + - ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction. - ``binary:logistic``: logistic regression for binary classification, output probability - ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation - ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities. diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index d24057e255a7..cce92d3679f4 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -90,9 +90,8 @@ class GradientBooster : public Model, public Configurable { * \param prediction The output prediction cache entry that needs to be updated. * the booster may change content of gpair */ - virtual void DoBoost(DMatrix* p_fmat, - HostDeviceVector* in_gpair, - PredictionCacheEntry*) = 0; + virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, + PredictionCacheEntry*, ObjFunction const* obj) = 0; /*! * \brief generate predictions for given feature matrix diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 32d0f9fb9f9c..015121560039 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -670,9 +670,13 @@ class Tensor { * See \ref TensorView for parameters of this constructor. */ template - explicit Tensor(I const (&shape)[D], int32_t device) { + explicit Tensor(I const (&shape)[D], int32_t device) + : Tensor{common::Span{shape}, device} {} + + template + explicit Tensor(common::Span shape, int32_t device) { // No device unroll as this is a host only function. - std::copy(shape, shape + D, shape_); + std::copy(shape.data(), shape.data() + D, shape_); for (auto i = D; i < kDim; ++i) { shape_[i] = 1; } diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 44dc46ddc8da..cb0fe7741dc9 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2022 by Contributors * \file objective.h * \brief interface of objective function used by xgboost. * \author Tianqi Chen, Kailong Chen @@ -22,6 +22,8 @@ namespace xgboost { +class RegTree; + /*! \brief interface of objective function */ class ObjFunction : public Configurable { protected: @@ -88,6 +90,22 @@ class ObjFunction : public Configurable { return 1; } + /** + * \brief Update the leaf values after a tree is built. Needed for objectives with 0 + * hessian. + * + * Note that the leaf update is not well defined for distributed training as XGBoost + * computes only an average of quantile between workers. This breaks when some leaf + * have no sample assigned in a local worker. + * + * \param position The leaf index for each rows. + * \param info MetaInfo providing labels and weights. + * \param prediction Model prediction after transformation. + * \param p_tree Tree that needs to be updated. + */ + virtual void UpdateTreeLeaf(HostDeviceVector const& position, MetaInfo const& info, + HostDeviceVector const& prediction, RegTree* p_tree) const {} + /*! * \brief Create an objective function according to name. * \param tparam Generic parameters. diff --git a/include/xgboost/task.h b/include/xgboost/task.h index 537320657544..739207a309d8 100644 --- a/include/xgboost/task.h +++ b/include/xgboost/task.h @@ -33,13 +33,18 @@ struct ObjInfo { } task; // Does the objective have constant hessian value? bool const_hess{false}; + bool zero_hess{false}; - explicit ObjInfo(Task t) : task{t} {} - ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {} + ObjInfo(Task t) : task{t} {} // NOLINT + ObjInfo(Task t, bool khess, bool zhess) : task{t}, const_hess{khess}, zero_hess(zhess) {} XGBOOST_DEVICE bool UseOneHot() const { return (task != ObjInfo::kRegression && task != ObjInfo::kBinary); } + /** + * \brief Use adaptive tree if the objective doesn't have valid hessian value. + */ + XGBOOST_DEVICE bool UpdateTreeLeaf() const { return zero_hess; } }; } // namespace xgboost #endif // XGBOOST_TASK_H_ diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 6189221dc0bf..f0fabb26d9a0 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -49,18 +49,25 @@ class TreeUpdater : public Configurable { * existing trees. */ virtual bool CanModifyTree() const { return false; } + /*! + * \brief Wether the out_position in `Update` is valid. This determines whether adaptive + * tree can be used. + */ + virtual bool HasNodePosition() const { return false; } /*! * \brief perform update to the tree models * \param gpair the gradient pair statistics of the data * \param data The data matrix passed to the updater. - * \param trees references the trees to be updated, updater will change the content of trees + * \param out_position The leaf index for each row. The index is negated if that row is + * removed during sampling. So the 3th node is ~3. + * \param out_trees references the trees to be updated, updater will change the content of trees * note: all the trees in the vector are updated, with the same statistics, * but maybe different random seeds, usually one tree is passed in at a time, * there can be multiple trees when we train random forest style model */ - virtual void Update(HostDeviceVector* gpair, - DMatrix* data, - const std::vector& trees) = 0; + virtual void Update(HostDeviceVector* gpair, DMatrix* data, + common::Span> out_position, + const std::vector& out_trees) = 0; /*! * \brief determines whether updater has enough knowledge about a given dataset diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index b61073360e00..e220e4497141 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2019 by Contributors + * Copyright 2015-2022 by Contributors * \file custom_metric.cc * \brief This is an example to define plugin of xgboost. * This plugin defines the additional metric function. @@ -31,13 +31,9 @@ DMLC_REGISTER_PARAMETER(MyLogisticParam); // Implement the interface. class MyLogistic : public ObjFunction { public: - void Configure(const std::vector >& args) override { - param_.UpdateAllowUnknown(args); - } + void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); } - struct ObjInfo Task() const override { - return {ObjInfo::kRegression, false}; - } + ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, diff --git a/src/common/common.h b/src/common/common.h index fb7e7fee55da..aa2d8197b4a1 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2018 by Contributors + * Copyright 2015-2022 by XGBoost Contributors * \file common.h * \brief Common utilities */ @@ -14,12 +14,12 @@ #include #include #include -#include -#include -#include -#include #include +#include +#include +#include #include +#include #if defined(__CUDACC__) #include @@ -164,6 +164,67 @@ class Range { Iterator end_; }; +/** + * \brief Transform iterator that takes an index and calls transform operator. + * + * This is CPU-only right now as taking host device function as operator complicates the + * code. For device side one can use `thrust::transform_iterator` instead. + */ +template +class IndexTransformIter { + size_t iter_{0}; + Fn fn_; + + public: + using iterator_category = std::random_access_iterator_tag; // NOLINT + using value_type = std::result_of_t; // NOLINT + using difference_type = detail::ptrdiff_t; // NOLINT + using reference = std::add_lvalue_reference_t; // NOLINT + using pointer = std::add_pointer_t; // NOLINT + + public: + /** + * \param op Transform operator, takes a size_t index as input. + */ + explicit IndexTransformIter(Fn &&op) : fn_{op} {} + IndexTransformIter(IndexTransformIter const &) = default; + + value_type operator*() const { return fn_(iter_); } + + auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; } + + IndexTransformIter &operator++() { + iter_++; + return *this; + } + IndexTransformIter operator++(int) { + auto ret = *this; + ++(*this); + return ret; + } + IndexTransformIter &operator+=(difference_type n) { + iter_ += n; + return *this; + } + IndexTransformIter &operator-=(difference_type n) { + (*this) += -n; + return *this; + } + IndexTransformIter operator+(difference_type n) const { + auto ret = *this; + return ret += n; + } + IndexTransformIter operator-(difference_type n) const { + auto ret = *this; + return ret -= n; + } +}; + +template +auto MakeIndexTransformIter(Fn&& fn) { + return IndexTransformIter(std::forward(fn)); +} + int AllVisibleGPUs(); inline void AssertGPUSupport() { @@ -191,13 +252,39 @@ std::vector ArgSort(Container const &array, Comp comp = std::less{}) { struct OptionalWeights { Span weights; - float dft{1.0f}; + float dft{1.0f}; // fixme: make this compile time constant explicit OptionalWeights(Span w) : weights{w} {} explicit OptionalWeights(float w) : dft{w} {} XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; } }; + +/** + * Last index of a group in a CSR style of index pointer. + */ +template +XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) { + return indptr[group + 1] - 1; +} + +/** + * \brief Run length encode on CPU, input must be sorted. + */ +template +void RunLengthEncode(Iter begin, Iter end, std::vector *p_out) { + auto &out = *p_out; + out = std::vector{0}; + size_t n = std::distance(begin, end); + for (size_t i = 1; i < n; ++i) { + if (begin[i] != begin[i - 1]) { + out.push_back(i); + } + } + if (out.back() != n) { + out.push_back(n); + } +} } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_COMMON_H_ diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 9adf866fece9..334e3b4f89bf 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #pragma once #include @@ -1537,6 +1537,43 @@ void SegmentedArgSort(xgboost::common::Span values, sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); } +/** + * \brief Different from the above one, this one can handle cases where segment doesn't + * start from 0, but as a result it uses comparison sort. + */ +template +void SegmentedArgSort(SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end, + dh::device_vector *p_sorted_idx) { + using Tup = thrust::tuple; + auto &sorted_idx = *p_sorted_idx; + size_t n = std::distance(val_begin, val_end); + sorted_idx.resize(n); + dh::Iota(dh::ToSpan(sorted_idx)); + dh::device_vector keys(sorted_idx.size()); + auto key_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(size_t i) -> Tup { + int32_t leaf_idx; + if (i < *seg_begin) { + leaf_idx = -1; + } else { + leaf_idx = dh::SegmentId(seg_begin, seg_end, i); + } + auto residue = val_begin[i]; + return thrust::make_tuple(leaf_idx, residue); + }); + dh::XGBCachingDeviceAllocator caching; + thrust::copy(thrust::cuda::par(caching), key_it, key_it + keys.size(), keys.begin()); + + dh::XGBDeviceAllocator alloc; + thrust::stable_sort_by_key(thrust::cuda::par(alloc), keys.begin(), keys.end(), sorted_idx.begin(), + [=] XGBOOST_DEVICE(Tup const &l, Tup const &r) { + if (thrust::get<0>(l) != thrust::get<0>(r)) { + return thrust::get<0>(l) < thrust::get<0>(r); // segment index + } + return thrust::get<1>(l) < thrust::get<1>(r); // residue + }); +} + class CUDAStreamView; class CUDAEvent { @@ -1600,5 +1637,6 @@ class CUDAStream { } CUDAStreamView View() const { return CUDAStreamView{stream_}; } + void Sync() { this->View().Sync(); } }; } // namespace dh diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index f0f89df8ab26..558a09ca6acb 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -13,6 +13,7 @@ namespace xgboost { namespace linalg { template void ElementWiseKernelDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { + dh::safe_cuda(cudaSetDevice(t.DeviceIdx())); static_assert(std::is_void>::value, "For function with return, use transform instead."); if (t.Contiguous()) { @@ -40,7 +41,7 @@ void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_ } template -void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView t, Fn&& fn) { +void ElementWiseKernel(Context const* ctx, linalg::TensorView t, Fn&& fn) { ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn); } } // namespace linalg diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 3250b9d2bf25..648cbe61a3a3 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -12,10 +12,12 @@ #include #include #include +#include #include #include "categorical.h" #include "column_matrix.h" +#include "xgboost/generic_parameters.h" #include "xgboost/tree_model.h" namespace xgboost { @@ -254,7 +256,7 @@ class PartitionBuilder { n_left += mem_blocks_[j]->n_left; } size_t n_right = 0; - for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) { + for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i + 1]; ++j) { mem_blocks_[j]->n_offset_right = n_left + n_right; n_right += mem_blocks_[j]->n_right; } @@ -279,6 +281,30 @@ class PartitionBuilder { return blocks_offsets_[nid] + begin / BlockSize; } + // Copy row partitions into global cache for reuse in objective + template + void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set, + std::vector* p_position, Sampledp sampledp) const { + auto& h_pos = *p_position; + h_pos.resize(row_set.Data()->size(), std::numeric_limits::max()); + + auto p_begin = row_set.Data()->data(); + ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) { + auto const& node = row_set[i]; + if (node.node_id < 0) { + return; + } + CHECK(tree[node.node_id].IsLeaf()); + if (node.begin) { // guard for empty node. + size_t ptr_offset = node.end - p_begin; + CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id; + for (auto idx = node.begin; idx != node.end; ++idx) { + h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id; + } + } + }); + } + protected: struct BlockInfo{ size_t n_left; diff --git a/src/common/row_set.h b/src/common/row_set.h index dc61d5f5d877..87d5f52874f2 100644 --- a/src/common/row_set.h +++ b/src/common/row_set.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017 by Contributors + * Copyright 2017-2022 by Contributors * \file row_set.h * \brief Quick Utility to compute subset of rows * \author Philip Cho, Tianqi Chen @@ -15,10 +15,15 @@ namespace xgboost { namespace common { - /*! \brief collection of rowset */ class RowSetCollection { public: + RowSetCollection() = default; + RowSetCollection(RowSetCollection const&) = delete; + RowSetCollection(RowSetCollection&&) = default; + RowSetCollection& operator=(RowSetCollection const&) = delete; + RowSetCollection& operator=(RowSetCollection&&) = default; + /*! \brief data structure to store an instance set, a subset of * rows (instances) associated with a particular node in a decision * tree. */ @@ -38,20 +43,17 @@ class RowSetCollection { return end - begin; } }; - /* \brief specifies how to split a rowset into two */ - struct Split { - std::vector left; - std::vector right; - }; - inline std::vector::const_iterator begin() const { // NOLINT + std::vector::const_iterator begin() const { // NOLINT return elem_of_each_node_.begin(); } - inline std::vector::const_iterator end() const { // NOLINT + std::vector::const_iterator end() const { // NOLINT return elem_of_each_node_.end(); } + size_t Size() const { return std::distance(begin(), end()); } + /*! \brief return corresponding element set given the node_id */ inline const Elem& operator[](unsigned node_id) const { const Elem& e = elem_of_each_node_[node_id]; @@ -86,6 +88,8 @@ class RowSetCollection { } std::vector* Data() { return &row_indices_; } + std::vector const* Data() const { return &row_indices_; } + // split rowset into two inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id, size_t n_left, size_t n_right) { @@ -123,7 +127,6 @@ class RowSetCollection { // vector: node_id -> elements std::vector elem_of_each_node_; }; - } // namespace common } // namespace xgboost diff --git a/src/common/stats.cuh b/src/common/stats.cuh new file mode 100644 index 000000000000..9d9e526a8576 --- /dev/null +++ b/src/common/stats.cuh @@ -0,0 +1,127 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_STATS_CUH_ +#define XGBOOST_COMMON_STATS_CUH_ + +#include +#include + +#include // std::distance + +#include "device_helpers.cuh" +#include "linalg_op.cuh" +#include "xgboost/generic_parameters.h" +#include "xgboost/linalg.h" +#include "xgboost/tree_model.h" + +namespace xgboost { +namespace common { +/** + * \brief Compute segmented quantile on GPU. + * + * \tparam SegIt Iterator for CSR style segments indptr + * \tparam ValIt Iterator for values + * + * \param alpha The p^th quantile we want to compute + * + * std::distance(ptr_begin, ptr_end) should be equal to n_segments + 1 + */ +template +void SegmentedQuantile(Context const* ctx, double alpha, SegIt seg_begin, SegIt seg_end, + ValIt val_begin, ValIt val_end, HostDeviceVector* quantiles) { + CHECK(alpha >= 0 && alpha <= 1); + + dh::device_vector sorted_idx; + using Tup = thrust::tuple; + dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx); + auto n_segments = std::distance(seg_begin, seg_end) - 1; + if (n_segments <= 0) { + return; + } + + quantiles->SetDevice(ctx->gpu_id); + quantiles->Resize(n_segments); + auto d_results = quantiles->DeviceSpan(); + auto d_sorted_idx = dh::ToSpan(sorted_idx); + + auto val = thrust::make_permutation_iterator(val_begin, dh::tcbegin(d_sorted_idx)); + + dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) { + // each segment is the index of a leaf. + size_t seg_idx = i; + size_t begin = seg_begin[seg_idx]; + auto n = static_cast(seg_begin[seg_idx + 1] - begin); + if (n == 0) { + d_results[i] = std::numeric_limits::quiet_NaN(); + return; + } + + if (alpha <= (1 / (n + 1))) { + d_results[i] = val[begin]; + return; + } + if (alpha >= (n / (n + 1))) { + d_results[i] = val[common::LastOf(seg_idx, seg_begin)]; + return; + } + + double x = alpha * static_cast(n + 1); + double k = std::floor(x) - 1; + double d = (x - 1) - k; + + auto v0 = val[begin + static_cast(k)]; + auto v1 = val[begin + static_cast(k) + 1]; + d_results[seg_idx] = v0 + d * (v1 - v0); + }); +} + +template +void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg, SegIt seg_end, + ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end, + HostDeviceVector* quantiles) { + CHECK(alpha >= 0 && alpha <= 1); + dh::device_vector sorted_idx; + dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx); + auto d_sorted_idx = dh::ToSpan(sorted_idx); + size_t n_weights = std::distance(w_begin, w_end); + dh::device_vector weights_cdf(n_weights); + + dh::XGBCachingDeviceAllocator caching; + auto scan_key = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(seg_beg, seg_end, i); }); + auto scan_val = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(size_t i) { return w_begin[d_sorted_idx[i]]; }); + thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights, + scan_val, weights_cdf.begin()); + + auto n_segments = std::distance(seg_beg, seg_end) - 1; + quantiles->SetDevice(ctx->gpu_id); + quantiles->Resize(n_segments); + auto d_results = quantiles->DeviceSpan(); + auto d_weight_cdf = dh::ToSpan(weights_cdf); + + dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) { + size_t seg_idx = i; + size_t begin = seg_beg[seg_idx]; + auto n = static_cast(seg_beg[seg_idx + 1] - begin); + if (n == 0) { + d_results[i] = std::numeric_limits::quiet_NaN(); + return; + } + auto leaf_cdf = d_weight_cdf.subspan(begin, static_cast(n)); + auto leaf_sorted_idx = d_sorted_idx.subspan(begin, static_cast(n)); + float thresh = leaf_cdf.back() * alpha; + + size_t idx = thrust::lower_bound(thrust::seq, leaf_cdf.data(), + leaf_cdf.data() + leaf_cdf.size(), thresh) - + leaf_cdf.data(); + idx = std::min(idx, static_cast(n - 1)); + d_results[i] = val_begin[leaf_sorted_idx[idx]]; + }); +} +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_STATS_CUH_ diff --git a/src/common/stats.h b/src/common/stats.h new file mode 100644 index 000000000000..4ad9e4aa770a --- /dev/null +++ b/src/common/stats.h @@ -0,0 +1,95 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_STATS_H_ +#define XGBOOST_COMMON_STATS_H_ +#include +#include +#include +#include + +#include "common.h" +#include "xgboost/linalg.h" + +namespace xgboost { +namespace common { + +/** + * \brief Percentile with masked array using linear interpolation. + * + * https://www.itl.nist.gov/div898/handbook/prc/section2/prc262.htm + * + * \param alpha Percentile, must be in range [0, 1]. + * \param begin Iterator begin for input array. + * \param end Iterator end for input array. + * + * \return The result of interpolation. + */ +template +float Quantile(double alpha, Iter const& begin, Iter const& end) { + CHECK(alpha >= 0 && alpha <= 1); + auto n = static_cast(std::distance(begin, end)); + if (n == 0) { + return std::numeric_limits::quiet_NaN(); + } + + std::vector sorted_idx(n); + std::iota(sorted_idx.begin(), sorted_idx.end(), 0); + std::stable_sort(sorted_idx.begin(), sorted_idx.end(), + [&](size_t l, size_t r) { return *(begin + l) < *(begin + r); }); + + auto val = [&](size_t i) { return *(begin + sorted_idx[i]); }; + static_assert(std::is_same::value, ""); + + if (alpha <= (1 / (n + 1))) { + return val(0); + } + if (alpha >= (n / (n + 1))) { + return val(sorted_idx.size() - 1); + } + assert(n != 0 && "The number of rows in a leaf can not be zero."); + double x = alpha * static_cast((n + 1)); + double k = std::floor(x) - 1; + CHECK_GE(k, 0); + double d = (x - 1) - k; + + auto v0 = val(static_cast(k)); + auto v1 = val(static_cast(k) + 1); + return v0 + d * (v1 - v0); +} + +/** + * \brief Calculate the weighted quantile with step function. Unlike the unweighted + * version, no interpolation is used. + * + * See https://aakinshin.net/posts/weighted-quantiles/ for some discussion on computing + * weighted quantile with interpolation. + */ +template +float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) { + auto n = static_cast(std::distance(begin, end)); + if (n == 0) { + return std::numeric_limits::quiet_NaN(); + } + std::vector sorted_idx(n); + std::iota(sorted_idx.begin(), sorted_idx.end(), 0); + std::stable_sort(sorted_idx.begin(), sorted_idx.end(), + [&](size_t l, size_t r) { return *(begin + l) < *(begin + r); }); + + auto val = [&](size_t i) { return *(begin + sorted_idx[i]); }; + + std::vector weight_cdf(n); // S_n + // weighted cdf is sorted during construction + weight_cdf[0] = *(weights + sorted_idx[0]); + for (size_t i = 1; i < n; ++i) { + weight_cdf[i] = weight_cdf[i - 1] + *(weights + sorted_idx[i]); + } + float thresh = weight_cdf.back() * alpha; + size_t idx = + std::lower_bound(weight_cdf.cbegin(), weight_cdf.cend(), thresh) - weight_cdf.cbegin(); + idx = std::min(idx, static_cast(n - 1)); + return val(idx); +} +} // namespace common +} // namespace xgboost +#endif // XGBOOST_COMMON_STATS_H_ diff --git a/src/data/data.cc b/src/data/data.cc index 86f73523a39d..c297527c6bae 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -512,16 +512,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { } } CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data."; - group_ptr_.clear(); - group_ptr_.push_back(0); - for (size_t i = 1; i < query_ids.size(); ++i) { - if (query_ids[i] != query_ids[i - 1]) { - group_ptr_.push_back(i); - } - } - if (group_ptr_.back() != query_ids.size()) { - group_ptr_.push_back(query_ids.size()); - } + common::RunLengthEncode(query_ids.cbegin(), query_ids.cend(), &group_ptr_); data::ValidateQueryGroup(group_ptr_); return; } diff --git a/src/data/iterative_device_dmatrix.h b/src/data/iterative_device_dmatrix.h index ba2d4a92f9da..031b289f2760 100644 --- a/src/data/iterative_device_dmatrix.h +++ b/src/data/iterative_device_dmatrix.h @@ -68,7 +68,7 @@ class IterativeDeviceDMatrix : public DMatrix { BatchSet GetEllpackBatches(const BatchParam& param) override; - bool SingleColBlock() const override { return false; } + bool SingleColBlock() const override { return true; } MetaInfo &Info() override { return info_; } MetaInfo const &Info() const override { return info_; } diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index cbf6ffebfca5..0e983fe4b37f 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -134,9 +134,8 @@ class GBLinear : public GradientBooster { this->updater_->SaveConfig(&j_updater); } - void DoBoost(DMatrix *p_fmat, - HostDeviceVector *in_gpair, - PredictionCacheEntry*) override { + void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry*, + ObjFunction const*) override { monitor_.Start("DoBoost"); model_.LazyInitModel(); diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index ec611ee95a68..bb7c341f8beb 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -1,33 +1,34 @@ /*! - * Copyright 2014-2021 by Contributors + * Copyright 2014-2022 by Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen */ +#include "gbtree.h" + #include #include -#include +#include +#include #include -#include #include -#include -#include +#include +#include +#include "../common/common.h" +#include "../common/random.h" +#include "../common/threading_utils.h" +#include "../common/timer.h" +#include "gbtree_model.h" #include "xgboost/data.h" #include "xgboost/gbm.h" -#include "xgboost/logging.h" +#include "xgboost/host_device_vector.h" #include "xgboost/json.h" +#include "xgboost/logging.h" +#include "xgboost/objective.h" #include "xgboost/predictor.h" #include "xgboost/tree_updater.h" -#include "xgboost/host_device_vector.h" - -#include "gbtree.h" -#include "gbtree_model.h" -#include "../common/common.h" -#include "../common/random.h" -#include "../common/timer.h" -#include "../common/threading_utils.h" namespace xgboost { namespace gbm { @@ -216,53 +217,68 @@ void CopyGradient(HostDeviceVector const* in_gpair, int32_t n_thre } } -void GBTree::DoBoost(DMatrix* p_fmat, - HostDeviceVector* in_gpair, - PredictionCacheEntry* predt) { - std::vector > > new_trees; +void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const& predictions, + ObjFunction const* obj, size_t gidx, + std::vector>* p_trees) { + CHECK(!updaters_.empty()); + if (!updaters_.back()->HasNodePosition()) { + return; + } + if (!obj || !obj->Task().UpdateTreeLeaf()) { + return; + } + auto& trees = *p_trees; + for (size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) { + auto const& position = this->node_position_.at(tree_idx); + obj->UpdateTreeLeaf(position, p_fmat->Info(), predictions, trees[tree_idx].get()); + } +} + +void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, + PredictionCacheEntry* predt, ObjFunction const* obj) { + std::vector>> new_trees; const int ngroup = model_.learner_model_param->num_output_group; ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); // Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let // `gpu_id` be the single source of determining what algorithms to run, but that will // break a lots of existing code. - auto device = tparam_.tree_method != TreeMethod::kGPUHist - ? GenericParameter::kCpuId - : ctx_->gpu_id; + auto device = tparam_.tree_method != TreeMethod::kGPUHist ? Context::kCpuId : ctx_->gpu_id; auto out = linalg::TensorView{ - device == GenericParameter::kCpuId ? predt->predictions.HostSpan() - : predt->predictions.DeviceSpan(), - {static_cast(p_fmat->Info().num_row_), - static_cast(ngroup)}, + device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(), + {static_cast(p_fmat->Info().num_row_), static_cast(ngroup)}, device}; CHECK_NE(ngroup, 0); + + if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) { + LOG(FATAL) << "Current objective doesn't support external memory."; + } + if (ngroup == 1) { std::vector> ret; BoostNewTrees(in_gpair, p_fmat, 0, &ret); + UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); auto v_predt = out.Slice(linalg::All(), 0); - if (updaters_.size() > 0 && num_new_trees == 1 && - predt->predictions.Size() > 0 && + if (updaters_.size() > 0 && num_new_trees == 1 && predt->predictions.Size() > 0 && updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) { predt->Update(1); } } else { - CHECK_EQ(in_gpair->Size() % ngroup, 0U) - << "must have exactly ngroup * nrow gpairs"; - HostDeviceVector tmp(in_gpair->Size() / ngroup, - GradientPair(), + CHECK_EQ(in_gpair->Size() % ngroup, 0U) << "must have exactly ngroup * nrow gpairs"; + HostDeviceVector tmp(in_gpair->Size() / ngroup, GradientPair(), in_gpair->DeviceIdx()); bool update_predict = true; for (int gid = 0; gid < ngroup; ++gid) { CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp); - std::vector > ret; + std::vector> ret; BoostNewTrees(&tmp, p_fmat, gid, &ret); + UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); auto v_predt = out.Slice(linalg::All(), gid); - if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 && - num_new_trees == 1 && + if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 && num_new_trees == 1 && updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) { update_predict = false; } @@ -271,6 +287,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, predt->Update(1); } } + monitor_.Stop("BoostNewTrees"); this->CommitModel(std::move(new_trees), p_fmat, predt); } @@ -316,10 +333,8 @@ void GBTree::InitUpdater(Args const& cfg) { } } -void GBTree::BoostNewTrees(HostDeviceVector* gpair, - DMatrix *p_fmat, - int bst_group, - std::vector >* ret) { +void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fmat, int bst_group, + std::vector>* ret) { std::vector new_trees; ret->clear(); // create the trees @@ -338,9 +353,9 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, } else if (tparam_.process_type == TreeProcessType::kUpdate) { for (auto const& up : updaters_) { CHECK(up->CanModifyTree()) - << "Updater: `" << up->Name() << "` " - << "can not be used to modify existing trees. " - << "Set `process_type` to `default` if you want to build new trees."; + << "Updater: `" << up->Name() << "` " + << "can not be used to modify existing trees. " + << "Set `process_type` to `default` if you want to build new trees."; } CHECK_LT(model_.trees.size(), model_.trees_to_update.size()) << "No more tree left for updating. For updating existing trees, " @@ -356,8 +371,10 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_) << "Mismatching size between number of rows from input data and size of " "gradient vector."; + node_position_.resize(new_trees.size()); for (auto& up : updaters_) { - up->Update(gpair, p_fmat, new_trees); + up->Update(gpair, p_fmat, common::Span>{node_position_}, + new_trees); } } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 67d9e212888a..020b7d0cb9c0 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2021 by Contributors + * Copyright 2014-2022 by Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -202,10 +202,16 @@ class GBTree : public GradientBooster { void ConfigureUpdaters(); void ConfigureWithKnownData(Args const& cfg, DMatrix* fmat); + /** + * \brief Optionally update the leaf value. + */ + void UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const& predictions, + ObjFunction const* obj, size_t gidx, + std::vector>* p_trees); + /*! \brief Carry out one iteration of boosting */ - void DoBoost(DMatrix* p_fmat, - HostDeviceVector* in_gpair, - PredictionCacheEntry* predt) override; + void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, + PredictionCacheEntry* predt, ObjFunction const* obj) override; bool UseGPU() const override { return @@ -435,6 +441,9 @@ class GBTree : public GradientBooster { Args cfg_; // the updaters that can be applied to each of tree std::vector> updaters_; + // The node position for each row, 1 HDV for each tree in the forest. Note that the + // position is negated if the row is sampled out. + std::vector> node_position_; // Predictors std::unique_ptr cpu_predictor_; #if defined(XGBOOST_USE_CUDA) diff --git a/src/learner.cc b/src/learner.cc index 73447cf2ef1a..1fc987d65427 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -419,6 +419,7 @@ class LearnerConfiguration : public Learner { obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); } obj_->LoadConfig(objective_fn); + learner_model_param_.task = obj_->Task(); tparam_.booster = get(gradient_booster["name"]); if (!gbm_) { @@ -1168,7 +1169,7 @@ class LearnerImpl : public LearnerIO { monitor_.Stop("GetGradient"); TrainingObserver::Instance().Observe(gpair_, "Gradients"); - gbm_->DoBoost(train.get(), &gpair_, &predt); + gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get()); monitor_.Stop("UpdateOneIter"); } @@ -1185,7 +1186,7 @@ class LearnerImpl : public LearnerIO { auto local_cache = this->GetPredictionCache(); local_cache->Cache(train, generic_parameters_.gpu_id); - gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get())); + gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get()); monitor_.Stop("BoostOneIter"); } diff --git a/src/metric/auc.cu b/src/metric/auc.cu index be89c015c93d..5faa116c8561 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #include #include @@ -201,14 +201,6 @@ void Transpose(common::Span in, common::Span out, size_t m, }); } -/** - * Last index of a group in a CSR style of index pointer. - */ -template -XGBOOST_DEVICE size_t LastOf(size_t group, common::Span indptr) { - return indptr[group + 1] - 1; -} - double ScaleClasses(common::Span results, common::Span local_area, common::Span fp, common::Span tp, common::Span auc, @@ -300,9 +292,9 @@ void SegmentedReduceAUC(common::Span d_unique_idx, double fp, tp, fp_prev, tp_prev; if (i == d_unique_class_ptr[class_id]) { // first item is ignored, we use this thread to calculate the last item - thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)]; + thrust::tie(fp, tp) = d_fptp[common::LastOf(class_id, d_class_ptr)]; thrust::tie(fp_prev, tp_prev) = - d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]]; + d_neg_pos[d_unique_idx[common::LastOf(class_id, d_unique_class_ptr)]]; } else { thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1]; thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]]; @@ -413,10 +405,10 @@ double GPUMultiClassAUCOVR(common::Span predts, } uint32_t class_id = d_unique_idx[i] / n_samples; d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; - if (i == LastOf(class_id, d_unique_class_ptr)) { + if (i == common::LastOf(class_id, d_unique_class_ptr)) { // last one needs to be included. - size_t last = d_unique_idx[LastOf(class_id, d_unique_class_ptr)]; - d_neg_pos[LastOf(class_id, d_class_ptr)] = d_fptp[last - 1]; + size_t last = d_unique_idx[common::LastOf(class_id, d_unique_class_ptr)]; + d_neg_pos[common::LastOf(class_id, d_class_ptr)] = d_fptp[last - 1]; return; } }); @@ -592,7 +584,7 @@ GPURankingAUC(common::Span predts, MetaInfo const &info, auto data_group_begin = d_group_ptr[group_id]; size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin; // last item of current group - if (item.idx == LastOf(group_id, d_threads_group_ptr)) { + if (item.idx == common::LastOf(group_id, d_threads_group_ptr)) { if (item.w > 0) { s_d_auc[group_id] = item.predt / item.w; } else { @@ -797,10 +789,10 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, } auto group_idx = dh::SegmentId(d_group_ptr, d_unique_idx[i]); d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1]; - if (i == LastOf(group_idx, d_unique_class_ptr)) { + if (i == common::LastOf(group_idx, d_unique_class_ptr)) { // last one needs to be included. - size_t last = d_unique_idx[LastOf(group_idx, d_unique_class_ptr)]; - d_neg_pos[LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1]; + size_t last = d_unique_idx[common::LastOf(group_idx, d_unique_class_ptr)]; + d_neg_pos[common::LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1]; return; } }); @@ -821,7 +813,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, auto it = dh::MakeTransformIterator>( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) { double fp, tp; - thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)]; + thrust::tie(fp, tp) = d_fptp[common::LastOf(g, d_group_ptr)]; double area = fp * tp; auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g]; if (area > 0 && n_documents >= 2) { diff --git a/src/objective/adaptive.cc b/src/objective/adaptive.cc new file mode 100644 index 000000000000..f2675d918bdf --- /dev/null +++ b/src/objective/adaptive.cc @@ -0,0 +1,126 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#include "adaptive.h" + +#include +#include + +#include "../common/common.h" +#include "../common/stats.h" +#include "../common/threading_utils.h" +#include "xgboost/tree_model.h" + +namespace xgboost { +namespace obj { +namespace detail { +void EncodeTreeLeafHost(RegTree const& tree, std::vector const& position, + std::vector* p_nptr, std::vector* p_nidx, + std::vector* p_ridx) { + auto& nptr = *p_nptr; + auto& nidx = *p_nidx; + auto& ridx = *p_ridx; + ridx = common::ArgSort(position); + std::vector sorted_pos(position); + // permutation + for (size_t i = 0; i < position.size(); ++i) { + sorted_pos[i] = position[ridx[i]]; + } + // find the first non-sampled row + auto begin_pos = + std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(), + [](bst_node_t nidx) { return nidx >= 0; })); + CHECK_LE(begin_pos, sorted_pos.size()); + + std::vector leaf; + tree.WalkTree([&](bst_node_t nidx) { + if (tree[nidx].IsLeaf()) { + leaf.push_back(nidx); + } + return true; + }); + + if (begin_pos == sorted_pos.size()) { + nidx = leaf; + return; + } + + auto beg_it = sorted_pos.begin() + begin_pos; + common::RunLengthEncode(beg_it, sorted_pos.end(), &nptr); + CHECK_GT(nptr.size(), 0); + // skip the sampled rows in indptr + std::transform(nptr.begin(), nptr.end(), nptr.begin(), + [begin_pos](size_t ptr) { return ptr + begin_pos; }); + + size_t n_leaf = nptr.size() - 1; + auto n_unique = std::unique(beg_it, sorted_pos.end()) - beg_it; + CHECK_EQ(n_unique, n_leaf); + nidx.resize(n_leaf); + std::copy(beg_it, beg_it + n_unique, nidx.begin()); + + if (n_leaf != leaf.size()) { + FillMissingLeaf(leaf, &nidx, &nptr); + } +} + +void UpdateTreeLeafHost(Context const* ctx, std::vector const& position, + MetaInfo const& info, HostDeviceVector const& predt, float alpha, + RegTree* p_tree) { + auto& tree = *p_tree; + + std::vector nidx; + std::vector nptr; + std::vector ridx; + EncodeTreeLeafHost(*p_tree, position, &nptr, &nidx, &ridx); + size_t n_leaf = nidx.size(); + if (nptr.empty()) { + std::vector quantiles; + UpdateLeafValues(&quantiles, nidx, p_tree); + return; + } + + CHECK(!position.empty()); + std::vector quantiles(n_leaf, 0); + std::vector n_valids(n_leaf, 0); + + auto const& h_node_idx = nidx; + auto const& h_node_ptr = nptr; + CHECK_LE(h_node_ptr.back(), info.num_row_); + // loop over each leaf + common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { + auto nidx = h_node_idx[k]; + CHECK(tree[nidx].IsLeaf()); + CHECK_LT(k + 1, h_node_ptr.size()); + size_t n = h_node_ptr[k + 1] - h_node_ptr[k]; + auto h_row_set = common::Span{ridx}.subspan(h_node_ptr[k], n); + // multi-target not yet supported. + auto h_labels = info.labels.HostView().Slice(linalg::All(), 0); + auto const& h_predt = predt.ConstHostVector(); + auto h_weights = linalg::MakeVec(&info.weights_); + + auto iter = common::MakeIndexTransformIter([&](size_t i) -> float { + auto row_idx = h_row_set[i]; + return h_labels(row_idx) - h_predt[row_idx]; + }); + auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float { + auto row_idx = h_row_set[i]; + return h_weights(row_idx); + }); + + float q{0}; + if (info.weights_.Empty()) { + q = common::Quantile(alpha, iter, iter + h_row_set.size()); + } else { + q = common::WeightedQuantile(alpha, iter, iter + h_row_set.size(), w_it); + } + if (std::isnan(q)) { + CHECK(h_row_set.empty()); + } + quantiles.at(k) = q; + }); + + UpdateLeafValues(&quantiles, nidx, p_tree); +} +} // namespace detail +} // namespace obj +} // namespace xgboost diff --git a/src/objective/adaptive.cu b/src/objective/adaptive.cu new file mode 100644 index 000000000000..42d239acd977 --- /dev/null +++ b/src/objective/adaptive.cu @@ -0,0 +1,182 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#include + +#include + +#include "../common/device_helpers.cuh" +#include "../common/stats.cuh" +#include "adaptive.h" + +namespace xgboost { +namespace obj { +namespace detail { +void EncodeTreeLeafDevice(Context const* ctx, common::Span position, + dh::device_vector* p_ridx, HostDeviceVector* p_nptr, + HostDeviceVector* p_nidx, RegTree const& tree) { + // copy position to buffer + dh::safe_cuda(cudaSetDevice(ctx->gpu_id)); + size_t n_samples = position.size(); + dh::XGBDeviceAllocator alloc; + dh::device_vector sorted_position(position.size()); + dh::safe_cuda(cudaMemcpyAsync(sorted_position.data().get(), position.data(), + position.size_bytes(), cudaMemcpyDeviceToDevice)); + + p_ridx->resize(position.size()); + dh::Iota(dh::ToSpan(*p_ridx)); + // sort row index according to node index + thrust::stable_sort_by_key(thrust::cuda::par(alloc), sorted_position.begin(), + sorted_position.begin() + n_samples, p_ridx->begin()); + dh::XGBCachingDeviceAllocator caching; + auto beg_pos = + thrust::find_if(thrust::cuda::par(caching), sorted_position.cbegin(), sorted_position.cend(), + [] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) - + sorted_position.cbegin(); + if (beg_pos == sorted_position.size()) { + auto& leaf = p_nidx->HostVector(); + tree.WalkTree([&](bst_node_t nidx) { + if (tree[nidx].IsLeaf()) { + leaf.push_back(nidx); + } + return true; + }); + return; + } + + size_t n_leaf = tree.GetNumLeaves(); + size_t max_n_unique = n_leaf; + + dh::caching_device_vector counts_out(max_n_unique + 1, 0); + auto d_counts_out = dh::ToSpan(counts_out).subspan(0, max_n_unique); + auto d_num_runs_out = dh::ToSpan(counts_out).subspan(max_n_unique, 1); + dh::caching_device_vector unique_out(max_n_unique, 0); + auto d_unique_out = dh::ToSpan(unique_out); + + size_t nbytes; + auto begin_it = sorted_position.begin() + beg_pos; + cub::DeviceRunLengthEncode::Encode(nullptr, nbytes, begin_it, unique_out.data().get(), + counts_out.data().get(), d_num_runs_out.data(), + n_samples - beg_pos); + dh::TemporaryArray temp(nbytes); + cub::DeviceRunLengthEncode::Encode(temp.data().get(), nbytes, begin_it, unique_out.data().get(), + counts_out.data().get(), d_num_runs_out.data(), + n_samples - beg_pos); + + dh::PinnedMemory pinned_pool; + auto pinned = pinned_pool.GetSpan(sizeof(size_t) + sizeof(bst_node_t)); + dh::CUDAStream copy_stream; + size_t* h_num_runs = reinterpret_cast(pinned.subspan(0, sizeof(size_t)).data()); + // flag for whether there's ignored position + bst_node_t* h_first_unique = + reinterpret_cast(pinned.subspan(sizeof(size_t), sizeof(bst_node_t)).data()); + dh::safe_cuda(cudaMemcpyAsync(h_num_runs, d_num_runs_out.data(), sizeof(size_t), + cudaMemcpyDeviceToHost, copy_stream.View())); + dh::safe_cuda(cudaMemcpyAsync(h_first_unique, d_unique_out.data(), sizeof(bst_node_t), + cudaMemcpyDeviceToHost, copy_stream.View())); + + /** + * copy node index (leaf index) + */ + auto& nidx = *p_nidx; + auto& nptr = *p_nptr; + nidx.SetDevice(ctx->gpu_id); + nidx.Resize(n_leaf); + auto d_node_idx = nidx.DeviceSpan(); + + nptr.SetDevice(ctx->gpu_id); + nptr.Resize(n_leaf + 1, 0); + auto d_node_ptr = nptr.DeviceSpan(); + + dh::LaunchN(n_leaf, [=] XGBOOST_DEVICE(size_t i) { + if (i >= d_num_runs_out[0]) { + // d_num_runs_out <= max_n_unique + // this omits all the leaf that are empty. A leaf can be empty when there's + // missing data, which can be caused by sparse input and distributed training. + return; + } + d_node_idx[i] = d_unique_out[i]; + d_node_ptr[i + 1] = d_counts_out[i]; + if (i == 0) { + d_node_ptr[0] = beg_pos; + } + }); + thrust::inclusive_scan(thrust::cuda::par(caching), dh::tbegin(d_node_ptr), dh::tend(d_node_ptr), + dh::tbegin(d_node_ptr)); + copy_stream.View().Sync(); + CHECK_GT(*h_num_runs, 0); + CHECK_LE(*h_num_runs, n_leaf); + + if (*h_num_runs < n_leaf) { + // shrink to omit the sampled nodes. + nptr.Resize(*h_num_runs + 1); + nidx.Resize(*h_num_runs); + + std::vector leaves; + tree.WalkTree([&](bst_node_t nidx) { + if (tree[nidx].IsLeaf()) { + leaves.push_back(nidx); + } + return true; + }); + CHECK_EQ(leaves.size(), n_leaf); + // Fill all the leaves that don't have any sample. This is hacky and inefficient. An + // alternative is to leave the objective to handle missing leaf, which is more messy + // as we need to take other distributed workers into account. + auto& h_nidx = nidx.HostVector(); + auto& h_nptr = nptr.HostVector(); + FillMissingLeaf(leaves, &h_nidx, &h_nptr); + nidx.DevicePointer(); + nptr.DevicePointer(); + } + CHECK_EQ(nidx.Size(), n_leaf); + CHECK_EQ(nptr.Size(), n_leaf + 1); +} + +void UpdateTreeLeafDevice(Context const* ctx, common::Span position, + MetaInfo const& info, HostDeviceVector const& predt, float alpha, + RegTree* p_tree) { + dh::safe_cuda(cudaSetDevice(ctx->gpu_id)); + dh::device_vector ridx; + HostDeviceVector nptr; + HostDeviceVector nidx; + + EncodeTreeLeafDevice(ctx, position, &ridx, &nptr, &nidx, *p_tree); + + if (nptr.Empty()) { + std::vector quantiles; + UpdateLeafValues(&quantiles, nidx.ConstHostVector(), p_tree); + } + + HostDeviceVector quantiles; + predt.SetDevice(ctx->gpu_id); + auto d_predt = predt.ConstDeviceSpan(); + auto d_labels = info.labels.View(ctx->gpu_id); + + auto d_row_index = dh::ToSpan(ridx); + auto seg_beg = nptr.DevicePointer(); + auto seg_end = seg_beg + nptr.Size(); + auto val_beg = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(size_t i) { + auto predt = d_predt[d_row_index[i]]; + auto y = d_labels(d_row_index[i]); + return y - predt; + }); + auto val_end = val_beg + d_labels.Size(); + CHECK_EQ(nidx.Size() + 1, nptr.Size()); + if (info.weights_.Empty()) { + common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles); + } else { + info.weights_.SetDevice(ctx->gpu_id); + auto d_weights = info.weights_.ConstDeviceSpan(); + CHECK_EQ(d_weights.size(), d_row_index.size()); + auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index)); + common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it, + w_it + d_weights.size(), &quantiles); + } + + UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), p_tree); +} +} // namespace detail +} // namespace obj +} // namespace xgboost diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h new file mode 100644 index 000000000000..85c041347cb9 --- /dev/null +++ b/src/objective/adaptive.h @@ -0,0 +1,83 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#pragma once + +#include +#include +#include + +#include "rabit/rabit.h" +#include "xgboost/generic_parameters.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/tree_model.h" + +namespace xgboost { +namespace obj { +namespace detail { +inline void FillMissingLeaf(std::vector const& maybe_missing, + std::vector* p_nidx, std::vector* p_nptr) { + auto& h_node_idx = *p_nidx; + auto& h_node_ptr = *p_nptr; + + for (auto leaf : maybe_missing) { + if (std::binary_search(h_node_idx.cbegin(), h_node_idx.cend(), leaf)) { + continue; + } + auto it = std::upper_bound(h_node_idx.cbegin(), h_node_idx.cend(), leaf); + auto pos = it - h_node_idx.cbegin(); + h_node_idx.insert(h_node_idx.cbegin() + pos, leaf); + h_node_ptr.insert(h_node_ptr.cbegin() + pos, h_node_ptr[pos]); + } +} + +inline void UpdateLeafValues(std::vector* p_quantiles, std::vector const nidx, + RegTree* p_tree) { + auto& tree = *p_tree; + auto& quantiles = *p_quantiles; + auto const& h_node_idx = nidx; + + size_t n_leaf{h_node_idx.size()}; + rabit::Allreduce(&n_leaf, 1); + CHECK(quantiles.empty() || quantiles.size() == n_leaf); + if (quantiles.empty()) { + quantiles.resize(n_leaf, std::numeric_limits::quiet_NaN()); + } + + // number of workers that have valid quantiles + std::vector n_valids(quantiles.size()); + std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(), + [](float q) { return static_cast(!std::isnan(q)); }); + rabit::Allreduce(n_valids.data(), n_valids.size()); + // convert to 0 for all reduce + std::replace_if( + quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f); + // use the mean value + rabit::Allreduce(quantiles.data(), quantiles.size()); + for (size_t i = 0; i < n_leaf; ++i) { + if (n_valids[i] > 0) { + quantiles[i] /= static_cast(n_valids[i]); + } else { + // Use original leaf value if no worker can provide the quantile. + quantiles[i] = tree[h_node_idx[i]].LeafValue(); + } + } + + for (size_t i = 0; i < nidx.size(); ++i) { + auto nidx = h_node_idx[i]; + auto q = quantiles[i]; + CHECK(tree[nidx].IsLeaf()); + tree[nidx].SetLeaf(q); + } +} + +void UpdateTreeLeafDevice(Context const* ctx, common::Span position, + MetaInfo const& info, HostDeviceVector const& predt, float alpha, + RegTree* p_tree); + +void UpdateTreeLeafHost(Context const* ctx, std::vector const& position, + MetaInfo const& info, HostDeviceVector const& predt, float alpha, + RegTree* p_tree); +} // namespace detail +} // namespace obj +} // namespace xgboost diff --git a/src/objective/aft_obj.cu b/src/objective/aft_obj.cu index 0e2d9290f95c..5f2306dee082 100644 --- a/src/objective/aft_obj.cu +++ b/src/objective/aft_obj.cu @@ -34,11 +34,11 @@ DMLC_REGISTRY_FILE_TAG(aft_obj_gpu); class AFTObj : public ObjFunction { public: - void Configure(const std::vector >& args) override { + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return {ObjInfo::kSurvival, false}; } + ObjInfo Task() const override { return ObjInfo::kSurvival; } template void GetGradientImpl(const HostDeviceVector &preds, diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index e1f0df74d4e1..e062b2b48e3c 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -24,10 +24,8 @@ class HingeObj : public ObjFunction { public: HingeObj() = default; - void Configure( - const std::vector > &args) override {} - - ObjInfo Task() const override { return {ObjInfo::kRegression, false}; } + void Configure(Args const&) override {} + ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 4b912a81710d..312992ec59f2 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -46,7 +46,7 @@ class SoftmaxMultiClassObj : public ObjFunction { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return {ObjInfo::kClassification, false}; } + ObjInfo Task() const override { return ObjInfo::kClassification; } void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 0bbf6f6df26b..f1c8702102df 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2019 XGBoost contributors + * Copyright 2015-2022 XGBoost contributors */ #include #include @@ -750,11 +750,8 @@ class SortedLabelList : dh::SegmentSorter { template class LambdaRankObj : public ObjFunction { public: - void Configure(const std::vector >& args) override { - param_.UpdateAllowUnknown(args); - } - - ObjInfo Task() const override { return {ObjInfo::kRanking, false}; } + void Configure(Args const &args) override { param_.UpdateAllowUnknown(args); } + ObjInfo Task() const override { return ObjInfo::kRanking; } void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, diff --git a/src/objective/regression_loss.h b/src/objective/regression_loss.h index f92dfe2d47d7..f394432a8f28 100644 --- a/src/objective/regression_loss.h +++ b/src/objective/regression_loss.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2019 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ #define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_ @@ -38,7 +38,7 @@ struct LinearSquareLoss { static const char* DefaultEvalMetric() { return "rmse"; } static const char* Name() { return "reg:squarederror"; } - static ObjInfo Info() { return {ObjInfo::kRegression, true}; } + static ObjInfo Info() { return {ObjInfo::kRegression, true, false}; } }; struct SquaredLogError { @@ -65,7 +65,7 @@ struct SquaredLogError { static const char* Name() { return "reg:squaredlogerror"; } - static ObjInfo Info() { return {ObjInfo::kRegression, false}; } + static ObjInfo Info() { return ObjInfo::kRegression; } }; // logistic loss for probability regression task @@ -102,14 +102,14 @@ struct LogisticRegression { static const char* Name() { return "reg:logistic"; } - static ObjInfo Info() { return {ObjInfo::kRegression, false}; } + static ObjInfo Info() { return ObjInfo::kRegression; } }; // logistic loss for binary classification task struct LogisticClassification : public LogisticRegression { static const char* DefaultEvalMetric() { return "logloss"; } static const char* Name() { return "binary:logistic"; } - static ObjInfo Info() { return {ObjInfo::kBinary, false}; } + static ObjInfo Info() { return ObjInfo::kBinary; } }; // logistic loss, but predict un-transformed margin @@ -146,7 +146,7 @@ struct LogisticRaw : public LogisticRegression { static const char* Name() { return "binary:logitraw"; } - static ObjInfo Info() { return {ObjInfo::kRegression, false}; } + static ObjInfo Info() { return ObjInfo::kRegression; } }; } // namespace obj diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index fa294a5a5773..3dc4a7b82316 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -4,10 +4,10 @@ * \brief Definition of single-value regression and classification objectives. * \author Tianqi Chen, Kailong Chen */ - #include #include #include +#include #include #include @@ -19,12 +19,18 @@ #include "../common/threading_utils.h" #include "../common/transform.h" #include "./regression_loss.h" +#include "adaptive.h" +#include "xgboost/base.h" +#include "xgboost/data.h" +#include "xgboost/generic_parameters.h" #include "xgboost/host_device_vector.h" #include "xgboost/json.h" +#include "xgboost/linalg.h" #include "xgboost/parameter.h" #include "xgboost/span.h" #if defined(XGBOOST_USE_CUDA) +#include "../common/device_helpers.cuh" #include "../common/linalg_op.cuh" #endif // defined(XGBOOST_USE_CUDA) @@ -67,9 +73,7 @@ class RegLossObj : public ObjFunction { param_.UpdateAllowUnknown(args); } - struct ObjInfo Task() const override { - return Loss::Info(); - } + ObjInfo Task() const override { return Loss::Info(); } uint32_t Targets(MetaInfo const& info) const override { // Multi-target regression. @@ -209,7 +213,7 @@ class PseudoHuberRegression : public ObjFunction { public: void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } - struct ObjInfo Task() const override { return {ObjInfo::kRegression, false}; } + ObjInfo Task() const override { return ObjInfo::kRegression; } uint32_t Targets(MetaInfo const& info) const override { return std::max(static_cast(1), info.labels.Shape(1)); } @@ -286,9 +290,7 @@ class PoissonRegression : public ObjFunction { param_.UpdateAllowUnknown(args); } - struct ObjInfo Task() const override { - return {ObjInfo::kRegression, false}; - } + ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, @@ -378,12 +380,8 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") // cox regression for survival data (negative values mean they are censored) class CoxRegression : public ObjFunction { public: - void Configure( - const std::vector >&) override {} - - struct ObjInfo Task() const override { - return {ObjInfo::kRegression, false}; - } + void Configure(Args const&) override {} + ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, @@ -479,12 +477,8 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox") // gamma regression class GammaRegression : public ObjFunction { public: - void Configure( - const std::vector >&) override {} - - struct ObjInfo Task() const override { - return {ObjInfo::kRegression, false}; - } + void Configure(Args const&) override {} + ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector &preds, const MetaInfo &info, int, @@ -582,9 +576,7 @@ class TweedieRegression : public ObjFunction { metric_ = os.str(); } - struct ObjInfo Task() const override { - return {ObjInfo::kRegression, false}; - } + ObjInfo Task() const override { return ObjInfo::kRegression; } void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int, @@ -675,5 +667,65 @@ XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie") .describe("Tweedie regression for insurance data.") .set_body([]() { return new TweedieRegression(); }); +class MeanAbsoluteError : public ObjFunction { + public: + void Configure(Args const&) override {} + ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; } + + void GetGradient(HostDeviceVector const& preds, const MetaInfo& info, int iter, + HostDeviceVector* out_gpair) override { + CheckRegInputs(info, preds); + auto labels = info.labels.View(ctx_->gpu_id); + + out_gpair->SetDevice(ctx_->gpu_id); + out_gpair->Resize(info.labels.Size()); + auto gpair = linalg::MakeVec(out_gpair); + + preds.SetDevice(ctx_->gpu_id); + auto predt = linalg::MakeVec(&preds); + info.weights_.SetDevice(ctx_->gpu_id); + common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() + : info.weights_.ConstDeviceSpan()}; + + linalg::ElementWiseKernel(ctx_, labels, [=] XGBOOST_DEVICE(size_t i, float const y) mutable { + auto sign = [](auto x) { + return (x > static_cast(0)) - (x < static_cast(0)); + }; + auto sample_id = std::get<0>(linalg::UnravelIndex(i, labels.Shape())); + auto grad = sign(predt(i) - y) * weight[i]; + auto hess = weight[sample_id]; + gpair(i) = GradientPair{grad, hess}; + }); + } + + void UpdateTreeLeaf(HostDeviceVector const& position, MetaInfo const& info, + HostDeviceVector const& prediction, RegTree* p_tree) const override { + if (ctx_->IsCPU()) { + auto const& h_position = position.ConstHostVector(); + detail::UpdateTreeLeafHost(ctx_, h_position, info, prediction, 0.5, p_tree); + } else { +#if defined(XGBOOST_USE_CUDA) + position.SetDevice(ctx_->gpu_id); + auto d_position = position.ConstDeviceSpan(); + detail::UpdateTreeLeafDevice(ctx_, d_position, info, prediction, 0.5, p_tree); +#else + common::AssertGPUSupport(); +#endif // defined(XGBOOST_USE_CUDA) + } + } + + const char* DefaultEvalMetric() const override { return "mae"; } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String("reg:absoluteerror"); + } + + void LoadConfig(Json const& in) override {} +}; + +XGBOOST_REGISTER_OBJECTIVE(MeanAbsoluteError, "reg:absoluteerror") + .describe("Mean absoluate error.") + .set_body([]() { return new MeanAbsoluteError(); }); } // namespace obj } // namespace xgboost diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 1b5a5222229e..9470b6447512 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -1,13 +1,17 @@ /*! - * Copyright 2017-2019 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #pragma once +#include +#include #include "xgboost/base.h" #include "../../common/device_helpers.cuh" +#include "xgboost/generic_parameters.h" +#include "xgboost/task.h" +#include "xgboost/tree_model.h" namespace xgboost { namespace tree { - /*! \brief Count how many rows are assigned to left node. */ __forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) { #if __CUDACC_VER_MAJOR__ > 8 @@ -149,23 +153,48 @@ class RowPartitioner { } /** - * \brief Finalise the position of all training instances after tree - * construction is complete. Does not update any other meta information in - * this data structure, so should only be used at the end of training. + * \brief Finalise the position of all training instances after tree construction is + * complete. Does not update any other meta information in this data structure, so + * should only be used at the end of training. + * + * When the task requires update leaf, this function will copy the node index into + * p_out_position. The index is negated if it's being sampled in current iteration. * - * \param op Device lambda. Should provide the row index and current - * position as an argument and return the new position for this training - * instance. + * \param p_out_position Node index for each row. + * \param op Device lambda. Should provide the row index and current position as an + * argument and return the new position for this training instance. + * \param sampled A device lambda to inform the partitioner whether a row is sampled. */ - template - void FinalisePosition(FinalisePositionOpT op) { + template + void FinalisePosition(Context const* ctx, ObjInfo task, + HostDeviceVector* p_out_position, FinalisePositionOpT op, + Sampledp sampledp) { auto d_position = position_.Current(); const auto d_ridx = ridx_.Current(); + if (!task.UpdateTreeLeaf()) { + dh::LaunchN(position_.Size(), [=] __device__(size_t idx) { + auto position = d_position[idx]; + RowIndexT ridx = d_ridx[idx]; + bst_node_t new_position = op(ridx, position); + if (new_position == kIgnoredTreePosition) { + return; + } + d_position[idx] = new_position; + }); + return; + } + + p_out_position->SetDevice(ctx->gpu_id); + p_out_position->Resize(position_.Size()); + auto sorted_position = p_out_position->DevicePointer(); dh::LaunchN(position_.Size(), [=] __device__(size_t idx) { auto position = d_position[idx]; RowIndexT ridx = d_ridx[idx]; bst_node_t new_position = op(ridx, position); - if (new_position == kIgnoredTreePosition) return; + sorted_position[ridx] = sampledp(ridx) ? ~new_position : new_position; + if (new_position == kIgnoredTreePosition) { + return; + } d_position[idx] = new_position; }); } diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 053b485012bd..4e445a0680e5 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -390,7 +390,6 @@ void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_las CHECK(p_last_tree); auto const &tree = *p_last_tree; - auto const &snode = hist_evaluator.Stats(); auto evaluator = hist_evaluator.Evaluator(); CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); size_t n_nodes = p_last_tree->GetNodes().size(); @@ -401,9 +400,7 @@ void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_las common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) { if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) { auto const &rowset = part[nidx]; - auto const &stats = snode[nidx]; - auto leaf_value = - evaluator.CalcWeight(nidx, param, GradStats{stats.stats}) * param.learning_rate; + auto leaf_value = tree[nidx].LeafValue(); for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { out_preds(*it) += leaf_value; } diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 3bad6f7da4cc..4222cddb1ee9 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -19,6 +19,7 @@ #include "param.h" #include "xgboost/base.h" #include "xgboost/json.h" +#include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" namespace xgboost { @@ -154,6 +155,18 @@ class GloablApproxBuilder { monitor_->Stop(__func__); } + void LeafPartition(RegTree const &tree, common::Span hess, + std::vector *p_out_position) { + monitor_->Start(__func__); + if (!evaluator_.Task().UpdateTreeLeaf()) { + return; + } + for (auto const &part : partitioner_) { + part.LeafPartition(ctx_, tree, hess, p_out_position); + } + monitor_->Stop(__func__); + } + public: explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx, std::shared_ptr column_sampler, ObjInfo task, @@ -164,8 +177,8 @@ class GloablApproxBuilder { ctx_{ctx}, monitor_{monitor} {} - void UpdateTree(RegTree *p_tree, std::vector const &gpair, common::Span hess, - DMatrix *p_fmat) { + void UpdateTree(DMatrix *p_fmat, std::vector const &gpair, common::Span hess, + RegTree *p_tree, HostDeviceVector *p_out_position) { p_last_tree_ = p_tree; this->InitData(p_fmat, hess); @@ -231,6 +244,9 @@ class GloablApproxBuilder { driver.Push(best_splits.begin(), best_splits.end()); expand_set = driver.Pop(); } + + auto &h_position = p_out_position->HostVector(); + this->LeafPartition(tree, hess, &h_position); } }; @@ -275,6 +291,7 @@ class GlobalApproxUpdater : public TreeUpdater { sampled->resize(h_gpair.size()); std::copy(h_gpair.cbegin(), h_gpair.cend(), sampled->begin()); auto &rnd = common::GlobalRandom(); + if (param.subsample != 1.0) { CHECK(param.sampling_method != TrainParam::kGradientBased) << "Gradient based sampling is not supported for approx tree method."; @@ -292,6 +309,7 @@ class GlobalApproxUpdater : public TreeUpdater { char const *Name() const override { return "grow_histmaker"; } void Update(HostDeviceVector *gpair, DMatrix *m, + common::Span> out_position, const std::vector &trees) override { float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); @@ -313,12 +331,14 @@ class GlobalApproxUpdater : public TreeUpdater { cached_ = m; + size_t t_idx = 0; for (auto p_tree : trees) { if (hist_param_.single_precision_histogram) { - this->f32_impl_->UpdateTree(p_tree, h_gpair, hess, m); + this->f32_impl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]); } else { - this->f64_impl_->UpdateTree(p_tree, h_gpair, hess, m); + this->f64_impl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]); } + ++t_idx; } param_.learning_rate = lr; } @@ -335,6 +355,8 @@ class GlobalApproxUpdater : public TreeUpdater { } return true; } + + bool HasNodePosition() const override { return true; } }; DMLC_REGISTRY_FILE_TAG(grow_histmaker); diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h index ec54da19e5b0..bb37f99ec61d 100644 --- a/src/tree/updater_approx.h +++ b/src/tree/updater_approx.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 XGBoost contributors + * Copyright 2021-2022 XGBoost contributors * * \brief Implementation for the approx tree method. */ @@ -18,6 +18,7 @@ #include "hist/expand_entry.h" #include "hist/param.h" #include "param.h" +#include "xgboost/generic_parameters.h" #include "xgboost/json.h" #include "xgboost/tree_updater.h" @@ -122,6 +123,12 @@ class ApproxRowPartitioner { auto const &Partitions() const { return row_set_collection_; } + void LeafPartition(Context const *ctx, RegTree const &tree, common::Span hess, + std::vector *p_out_position) const { + partition_builder_.LeafPartition(ctx, tree, this->Partitions(), p_out_position, + [&](size_t idx) -> bool { return hess[idx] - .0f == .0f; }); + } + auto operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } auto const &operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index e3d716f2cba8..6d63a00a139a 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -96,9 +96,9 @@ class ColMaker: public TreeUpdater { } } - void Update(HostDeviceVector *gpair, - DMatrix* dmat, - const std::vector &trees) override { + void Update(HostDeviceVector *gpair, DMatrix *dmat, + common::Span> out_position, + const std::vector &trees) override { if (rabit::IsDistributed()) { LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't " "support distributed training."; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index cb7dd9b7e8e4..569188fd5374 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -11,6 +11,9 @@ #include #include +#include "xgboost/base.h" +#include "xgboost/data.h" +#include "xgboost/generic_parameters.h" #include "xgboost/host_device_vector.h" #include "xgboost/parameter.h" #include "xgboost/span.h" @@ -35,6 +38,8 @@ #include "gpu_hist/histogram.cuh" #include "gpu_hist/evaluate_splits.cuh" #include "gpu_hist/expand_entry.cuh" +#include "xgboost/task.h" +#include "xgboost/tree_model.h" namespace xgboost { namespace tree { @@ -161,9 +166,9 @@ template struct GPUHistMakerDevice { private: GPUHistEvaluator evaluator_; + Context const* ctx_; public: - int device_id; EllpackPageImpl const* page; common::Span feature_types; BatchParam batch_param; @@ -195,12 +200,12 @@ struct GPUHistMakerDevice { // Storing split categories for last node. dh::caching_device_vector node_categories; - GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page, + GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, common::Span _feature_types, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, BatchParam _batch_param) - : evaluator_{_param, n_features, _device_id}, - device_id(_device_id), + : evaluator_{_param, n_features, ctx->gpu_id}, + ctx_(ctx), page(_page), feature_types{_feature_types}, param(std::move(_param)), @@ -216,14 +221,15 @@ struct GPUHistMakerDevice { node_sum_gradients.resize(param.MaxNodes()); // Init histogram - hist.Init(device_id, page->Cuts().TotalBins()); - monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); - feature_groups.reset(new FeatureGroups( - page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), sizeof(GradientSumT))); + hist.Init(ctx_->gpu_id, page->Cuts().TotalBins()); + monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id)); + feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, + dh::MaxSharedMemoryOptin(ctx_->gpu_id), + sizeof(GradientSumT))); } ~GPUHistMakerDevice() { // NOLINT - dh::safe_cuda(cudaSetDevice(device_id)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); } // Reset values for each update iteration @@ -235,10 +241,10 @@ struct GPUHistMakerDevice { this->column_sampler.Init(num_columns, info.feature_weights.HostVector(), param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); - dh::safe_cuda(cudaSetDevice(device_id)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param, - device_id); + ctx_->gpu_id); this->interaction_constraints.Reset(); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{}); @@ -256,7 +262,7 @@ struct GPUHistMakerDevice { histogram_rounding = CreateRoundingFactor(this->gpair); row_partitioner.reset(); // Release the device memory first before reallocating - row_partitioner.reset(new RowPartitioner(device_id, sample.sample_rows)); + row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows)); hist.Reset(); } @@ -264,10 +270,10 @@ struct GPUHistMakerDevice { int nidx = RegTree::kRoot; GPUTrainingParam gpu_param(param); auto sampled_features = column_sampler.GetFeatureSet(0); - sampled_features->SetDevice(device_id); + sampled_features->SetDevice(ctx_->gpu_id); common::Span feature_set = interaction_constraints.Query(sampled_features->DeviceSpan(), nidx); - auto matrix = page->GetDeviceAccessor(device_id); + auto matrix = page->GetDeviceAccessor(ctx_->gpu_id); EvaluateSplitInputs inputs{nidx, root_sum, gpu_param, @@ -287,14 +293,14 @@ struct GPUHistMakerDevice { dh::TemporaryArray splits_out(2); GPUTrainingParam gpu_param(param); auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx)); - left_sampled_features->SetDevice(device_id); + left_sampled_features->SetDevice(ctx_->gpu_id); common::Span left_feature_set = interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx); auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx)); - right_sampled_features->SetDevice(device_id); + right_sampled_features->SetDevice(ctx_->gpu_id); common::Span right_feature_set = interaction_constraints.Query(right_sampled_features->DeviceSpan(), left_nidx); - auto matrix = page->GetDeviceAccessor(device_id); + auto matrix = page->GetDeviceAccessor(ctx_->gpu_id); EvaluateSplitInputs left{left_nidx, candidate.split.left_sum, @@ -325,8 +331,8 @@ struct GPUHistMakerDevice { hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); - BuildGradientHistogram(page->GetDeviceAccessor(device_id), - feature_groups->DeviceAccessor(device_id), gpair, + BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id), + feature_groups->DeviceAccessor(ctx_->gpu_id), gpair, d_ridx, d_node_hist, histogram_rounding); } @@ -351,7 +357,7 @@ struct GPUHistMakerDevice { void UpdatePosition(int nidx, RegTree* p_tree) { RegTree::Node split_node = (*p_tree)[nidx]; auto split_type = p_tree->NodeSplitType(nidx); - auto d_matrix = page->GetDeviceAccessor(device_id); + auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto node_cats = dh::ToSpan(node_categories); row_partitioner->UpdatePosition( @@ -384,7 +390,8 @@ struct GPUHistMakerDevice { // After tree update is finished, update the position of all training // instances to their final leaf. This information is used later to update the // prediction cache - void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) { + void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task, + HostDeviceVector* p_out_position) { dh::TemporaryArray d_nodes(p_tree->GetNodes().size()); dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), d_nodes.size() * sizeof(RegTree::Node), @@ -405,17 +412,21 @@ struct GPUHistMakerDevice { if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { row_partitioner.reset(); // Release the device memory first before reallocating - row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); + row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, p_fmat->Info().num_row_)); + } + if (task.UpdateTreeLeaf() && !p_fmat->SingleColBlock() && param.subsample != 1.0) { + // see comment in the `FinalisePositionInPage`. + LOG(FATAL) << "Current objective function can not be used with subsampled external memory."; } if (page->n_rows == p_fmat->Info().num_row_) { - FinalisePositionInPage(page, dh::ToSpan(d_nodes), - dh::ToSpan(d_split_types), dh::ToSpan(d_categories), - dh::ToSpan(d_categories_segments)); + FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types), + dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task, + p_out_position); } else { - for (auto& batch : p_fmat->GetBatches(batch_param)) { - FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), - dh::ToSpan(d_split_types), dh::ToSpan(d_categories), - dh::ToSpan(d_categories_segments)); + for (auto const& batch : p_fmat->GetBatches(batch_param)) { + FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), dh::ToSpan(d_split_types), + dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task, + p_out_position); } } } @@ -424,9 +435,13 @@ struct GPUHistMakerDevice { const common::Span d_nodes, common::Span d_feature_types, common::Span categories, - common::Span categories_segments) { - auto d_matrix = page->GetDeviceAccessor(device_id); + common::Span categories_segments, + ObjInfo task, + HostDeviceVector* p_out_position) { + auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); + auto d_gpair = this->gpair; row_partitioner->FinalisePosition( + ctx_, task, p_out_position, [=] __device__(size_t row_id, int position) { // What happens if user prune the tree? if (!d_matrix.IsInRange(row_id)) { @@ -457,13 +472,20 @@ struct GPUHistMakerDevice { } node = d_nodes[position]; } + return position; + }, + [d_gpair] __device__(size_t ridx) { + // FIXME(jiamingy): Doesn't work when sampling is used with external memory as + // the sampler compacts the gradient vector. + return d_gpair[ridx].GetHess() - .0f == 0.f; }); } - void UpdatePredictionCache(linalg::VectorView out_preds_d) { - dh::safe_cuda(cudaSetDevice(device_id)); - CHECK_EQ(out_preds_d.DeviceIdx(), device_id); + void UpdatePredictionCache(linalg::VectorView out_preds_d, RegTree const* p_tree) { + CHECK(p_tree); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); + CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id); auto d_ridx = row_partitioner->GetRows(); GPUTrainingParam param_d(param); @@ -476,12 +498,15 @@ struct GPUHistMakerDevice { auto d_node_sum_gradients = device_node_sum_gradients.data().get(); auto tree_evaluator = evaluator_.GetEvaluator(); - dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(int local_idx) mutable { - int pos = d_position[local_idx]; - bst_float weight = - tree_evaluator.CalcWeight(pos, param_d, GradStats{d_node_sum_gradients[pos]}); - static_assert(!std::is_const::value, ""); - out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate; + auto const& h_nodes = p_tree->GetNodes(); + dh::caching_device_vector nodes(h_nodes.size()); + dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(), + h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); + auto d_nodes = dh::ToSpan(nodes); + dh::LaunchN(d_ridx.size(), [=] XGBOOST_DEVICE(size_t idx) mutable { + bst_node_t nidx = d_position[idx]; + auto weight = d_nodes[nidx].LeafValue(); + out_preds_d(d_ridx[idx]) += weight; }); row_partitioner.reset(); } @@ -610,7 +635,8 @@ struct GPUHistMakerDevice { } void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo task, - RegTree* p_tree, dh::AllReducer* reducer) { + RegTree* p_tree, dh::AllReducer* reducer, + HostDeviceVector* p_out_position) { auto& tree = *p_tree; Driver driver(static_cast(param.grow_policy)); @@ -641,7 +667,7 @@ struct GPUHistMakerDevice { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); - // Only create child entries if needed + // Only create child entries if needed_ if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { monitor.Start("UpdatePosition"); @@ -671,7 +697,7 @@ struct GPUHistMakerDevice { } monitor.Start("FinalisePosition"); - this->FinalisePosition(p_tree, p_fmat); + this->FinalisePosition(p_tree, p_fmat, task, p_out_position); monitor.Stop("FinalisePosition"); } }; @@ -682,7 +708,7 @@ class GPUHistMakerSpecialised { explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {}; void Configure(const Args& args, GenericParameter const* generic_param) { param_.UpdateAllowUnknown(args); - generic_param_ = generic_param; + ctx_ = generic_param; hist_maker_param_.UpdateAllowUnknown(args); dh::CheckComputeCapability(); @@ -694,20 +720,24 @@ class GPUHistMakerSpecialised { } void Update(HostDeviceVector* gpair, DMatrix* dmat, + common::Span> out_position, const std::vector& trees) { monitor_.Start("Update"); // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); + // build tree try { + size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { - this->UpdateTree(gpair, dmat, tree); + this->UpdateTree(gpair, dmat, tree, &out_position[t_idx]); if (hist_maker_param_.debug_synchronize) { this->CheckTreesSynchronized(tree); } + ++t_idx; } dh::safe_cuda(cudaGetLastError()); } catch (const std::exception& e) { @@ -719,41 +749,36 @@ class GPUHistMakerSpecialised { } void InitDataOnce(DMatrix* dmat) { - device_ = generic_param_->gpu_id; - CHECK_GE(device_, 0) << "Must have at least one device"; + CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device"; info_ = &dmat->Info(); - reducer_.Init({device_}); // NOLINT + reducer_.Init({ctx_->gpu_id}); // NOLINT // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); BatchParam batch_param{ - device_, + ctx_->gpu_id, param_.max_bin, }; auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); - dh::safe_cuda(cudaSetDevice(device_)); - info_->feature_types.SetDevice(device_); - maker.reset(new GPUHistMakerDevice(device_, - page, - info_->feature_types.ConstDeviceSpan(), - info_->num_row_, - param_, - column_sampling_seed, - info_->num_col_, - batch_param)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); + info_->feature_types.SetDevice(ctx_->gpu_id); + maker.reset(new GPUHistMakerDevice( + ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_, + column_sampling_seed, info_->num_col_, batch_param)); p_last_fmat_ = dmat; initialised_ = true; } - void InitData(DMatrix* dmat) { + void InitData(DMatrix* dmat, RegTree const* p_tree) { if (!initialised_) { monitor_.Start("InitDataOnce"); this->InitDataOnce(dmat); monitor_.Stop("InitDataOnce"); } + p_last_tree_ = p_tree; } // Only call this method for testing @@ -771,13 +796,14 @@ class GPUHistMakerSpecialised { CHECK(*local_tree == reference_tree); } - void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree) { + void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree, + HostDeviceVector* p_out_position) { monitor_.Start("InitData"); - this->InitData(p_fmat); + this->InitData(p_fmat, p_tree); monitor_.Stop("InitData"); - gpair->SetDevice(device_); - maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_); + gpair->SetDevice(ctx_->gpu_id); + maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_, p_out_position); } bool UpdatePredictionCache(const DMatrix *data, @@ -786,7 +812,7 @@ class GPUHistMakerSpecialised { return false; } monitor_.Start("UpdatePredictionCache"); - maker->UpdatePredictionCache(p_out_preds); + maker->UpdatePredictionCache(p_out_preds, p_last_tree_); monitor_.Stop("UpdatePredictionCache"); return true; } @@ -800,12 +826,12 @@ class GPUHistMakerSpecialised { bool initialised_ { false }; GPUHistMakerTrainParam hist_maker_param_; - GenericParameter const* generic_param_; + Context const* ctx_; dh::AllReducer reducer_; DMatrix* p_last_fmat_ { nullptr }; - int device_{-1}; + RegTree const* p_last_tree_{nullptr}; ObjInfo task_; common::Monitor monitor_; @@ -859,17 +885,17 @@ class GPUHistMaker : public TreeUpdater { } void Update(HostDeviceVector* gpair, DMatrix* dmat, + common::Span> out_position, const std::vector& trees) override { if (hist_maker_param_.single_precision_histogram) { - float_maker_->Update(gpair, dmat, trees); + float_maker_->Update(gpair, dmat, out_position, trees); } else { - double_maker_->Update(gpair, dmat, trees); + double_maker_->Update(gpair, dmat, out_position, trees); } } - bool - UpdatePredictionCache(const DMatrix *data, - linalg::VectorView p_out_preds) override { + bool UpdatePredictionCache(const DMatrix* data, + linalg::VectorView p_out_preds) override { if (hist_maker_param_.single_precision_histogram) { return float_maker_->UpdatePredictionCache(data, p_out_preds); } else { @@ -881,6 +907,8 @@ class GPUHistMaker : public TreeUpdater { return "grow_gpu_hist"; } + bool HasNodePosition() const override { return true; } + private: GPUHistMakerTrainParam hist_maker_param_; ObjInfo task_; diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 0a85d2d73832..27fc42455d2c 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -24,9 +24,9 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker); class HistMaker: public BaseMaker { public: - void Update(HostDeviceVector *gpair, - DMatrix *p_fmat, - const std::vector &trees) override { + void Update(HostDeviceVector *gpair, DMatrix *p_fmat, + common::Span> out_position, + const std::vector &trees) override { interaction_constraints_.Configure(param_, p_fmat->Info().num_col_); // rescale learning rate according to size of trees float lr = param_.learning_rate; diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index f71f1c698cb9..dcda4a3b34a2 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -50,9 +50,9 @@ class TreePruner: public TreeUpdater { } // update the tree, do pruning - void Update(HostDeviceVector *gpair, - DMatrix *p_fmat, - const std::vector &trees) override { + void Update(HostDeviceVector* gpair, DMatrix* p_fmat, + common::Span> out_position, + const std::vector& trees) override { pruner_monitor_.Start("PrunerUpdate"); // rescale learning rate according to size of trees float lr = param_.learning_rate; @@ -61,7 +61,7 @@ class TreePruner: public TreeUpdater { this->DoPrune(tree); } param_.learning_rate = lr; - syncher_->Update(gpair, p_fmat, trees); + syncher_->Update(gpair, p_fmat, out_position, trees); pruner_monitor_.Stop("PrunerUpdate"); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 0e1b6db47691..011733b4582a 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -36,6 +36,7 @@ void QuantileHistMaker::Configure(const Args &args) { } void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, + common::Span> out_position, const std::vector &trees) { // rescale learning rate according to size of trees float lr = param_.learning_rate; @@ -53,12 +54,15 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *d } } + size_t t_idx{0}; for (auto p_tree : trees) { + auto &t_row_position = out_position[t_idx]; if (hist_maker_param_.single_precision_histogram) { - this->float_builder_->UpdateTree(gpair, dmat, p_tree); + this->float_builder_->UpdateTree(gpair, dmat, p_tree, &t_row_position); } else { - this->double_builder_->UpdateTree(gpair, dmat, p_tree); + this->double_builder_->UpdateTree(gpair, dmat, p_tree, &t_row_position); } + ++t_idx; } param_.learning_rate = lr; @@ -169,13 +173,29 @@ void QuantileHistMaker::Builder::BuildHistogram( } } +template +void QuantileHistMaker::Builder::LeafPartition( + RegTree const &tree, common::Span gpair, + std::vector *p_out_position) { + monitor_->Start(__func__); + if (!evaluator_->Task().UpdateTreeLeaf()) { + return; + } + for (auto const &part : partitioner_) { + part.LeafPartition(ctx_, tree, gpair, p_out_position); + } + monitor_->Stop(__func__); +} + template void QuantileHistMaker::Builder::ExpandTree( - DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { + DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h, + HostDeviceVector *p_out_position) { monitor_->Start(__func__); Driver driver(static_cast(param_.grow_policy)); driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); + auto const &tree = *p_tree; bst_node_t num_leaves{1}; auto expand_set = driver.Pop(); @@ -208,7 +228,6 @@ void QuantileHistMaker::Builder::ExpandTree( std::vector best_splits; if (!valid_candidates.empty()) { this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair_h); - auto const &tree = *p_tree; for (auto const &candidate : valid_candidates) { int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); @@ -228,12 +247,15 @@ void QuantileHistMaker::Builder::ExpandTree( expand_set = driver.Pop(); } + auto &h_out_position = p_out_position->HostVector(); + this->LeafPartition(tree, gpair_h, &h_out_position); monitor_->Stop(__func__); } template -void QuantileHistMaker::Builder::UpdateTree(HostDeviceVector *gpair, - DMatrix *p_fmat, RegTree *p_tree) { +void QuantileHistMaker::Builder::UpdateTree( + HostDeviceVector *gpair, DMatrix *p_fmat, RegTree *p_tree, + HostDeviceVector *p_out_position) { monitor_->Start(__func__); std::vector *gpair_ptr = &(gpair->HostVector()); @@ -246,8 +268,7 @@ void QuantileHistMaker::Builder::UpdateTree(HostDeviceVectorInitData(p_fmat, *p_tree, gpair_ptr); - ExpandTree(p_fmat, p_tree, *gpair_ptr); - + ExpandTree(p_fmat, p_tree, *gpair_ptr, p_out_position); monitor_->Stop(__func__); } diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 3c03a371ebfb..6d5919abb75f 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -17,6 +17,7 @@ #include #include +#include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" @@ -214,6 +215,15 @@ class HistRowPartitioner { size_t Size() const { return std::distance(row_set_collection_.begin(), row_set_collection_.end()); } + + void LeafPartition(Context const* ctx, RegTree const& tree, + common::Span gpair, + std::vector* p_out_position) const { + partition_builder_.LeafPartition( + ctx, tree, this->Partitions(), p_out_position, + [&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; }); + } + auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } }; @@ -228,8 +238,8 @@ class QuantileHistMaker: public TreeUpdater { explicit QuantileHistMaker(ObjInfo task) : task_{task} {} void Configure(const Args& args) override; - void Update(HostDeviceVector* gpair, - DMatrix* dmat, + void Update(HostDeviceVector* gpair, DMatrix* dmat, + common::Span> out_position, const std::vector& trees) override; bool UpdatePredictionCache(const DMatrix *data, @@ -266,6 +276,8 @@ class QuantileHistMaker: public TreeUpdater { return "grow_quantile_histmaker"; } + bool HasNodePosition() const override { return true; } + protected: CPUHistMakerTrainParam hist_maker_param_; // training parameter @@ -289,7 +301,8 @@ class QuantileHistMaker: public TreeUpdater { monitor_->Init("Quantile::Builder"); } // update one tree, growing - void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree); + void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree, + HostDeviceVector* p_out_position); bool UpdatePredictionCache(DMatrix const* data, linalg::VectorView out_preds) const; @@ -308,7 +321,11 @@ class QuantileHistMaker: public TreeUpdater { std::vector const& valid_candidates, std::vector const& gpair); - void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h); + void LeafPartition(RegTree const& tree, common::Span gpair, + std::vector* p_out_position); + + void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h, + HostDeviceVector* p_out_position); private: const size_t n_trees_; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index d17c1e1444f7..8e82ae9f914c 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -42,9 +42,9 @@ class TreeRefresher: public TreeUpdater { return true; } // update the tree, do pruning - void Update(HostDeviceVector *gpair, - DMatrix *p_fmat, - const std::vector &trees) override { + void Update(HostDeviceVector *gpair, DMatrix *p_fmat, + common::Span> out_position, + const std::vector &trees) override { if (trees.size() == 0) return; const std::vector &gpair_h = gpair->ConstHostVector(); // thread temporal space diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 4f7c7a1a85a6..a4c1486fbf90 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -31,9 +31,9 @@ class TreeSyncher: public TreeUpdater { return "prune"; } - void Update(HostDeviceVector* , - DMatrix*, - const std::vector &trees) override { + void Update(HostDeviceVector*, DMatrix*, + common::Span> out_position, + const std::vector& trees) override { if (rabit::GetWorldSize() == 1) return; std::string s_model; common::MemoryBufferStream fs(&s_model); diff --git a/tests/cpp/common/test_stats.cc b/tests/cpp/common/test_stats.cc new file mode 100644 index 000000000000..2a1e375c0f20 --- /dev/null +++ b/tests/cpp/common/test_stats.cc @@ -0,0 +1,58 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#include +#include + +#include "../../../src/common/stats.h" + +namespace xgboost { +namespace common { +TEST(Stats, Quantile) { + { + linalg::Tensor arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, Context::kCpuId); + std::vector index{0, 2, 3, 4, 6}; + auto h_arr = arr.HostView(); + auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(index[i]); }); + auto end = beg + index.size(); + auto q = Quantile(0.40f, beg, end); + ASSERT_EQ(q, 26.0); + + q = Quantile(0.20f, beg, end); + ASSERT_EQ(q, 16.0); + + q = Quantile(0.10f, beg, end); + ASSERT_EQ(q, 15.0); + } + + { + std::vector vec{1., 2., 3., 4., 5.}; + auto beg = MakeIndexTransformIter([&](size_t i) { return vec[i]; }); + auto end = beg + vec.size(); + auto q = Quantile(0.5f, beg, end); + ASSERT_EQ(q, 3.); + } +} + +TEST(Stats, WeightedQuantile) { + linalg::Tensor arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, Context::kCpuId); + linalg::Tensor weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, Context::kCpuId); + + auto h_arr = arr.HostView(); + auto h_weight = weight.HostView(); + + auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(i); }); + auto end = beg + arr.Size(); + auto w = MakeIndexTransformIter([&](size_t i) { return h_weight(i); }); + + auto q = WeightedQuantile(0.50f, beg, end, w); + ASSERT_EQ(q, 3); + + q = WeightedQuantile(0.0, beg, end, w); + ASSERT_EQ(q, 1); + + q = WeightedQuantile(1.0, beg, end, w); + ASSERT_EQ(q, 5); +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/common/test_stats.cu b/tests/cpp/common/test_stats.cu new file mode 100644 index 000000000000..eee92921d931 --- /dev/null +++ b/tests/cpp/common/test_stats.cu @@ -0,0 +1,77 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#include +#include +#include + +#include "../../../src/common/stats.cuh" +#include "xgboost/base.h" +#include "xgboost/generic_parameters.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/linalg.h" + +namespace xgboost { +namespace common { +namespace { +class StatsGPU : public ::testing::Test { + private: + linalg::Tensor arr_{ + {1.f, 2.f, 3.f, 4.f, 5.f, + 2.f, 4.f, 5.f, 3.f, 1.f}, + {10}, 0}; + linalg::Tensor indptr_{{0, 5, 10}, {3}, 0}; + HostDeviceVector resutls_; + using TestSet = std::vector>; + Context ctx_; + + void Check(float expected) { + auto const& h_results = resutls_.HostVector(); + ASSERT_EQ(h_results.size(), indptr_.Size() - 1); + ASSERT_EQ(h_results.front(), expected); + EXPECT_EQ(h_results.back(), expected); + } + + public: + void SetUp() override { ctx_.gpu_id = 0; } + void Weighted() { + auto d_arr = arr_.View(0); + auto d_key = indptr_.View(0); + + auto key_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] __device__(size_t i) { return d_key(i); }); + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return d_arr(i); }); + linalg::Tensor weights{{10}, 0}; + linalg::ElementWiseTransformDevice(weights.View(0), + [=] XGBOOST_DEVICE(size_t, float) { return 1.0; }); + auto w_it = weights.Data()->ConstDevicePointer(); + for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) { + SegmentedWeightedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it, + val_it + arr_.Size(), w_it, w_it + weights.Size(), &resutls_); + this->Check(pair.second); + } + } + + void NonWeighted() { + auto d_arr = arr_.View(0); + auto d_key = indptr_.View(0); + + auto key_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] __device__(size_t i) { return d_key(i); }); + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { return d_arr(i); }); + + for (auto const& pair : TestSet{{0.0f, 1.0f}, {0.5f, 3.0f}, {1.0f, 5.0f}}) { + SegmentedQuantile(&ctx_, pair.first, key_it, key_it + indptr_.Size(), val_it, + val_it + arr_.Size(), &resutls_); + this->Check(pair.second); + } + } +}; +} // anonymous namespace + +TEST_F(StatsGPU, Quantile) { this->NonWeighted(); } +TEST_F(StatsGPU, WeightedQuantile) { this->Weighted(); } +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index c416d134307c..f9fe7d38660d 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 XGBoost contributors + * Copyright 2019-2022 XGBoost contributors */ #include #include @@ -69,13 +69,13 @@ TEST(GBTree, PredictionCache) { auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); auto gpair = GenerateRandomGradients(kRows); PredictionCacheEntry out_predictions; - gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); + gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr); gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); ASSERT_EQ(1, out_predictions.version); std::vector first_iter = out_predictions.predictions.HostVector(); // Add 1 more boosted round - gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); + gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr); gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); ASSERT_EQ(2, out_predictions.version); // Update the cache for all rounds @@ -83,7 +83,7 @@ TEST(GBTree, PredictionCache) { gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); ASSERT_EQ(2, out_predictions.version); - gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); + gbtree.DoBoost(p_m.get(), &gpair, &out_predictions, nullptr); // drop the cache. gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2); ASSERT_EQ(0, out_predictions.version); diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 05c138781e0d..68faa09642ed 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -548,7 +548,7 @@ std::unique_ptr CreateTrainedGBM( PredictionCacheEntry predts; - gbm->DoBoost(p_dmat.get(), &gpair, &predts); + gbm->DoBoost(p_dmat.get(), &gpair, &predts, nullptr); return gbm; } diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index ef4529934337..a26f69476152 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -1,11 +1,14 @@ /*! - * Copyright 2017-2021 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #include -#include #include #include +#include + +#include "../../../src/objective/adaptive.h" #include "../helpers.h" + namespace xgboost { TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { @@ -378,4 +381,113 @@ TEST(Objective, CoxRegressionGPair) { { 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f}); } #endif + +TEST(Objective, DeclareUnifiedTest(AbsoluteError)) { + Context ctx = CreateEmptyGenericParam(GPUIDX); + std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", &ctx)}; + obj->Configure({}); + CheckConfigReload(obj, "reg:absoluteerror"); + + MetaInfo info; + std::vector labels{0.f, 3.f, 2.f, 5.f, 4.f, 7.f}; + info.labels.Reshape(6, 1); + info.labels.Data()->HostVector() = labels; + info.num_row_ = labels.size(); + HostDeviceVector predt{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + info.weights_.HostVector() = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + + CheckObjFunction(obj, predt.HostVector(), labels, info.weights_.HostVector(), + {1.f, -1.f, 1.f, -1.f, 1.f, -1.f}, info.weights_.HostVector()); + + RegTree tree; + tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + + HostDeviceVector position(labels.size(), 0); + auto& h_position = position.HostVector(); + for (size_t i = 0; i < labels.size(); ++i) { + if (i < labels.size() / 2) { + h_position[i] = 1; // left + } else { + h_position[i] = 2; // right + } + } + + auto& h_predt = predt.HostVector(); + for (size_t i = 0; i < h_predt.size(); ++i) { + h_predt[i] = labels[i] + i; + } + + obj->UpdateTreeLeaf(position, info, predt, &tree); + ASSERT_EQ(tree[1].LeafValue(), -1); + ASSERT_EQ(tree[2].LeafValue(), -4); +} + +TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) { + Context ctx = CreateEmptyGenericParam(GPUIDX); + std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", &ctx)}; + obj->Configure({}); + + MetaInfo info; + info.labels.Reshape(16, 1); + info.num_row_ = info.labels.Size(); + CHECK_EQ(info.num_row_, 16); + auto h_labels = info.labels.HostView().Values(); + std::iota(h_labels.begin(), h_labels.end(), 0); + HostDeviceVector predt(h_labels.size()); + auto& h_predt = predt.HostVector(); + for (size_t i = 0; i < h_predt.size(); ++i) { + h_predt[i] = h_labels[i] + i; + } + + HostDeviceVector position(info.labels.Size(), 0); + auto& h_position = position.HostVector(); + for (int32_t i = 0; i < 3; ++i) { + h_position[i] = ~i; // negation for sampled nodes. + } + for (size_t i = 3; i < 8; ++i) { + h_position[i] = 3; + } + // empty leaf for node 4 + for (size_t i = 8; i < 13; ++i) { + h_position[i] = 5; + } + for (size_t i = 13; i < h_labels.size(); ++i) { + h_position[i] = 6; + } + + RegTree tree; + tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + tree.ExpandNode(1, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + tree.ExpandNode(2, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + ASSERT_EQ(tree.GetNumLeaves(), 4); + + auto empty_leaf = tree[4].LeafValue(); + obj->UpdateTreeLeaf(position, info, predt, &tree); + ASSERT_EQ(tree[3].LeafValue(), -5); + ASSERT_EQ(tree[4].LeafValue(), empty_leaf); + ASSERT_EQ(tree[5].LeafValue(), -10); + ASSERT_EQ(tree[6].LeafValue(), -14); +} + +TEST(Adaptive, DeclareUnifiedTest(MissingLeaf)) { + std::vector missing{1, 3}; + + std::vector h_nidx = {2, 4, 5}; + std::vector h_nptr = {0, 4, 8, 16}; + + obj::detail::FillMissingLeaf(missing, &h_nidx, &h_nptr); + + ASSERT_EQ(h_nidx[0], missing[0]); + ASSERT_EQ(h_nidx[2], missing[1]); + ASSERT_EQ(h_nidx[1], 2); + ASSERT_EQ(h_nidx[3], 4); + ASSERT_EQ(h_nidx[4], 5); + + ASSERT_EQ(h_nptr[0], 0); + ASSERT_EQ(h_nptr[1], 0); // empty + ASSERT_EQ(h_nptr[2], 4); + ASSERT_EQ(h_nptr[3], 4); // empty + ASSERT_EQ(h_nptr[4], 8); + ASSERT_EQ(h_nptr[5], 16); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 1a466ed3ff10..f43747abdd9e 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #include #include @@ -222,7 +222,7 @@ void TestUpdatePredictionCache(bool use_subsampling) { PredictionCacheEntry predtion_cache; predtion_cache.predictions.Resize(kRows*kClasses, 0); // after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values - gbm->DoBoost(dmat.get(), &gpair, &predtion_cache); + gbm->DoBoost(dmat.get(), &gpair, &predtion_cache, nullptr); PredictionCacheEntry out_predictions; // perform fair prediction on the same input data, should be equal to cached result diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 9b16cca5362d..c8aaf82dcb3e 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -1,7 +1,8 @@ /*! - * Copyright 2019-2021 by XGBoost Contributors + * Copyright 2019-2022 by XGBoost Contributors */ #include +#include #include #include @@ -10,6 +11,10 @@ #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../helpers.h" +#include "xgboost/base.h" +#include "xgboost/generic_parameters.h" +#include "xgboost/task.h" +#include "xgboost/tree_model.h" namespace xgboost { namespace tree { @@ -103,17 +108,58 @@ TEST(RowPartitioner, Basic) { TestUpdatePosition(); } void TestFinalise() { const int kNumRows = 10; - RowPartitioner rp(0, kNumRows); - rp.FinalisePosition([=]__device__(RowPartitioner::RowIndexT ridx, int position) - { - return 7; - }); - auto position = rp.GetPositionHost(); - for(auto p:position) + + ObjInfo task{ObjInfo::kRegression, false, false}; + HostDeviceVector position; + Context ctx; + ctx.gpu_id = 0; + { - EXPECT_EQ(p, 7); + RowPartitioner rp(0, kNumRows); + rp.FinalisePosition( + &ctx, task, &position, + [=] __device__(RowPartitioner::RowIndexT ridx, int position) { return 7; }, + [] XGBOOST_DEVICE(size_t idx) { return false; }); + + auto position = rp.GetPositionHost(); + for (auto p : position) { + EXPECT_EQ(p, 7); + } + } + + /** + * Test for sampling. + */ + dh::device_vector hess(kNumRows); + for (size_t i = 0; i < hess.size(); ++i) { + // removed rows, 0, 3, 6, 9 + if (i % 3 == 0) { + hess[i] = 0; + } else { + hess[i] = i; + } + } + + auto d_hess = dh::ToSpan(hess); + + RowPartitioner rp(0, kNumRows); + rp.FinalisePosition( + &ctx, task, &position, + [] __device__(RowPartitioner::RowIndexT ridx, bst_node_t position) { + return ridx % 2 == 0 ? 1 : 2; + }, + [d_hess] __device__(size_t ridx) { return d_hess[ridx] - 0.f == 0.f; }); + + auto const& h_position = position.ConstHostVector(); + for (size_t ridx = 0; ridx < h_position.size(); ++ridx) { + if (ridx % 3 == 0) { + ASSERT_LT(h_position[ridx], 0); + } else { + ASSERT_EQ(h_position[ridx], ridx % 2 == 0 ? 1 : 2); + } } } + TEST(RowPartitioner, Finalise) { TestFinalise(); } void TestIncorrectRow() { diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index a37c0973627e..2e2fd4a0b3d7 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -26,7 +26,7 @@ TEST(Approx, Partitioner) { std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(), [](auto gpair) { return gpair.GetHess(); }); - for (auto const &page : Xy->GetBatches({64, hess, true})) { + for (auto const& page : Xy->GetBatches({64, hess, true})) { bst_feature_t const split_ind = 0; { auto min_value = page.cut.MinValues()[split_ind]; @@ -44,9 +44,9 @@ TEST(Approx, Partitioner) { float split_value = page.cut.Values().at(ptr / 2); RegTree tree; GetSplit(&tree, split_value, &candidates); - auto left_nidx = tree[RegTree::kRoot].LeftChild(); partitioner.UpdatePosition(&ctx, page, candidates, &tree); + auto left_nidx = tree[RegTree::kRoot].LeftChild(); auto elem = partitioner[left_nidx]; ASSERT_LT(elem.Size(), n_samples); ASSERT_GT(elem.Size(), 1); @@ -54,6 +54,7 @@ TEST(Approx, Partitioner) { auto value = page.cut.Values().at(page.index[*it]); ASSERT_LE(value, split_value); } + auto right_nidx = tree[RegTree::kRoot].RightChild(); elem = partitioner[right_nidx]; for (auto it = elem.begin; it != elem.end; ++it) { @@ -63,5 +64,78 @@ TEST(Approx, Partitioner) { } } } +namespace { +void TestLeafPartition(size_t n_samples) { + size_t const n_features = 2, base_rowid = 0; + common::RowSetCollection row_set; + ApproxRowPartitioner partitioner{n_samples, base_rowid}; + + auto Xy = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + GenericParameter ctx; + std::vector candidates{{0, 0, 0.4}}; + RegTree tree; + std::vector hess(n_samples, 0); + // emulate sampling + auto not_sampled = [](size_t i) { + size_t const kSampleFactor{3}; + return i % kSampleFactor != 0; + }; + size_t n{0}; + for (size_t i = 0; i < hess.size(); ++i) { + if (not_sampled(i)) { + hess[i] = 1.0f; + ++n; + } + } + + std::vector h_nptr; + float split_value{0}; + for (auto const& page : Xy->GetBatches({Context::kCpuId, 64})) { + bst_feature_t const split_ind = 0; + auto ptr = page.cut.Ptrs()[split_ind + 1]; + split_value = page.cut.Values().at(ptr / 2); + GetSplit(&tree, split_value, &candidates); + partitioner.UpdatePosition(&ctx, page, candidates, &tree); + std::vector position; + partitioner.LeafPartition(&ctx, tree, hess, &position); + std::sort(position.begin(), position.end()); + size_t beg = std::distance( + position.begin(), + std::find_if(position.begin(), position.end(), [&](bst_node_t nidx) { return nidx >= 0; })); + std::vector nptr; + common::RunLengthEncode(position.cbegin() + beg, position.cend(), &nptr); + std::transform(nptr.begin(), nptr.end(), nptr.begin(), [&](size_t x) { return x + beg; }); + auto n_uniques = std::unique(position.begin() + beg, position.end()) - (position.begin() + beg); + ASSERT_EQ(nptr.size(), n_uniques + 1); + ASSERT_EQ(nptr[0], beg); + ASSERT_EQ(nptr.back(), n_samples); + + h_nptr = nptr; + } + + if (h_nptr.front() == n_samples) { + return; + } + + ASSERT_GE(h_nptr.size(), 2); + + for (auto const& page : Xy->GetBatches()) { + auto batch = page.GetView(); + size_t left{0}; + for (size_t i = 0; i < batch.Size(); ++i) { + if (not_sampled(i) && batch[i].front().fvalue < split_value) { + left++; + } + } + ASSERT_EQ(left, h_nptr[1] - h_nptr[0]); // equal to number of sampled assigned to left + } +} +} // anonymous namespace + +TEST(Approx, LeafPartition) { + for (auto n_samples : {0ul, 1ul, 128ul, 256ul}) { + TestLeafPartition(n_samples); + } +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 82f40465deb2..3c93c283917a 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2021 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #include #include @@ -13,6 +13,7 @@ #include "../helpers.h" #include "../histogram_helpers.h" +#include "xgboost/generic_parameters.h" #include "xgboost/json.h" #include "../../../src/data/sparse_page_source.h" #include "../../../src/tree/updater_gpu_hist.cu" @@ -22,7 +23,6 @@ namespace xgboost { namespace tree { - TEST(GpuHist, DeviceHistogram) { // Ensures that node allocates correctly after reaching `kStopGrowingSize`. dh::safe_cuda(cudaSetDevice(0)); @@ -81,8 +81,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, - kNCols, kNCols, batch_param); + Context ctx{CreateEmptyGenericParam(0)}; + GPUHistMakerDevice maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols, + batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); HostDeviceVector gpair(kNRows); @@ -158,14 +159,14 @@ TEST(GpuHist, ApplySplit) { BatchParam bparam; bparam.gpu_id = 0; bparam.max_bin = 3; + Context ctx{CreateEmptyGenericParam(0)}; for (auto& ellpack : m->GetBatches(bparam)){ auto impl = ellpack.Impl(); HostDeviceVector feature_types(10, FeatureType::kCategorical); feature_types.SetDevice(bparam.gpu_id); tree::GPUHistMakerDevice updater( - 0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, - bparam); + &ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam); updater.ApplySplit(candidate, &tree); ASSERT_EQ(tree.GetSplitTypes().size(), 3); @@ -224,8 +225,9 @@ TEST(GpuHist, EvaluateRootSplit) { // Initialize GPUHistMakerDevice auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker( - 0, page.get(), {}, kNRows, param, kNCols, kNCols, batch_param); + Context ctx{CreateEmptyGenericParam(0)}; + GPUHistMakerDevice maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols, + batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {}; @@ -348,7 +350,8 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, GenericParameter generic_param(CreateEmptyGenericParam(0)); hist_maker.Configure(args, &generic_param); - hist_maker.Update(gpair, dmat, {tree}); + std::vector> position(1); + hist_maker.Update(gpair, dmat, common::Span>{position}, {tree}); auto cache = linalg::VectorView{preds->DeviceSpan(), {preds->Size()}, 0}; hist_maker.UpdatePredictionCache(dmat, cache); } @@ -483,7 +486,7 @@ TEST(GpuHist, ExternalMemoryWithSampling) { auto preds_h = preds.ConstHostVector(); auto preds_ext_h = preds_ext.ConstHostVector(); for (int i = 0; i < kRows; i++) { - EXPECT_NEAR(preds_h[i], preds_ext_h[i], 1e-3); + ASSERT_NEAR(preds_h[i], preds_ext_h[i], 1e-3); } } diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index 56878b159d4b..90dc0a411294 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -39,7 +39,8 @@ TEST(GrowHistMaker, InteractionConstraint) { updater->Configure(Args{ {"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); - updater->Update(&gradients, p_dmat.get(), {&tree}); + std::vector> position(1); + updater->Update(&gradients, p_dmat.get(), position, {&tree}); ASSERT_EQ(tree.NumExtraNodes(), 4); ASSERT_EQ(tree[0].SplitIndex(), 1); @@ -55,7 +56,8 @@ TEST(GrowHistMaker, InteractionConstraint) { std::unique_ptr updater{ TreeUpdater::Create("grow_histmaker", ¶m, ObjInfo{ObjInfo::kRegression})}; updater->Configure(Args{{"num_feature", std::to_string(kCols)}}); - updater->Update(&gradients, p_dmat.get(), {&tree}); + std::vector> position(1); + updater->Update(&gradients, p_dmat.get(), position, {&tree}); ASSERT_EQ(tree.NumExtraNodes(), 10); ASSERT_EQ(tree[0].SplitIndex(), 1); diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc index ebe66cf575b3..3e30e0699358 100644 --- a/tests/cpp/tree/test_prediction_cache.cc +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -77,7 +77,8 @@ class TestPredictionCache : public ::testing::Test { std::vector trees{&tree}; auto gpair = GenerateRandomGradients(n_samples_); updater->Configure(Args{{"max_bin", "64"}}); - updater->Update(&gpair, Xy_.get(), trees); + std::vector> position(1); + updater->Update(&gpair, Xy_.get(), position, trees); HostDeviceVector out_prediction_cached; out_prediction_cached.SetDevice(ctx.gpu_id); out_prediction_cached.Resize(n_samples_); diff --git a/tests/cpp/tree/test_prune.cc b/tests/cpp/tree/test_prune.cc index dc6a8da21d72..77f78b1399d9 100644 --- a/tests/cpp/tree/test_prune.cc +++ b/tests/cpp/tree/test_prune.cc @@ -43,22 +43,23 @@ TEST(Updater, Prune) { pruner->Configure(cfg); // loss_chg < min_split_loss; + std::vector> position(trees.size()); tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); - pruner->Update(&gpair, p_dmat.get(), trees); + pruner->Update(&gpair, p_dmat.get(), position, trees); ASSERT_EQ(tree.NumExtraNodes(), 0); // loss_chg > min_split_loss; tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f, /*left_sum=*/0.0f, /*right_sum=*/0.0f); - pruner->Update(&gpair, p_dmat.get(), trees); + pruner->Update(&gpair, p_dmat.get(), position, trees); ASSERT_EQ(tree.NumExtraNodes(), 2); // loss_chg == min_split_loss; tree.Stat(0).loss_chg = 10; - pruner->Update(&gpair, p_dmat.get(), trees); + pruner->Update(&gpair, p_dmat.get(), position, trees); ASSERT_EQ(tree.NumExtraNodes(), 2); @@ -74,7 +75,7 @@ TEST(Updater, Prune) { /*left_sum=*/0.0f, /*right_sum=*/0.0f); cfg.emplace_back(std::make_pair("max_depth", "1")); pruner->Configure(cfg); - pruner->Update(&gpair, p_dmat.get(), trees); + pruner->Update(&gpair, p_dmat.get(), position, trees); ASSERT_EQ(tree.NumExtraNodes(), 2); @@ -84,7 +85,7 @@ TEST(Updater, Prune) { /*left_sum=*/0.0f, /*right_sum=*/0.0f); cfg.emplace_back(std::make_pair("min_split_loss", "0")); pruner->Configure(cfg); - pruner->Update(&gpair, p_dmat.get(), trees); + pruner->Update(&gpair, p_dmat.get(), position, trees); ASSERT_EQ(tree.NumExtraNodes(), 2); } } // namespace tree diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index 5b71f0841e19..f0abd0a871aa 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -44,7 +44,8 @@ TEST(Updater, Refresh) { tree.Stat(cright).base_weight = 1.3; refresher->Configure(cfg); - refresher->Update(&gpair, p_dmat.get(), trees); + std::vector> position; + refresher->Update(&gpair, p_dmat.get(), position, trees); bst_float constexpr kEps = 1e-6; ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps); diff --git a/tests/cpp/tree/test_tree_stat.cc b/tests/cpp/tree/test_tree_stat.cc index 772420ce0f23..723ca34ebc93 100644 --- a/tests/cpp/tree/test_tree_stat.cc +++ b/tests/cpp/tree/test_tree_stat.cc @@ -27,7 +27,8 @@ class UpdaterTreeStatTest : public ::testing::Test { up->Configure(Args{}); RegTree tree; tree.param.num_feature = kCols; - up->Update(&gpairs_, p_dmat_.get(), {&tree}); + std::vector> position(1); + up->Update(&gpairs_, p_dmat_.get(), position, {&tree}); tree.WalkTree([&tree](bst_node_t nidx) { if (tree[nidx].IsLeaf()) { @@ -87,13 +88,15 @@ class UpdaterEtaTest : public ::testing::Test { RegTree tree_0; { tree_0.param.num_feature = kCols; - up_0->Update(&gpairs_, p_dmat_.get(), {&tree_0}); + std::vector> position(1); + up_0->Update(&gpairs_, p_dmat_.get(), position, {&tree_0}); } RegTree tree_1; { tree_1.param.num_feature = kCols; - up_1->Update(&gpairs_, p_dmat_.get(), {&tree_1}); + std::vector> position(1); + up_1->Update(&gpairs_, p_dmat_.get(), position, {&tree_1}); } tree_0.WalkTree([&](bst_node_t nidx) { if (tree_0[nidx].IsLeaf()) { @@ -149,7 +152,8 @@ class TestMinSplitLoss : public ::testing::Test { up->Configure(args); RegTree tree; - up->Update(&gpair_, dmat_.get(), {&tree}); + std::vector> position(1); + up->Update(&gpair_, dmat_.get(), position, {&tree}); auto n_nodes = tree.NumExtraNodes(); return n_nodes; diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 38f4db07d366..4e41e637f7de 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -249,6 +249,8 @@ def predict_df(x): tm.dataset_strategy, shap_parameter_strategy) @settings(deadline=None, print_blob=True) def test_shap(self, num_rounds, dataset, param): + if dataset.name.endswith("-l1"): # not supported by the exact tree method + return param.update({"predictor": "gpu_predictor", "gpu_id": 0}) param = dataset.set_params(param) dmat = dataset.get_dmat() @@ -263,6 +265,8 @@ def test_shap(self, num_rounds, dataset, param): tm.dataset_strategy, shap_parameter_strategy) @settings(deadline=None, max_examples=20, print_blob=True) def test_shap_interactions(self, num_rounds, dataset, param): + if dataset.name.endswith("-l1"): # not supported by the exact tree method + return param.update({"predictor": "gpu_predictor", "gpu_id": 0}) param = dataset.set_params(param) dmat = dataset.get_dmat() diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index a3427b566360..e9d2bf06e229 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -90,6 +90,8 @@ def test_gpu_hist_device_dmatrix(self, param, num_rounds, dataset): tm.dataset_strategy) @settings(deadline=None, print_blob=True) def test_external_memory(self, param, num_rounds, dataset): + if dataset.name.endswith("-l1"): + return # We cannot handle empty dataset yet assume(len(dataset.y) > 0) param['tree_method'] = 'gpu_hist' diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 1f0339e913ec..2074ce073648 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -1,7 +1,7 @@ """Copyright 2019-2022 XGBoost contributors""" import sys import os -from typing import Type, TypeVar, Any, Dict, List, Tuple +from typing import Type, TypeVar, Any, Dict, List import pytest import numpy as np import asyncio @@ -198,9 +198,19 @@ def run_gpu_hist( dtrain=m, num_boost_round=num_rounds, evals=[(m, "train")], - )["history"] + )["history"]["train"][dataset.metric] note(history) - assert tm.non_increasing(history["train"][dataset.metric]) + + # See note on `ObjFunction::UpdateTreeLeaf`. + update_leaf = dataset.name.endswith("-l1") + if update_leaf and len(history) == 2: + assert history[0] + 1e-2 >= history[-1] + return + if update_leaf and len(history) > 2: + assert history[0] >= history[-1] + return + else: + assert tm.non_increasing(history) @pytest.mark.skipif(**tm.no_cudf()) @@ -305,8 +315,7 @@ def test_dask_classifier( def test_empty_dmatrix(self, local_cuda_cluster: LocalCUDACluster) -> None: with Client(local_cuda_cluster) as client: - parameters = {'tree_method': 'gpu_hist', - 'debug_synchronize': True} + parameters = {'tree_method': 'gpu_hist', 'debug_synchronize': True} run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters) diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index cdf40d843b1a..4b56d37d4493 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -40,6 +40,8 @@ class TestTreeMethod: tm.dataset_strategy) @settings(deadline=None, print_blob=True) def test_exact(self, param, num_rounds, dataset): + if dataset.name.endswith("-l1"): + return param['tree_method'] = 'exact' param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), num_rounds) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 4e80409d4764..cbee3d72b254 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -35,6 +35,7 @@ import dask.array as da from xgboost.dask import DaskDMatrix +dask.config.set({"distributed.scheduler.allowed-failures": False}) if hasattr(HealthCheck, 'function_scoped_fixture'): suppress = [HealthCheck.function_scoped_fixture] @@ -666,7 +667,8 @@ def test_empty_dmatrix_training_continuation(client: "Client") -> None: def run_empty_dmatrix_reg(client: "Client", parameters: dict) -> None: def _check_outputs(out: xgb.dask.TrainReturnT, predictions: np.ndarray) -> None: assert isinstance(out['booster'], xgb.dask.Booster) - assert len(out['history']['validation']['rmse']) == 2 + for _, v in out['history']['validation'].items(): + assert len(v) == 2 assert isinstance(predictions, np.ndarray) assert predictions.shape[0] == 1 @@ -867,6 +869,8 @@ def test_empty_dmatrix(tree_method) -> None: parameters = {'tree_method': tree_method} run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters) + parameters = {'tree_method': tree_method, "objective": "reg:absoluteerror"} + run_empty_dmatrix_reg(client, parameters) async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainReturnT: @@ -1285,7 +1289,12 @@ def is_stump(): def minimum_bin(): return "max_bin" in params and params["max_bin"] == 2 - if minimum_bin() and is_stump(): + # See note on `ObjFunction::UpdateTreeLeaf`. + update_leaf = dataset.name.endswith("-l1") + if update_leaf and len(history) >= 2: + assert history[0] >= history[-1] + return + elif minimum_bin() and is_stump(): assert tm.non_increasing(history, tolerance=1e-3) else: assert tm.non_increasing(history) @@ -1305,7 +1314,7 @@ def test_hist( dataset=tm.dataset_strategy) @settings(deadline=None, suppress_health_check=suppress, print_blob=True) def test_approx( - self, client: "Client", params: Dict, dataset: tm.TestDataset + self, client: "Client", params: Dict, dataset: tm.TestDataset ) -> None: num_rounds = 30 self.run_updater_test(client, params, num_rounds, dataset, 'approx') diff --git a/tests/python/testing.py b/tests/python/testing.py index 64417af42ab9..8633e4caa52d 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -327,6 +327,9 @@ def make_categorical( TestDataset( "calif_housing", get_california_housing, "reg:squarederror", "rmse" ), + TestDataset( + "calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae" + ), TestDataset("digits", get_digits, "multi:softmax", "mlogloss"), TestDataset("cancer", get_cancer, "binary:logistic", "logloss"), TestDataset( @@ -336,6 +339,7 @@ def make_categorical( "rmse", ), TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"), + TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"), TestDataset( "empty", lambda: (np.empty((0, 100)), np.empty(0)),