Skip to content

Commit

Permalink
[POC] Experimental support for l1 error. (#7812)
Browse files Browse the repository at this point in the history
Support adaptive tree, a feature supported by both sklearn and lightgbm.  The tree leaf is recomputed based on residue of labels and predictions after construction.

For l1 error, the optimal value is the median (50 percentile).

This is marked as experimental support for the following reasons:
- The value is not well defined for distributed training, where we might have empty leaves for local workers. Right now I just use the original leaf value for computing the average with other workers, which might cause significant errors.
- Some follow-ups are required, for exact, pruner, and optimization for quantile function. Also, we need to calculate the initial estimation.
  • Loading branch information
trivialfis committed Apr 26, 2022
1 parent ad06172 commit fdf533f
Show file tree
Hide file tree
Showing 64 changed files with 1,724 additions and 333 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -24,6 +24,7 @@
#include "../src/objective/rank_obj.cc"
#include "../src/objective/hinge.cc"
#include "../src/objective/aft_obj.cc"
#include "../src/objective/adaptive.cc"

// gbms
#include "../src/gbm/gbm.cc"
Expand Down
9 changes: 8 additions & 1 deletion doc/model.schema
Expand Up @@ -400,7 +400,6 @@
"reg_loss_param"
]
},

{
"type": "object",
"properties": {
Expand Down Expand Up @@ -433,6 +432,14 @@
"tweedie_regression_param"
]
},
{
"properties": {
"name": {
"const": "reg:absoluteerror"
}
},
"type": "object"
},
{
"type": "object",
"properties": {
Expand Down
1 change: 1 addition & 0 deletions doc/parameter.rst
Expand Up @@ -349,6 +349,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``reg:squaredlogerror``: regression with squared log loss :math:`\frac{1}{2}[log(pred + 1) - log(label + 1)]^2`. All input labels are required to be greater than -1. Also, see metric ``rmsle`` for possible issue with this objective.
- ``reg:logistic``: logistic regression.
- ``reg:pseudohubererror``: regression with Pseudo Huber loss, a twice differentiable alternative to absolute loss.
- ``reg:absoluteerror``: Regression with L1 error. When tree model is used, leaf value is refreshed after tree construction.
- ``binary:logistic``: logistic regression for binary classification, output probability
- ``binary:logitraw``: logistic regression for binary classification, output score before logistic transformation
- ``binary:hinge``: hinge loss for binary classification. This makes predictions of 0 or 1, rather than producing probabilities.
Expand Down
5 changes: 2 additions & 3 deletions include/xgboost/gbm.h
Expand Up @@ -90,9 +90,8 @@ class GradientBooster : public Model, public Configurable {
* \param prediction The output prediction cache entry that needs to be updated.
* the booster may change content of gpair
*/
virtual void DoBoost(DMatrix* p_fmat,
HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry*) = 0;
virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry*, ObjFunction const* obj) = 0;

/*!
* \brief generate predictions for given feature matrix
Expand Down
8 changes: 6 additions & 2 deletions include/xgboost/linalg.h
Expand Up @@ -670,9 +670,13 @@ class Tensor {
* See \ref TensorView for parameters of this constructor.
*/
template <typename I, int32_t D>
explicit Tensor(I const (&shape)[D], int32_t device) {
explicit Tensor(I const (&shape)[D], int32_t device)
: Tensor{common::Span<I const, D>{shape}, device} {}

template <typename I, size_t D>
explicit Tensor(common::Span<I const, D> shape, int32_t device) {
// No device unroll as this is a host only function.
std::copy(shape, shape + D, shape_);
std::copy(shape.data(), shape.data() + D, shape_);
for (auto i = D; i < kDim; ++i) {
shape_[i] = 1;
}
Expand Down
20 changes: 19 additions & 1 deletion include/xgboost/objective.h
@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2019 by Contributors
* Copyright 2014-2022 by Contributors
* \file objective.h
* \brief interface of objective function used by xgboost.
* \author Tianqi Chen, Kailong Chen
Expand All @@ -22,6 +22,8 @@

namespace xgboost {

class RegTree;

/*! \brief interface of objective function */
class ObjFunction : public Configurable {
protected:
Expand Down Expand Up @@ -88,6 +90,22 @@ class ObjFunction : public Configurable {
return 1;
}

/**
* \brief Update the leaf values after a tree is built. Needed for objectives with 0
* hessian.
*
* Note that the leaf update is not well defined for distributed training as XGBoost
* computes only an average of quantile between workers. This breaks when some leaf
* have no sample assigned in a local worker.
*
* \param position The leaf index for each rows.
* \param info MetaInfo providing labels and weights.
* \param prediction Model prediction after transformation.
* \param p_tree Tree that needs to be updated.
*/
virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, RegTree* p_tree) const {}

/*!
* \brief Create an objective function according to name.
* \param tparam Generic parameters.
Expand Down
9 changes: 7 additions & 2 deletions include/xgboost/task.h
Expand Up @@ -33,13 +33,18 @@ struct ObjInfo {
} task;
// Does the objective have constant hessian value?
bool const_hess{false};
bool zero_hess{false};

explicit ObjInfo(Task t) : task{t} {}
ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {}
ObjInfo(Task t) : task{t} {} // NOLINT
ObjInfo(Task t, bool khess, bool zhess) : task{t}, const_hess{khess}, zero_hess(zhess) {}

XGBOOST_DEVICE bool UseOneHot() const {
return (task != ObjInfo::kRegression && task != ObjInfo::kBinary);
}
/**
* \brief Use adaptive tree if the objective doesn't have valid hessian value.
*/
XGBOOST_DEVICE bool UpdateTreeLeaf() const { return zero_hess; }
};
} // namespace xgboost
#endif // XGBOOST_TASK_H_
15 changes: 11 additions & 4 deletions include/xgboost/tree_updater.h
Expand Up @@ -49,18 +49,25 @@ class TreeUpdater : public Configurable {
* existing trees.
*/
virtual bool CanModifyTree() const { return false; }
/*!
* \brief Wether the out_position in `Update` is valid. This determines whether adaptive
* tree can be used.
*/
virtual bool HasNodePosition() const { return false; }
/*!
* \brief perform update to the tree models
* \param gpair the gradient pair statistics of the data
* \param data The data matrix passed to the updater.
* \param trees references the trees to be updated, updater will change the content of trees
* \param out_position The leaf index for each row. The index is negated if that row is
* removed during sampling. So the 3th node is ~3.
* \param out_trees references the trees to be updated, updater will change the content of trees
* note: all the trees in the vector are updated, with the same statistics,
* but maybe different random seeds, usually one tree is passed in at a time,
* there can be multiple trees when we train random forest style model
*/
virtual void Update(HostDeviceVector<GradientPair>* gpair,
DMatrix* data,
const std::vector<RegTree*>& trees) = 0;
virtual void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* data,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree*>& out_trees) = 0;

/*!
* \brief determines whether updater has enough knowledge about a given dataset
Expand Down
10 changes: 3 additions & 7 deletions plugin/example/custom_obj.cc
@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2019 by Contributors
* Copyright 2015-2022 by Contributors
* \file custom_metric.cc
* \brief This is an example to define plugin of xgboost.
* This plugin defines the additional metric function.
Expand Down Expand Up @@ -31,13 +31,9 @@ DMLC_REGISTER_PARAMETER(MyLogisticParam);
// Implement the interface.
class MyLogistic : public ObjFunction {
public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
}
void Configure(const Args& args) override { param_.UpdateAllowUnknown(args); }

struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}
ObjInfo Task() const override { return ObjInfo::kRegression; }

void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
Expand Down
99 changes: 93 additions & 6 deletions src/common/common.h
@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2018 by Contributors
* Copyright 2015-2022 by XGBoost Contributors
* \file common.h
* \brief Common utilities
*/
Expand All @@ -14,12 +14,12 @@
#include <exception>
#include <functional>
#include <limits>
#include <type_traits>
#include <vector>
#include <string>
#include <sstream>
#include <numeric>
#include <sstream>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#if defined(__CUDACC__)
#include <thrust/system/cuda/error.h>
Expand Down Expand Up @@ -164,6 +164,67 @@ class Range {
Iterator end_;
};

/**
* \brief Transform iterator that takes an index and calls transform operator.
*
* This is CPU-only right now as taking host device function as operator complicates the
* code. For device side one can use `thrust::transform_iterator` instead.
*/
template <typename Fn>
class IndexTransformIter {
size_t iter_{0};
Fn fn_;

public:
using iterator_category = std::random_access_iterator_tag; // NOLINT
using value_type = std::result_of_t<Fn(size_t)>; // NOLINT
using difference_type = detail::ptrdiff_t; // NOLINT
using reference = std::add_lvalue_reference_t<value_type>; // NOLINT
using pointer = std::add_pointer_t<value_type>; // NOLINT

public:
/**
* \param op Transform operator, takes a size_t index as input.
*/
explicit IndexTransformIter(Fn &&op) : fn_{op} {}
IndexTransformIter(IndexTransformIter const &) = default;

value_type operator*() const { return fn_(iter_); }

auto operator-(IndexTransformIter const &that) const { return iter_ - that.iter_; }

IndexTransformIter &operator++() {
iter_++;
return *this;
}
IndexTransformIter operator++(int) {
auto ret = *this;
++(*this);
return ret;
}
IndexTransformIter &operator+=(difference_type n) {
iter_ += n;
return *this;
}
IndexTransformIter &operator-=(difference_type n) {
(*this) += -n;
return *this;
}
IndexTransformIter operator+(difference_type n) const {
auto ret = *this;
return ret += n;
}
IndexTransformIter operator-(difference_type n) const {
auto ret = *this;
return ret -= n;
}
};

template <typename Fn>
auto MakeIndexTransformIter(Fn&& fn) {
return IndexTransformIter<Fn>(std::forward<Fn>(fn));
}

int AllVisibleGPUs();

inline void AssertGPUSupport() {
Expand Down Expand Up @@ -191,13 +252,39 @@ std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {

struct OptionalWeights {
Span<float const> weights;
float dft{1.0f};
float dft{1.0f}; // fixme: make this compile time constant

explicit OptionalWeights(Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {}

XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
};

/**
* Last index of a group in a CSR style of index pointer.
*/
template <typename Indexable>
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
return indptr[group + 1] - 1;
}

/**
* \brief Run length encode on CPU, input must be sorted.
*/
template <typename Iter, typename Idx>
void RunLengthEncode(Iter begin, Iter end, std::vector<Idx> *p_out) {
auto &out = *p_out;
out = std::vector<Idx>{0};
size_t n = std::distance(begin, end);
for (size_t i = 1; i < n; ++i) {
if (begin[i] != begin[i - 1]) {
out.push_back(i);
}
}
if (out.back() != n) {
out.push_back(n);
}
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_
40 changes: 39 additions & 1 deletion src/common/device_helpers.cuh
@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2021 XGBoost contributors
* Copyright 2017-2022 XGBoost contributors
*/
#pragma once
#include <thrust/device_ptr.h>
Expand Down Expand Up @@ -1537,6 +1537,43 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
}

/**
* \brief Different from the above one, this one can handle cases where segment doesn't
* start from 0, but as a result it uses comparison sort.
*/
template <typename SegIt, typename ValIt>
void SegmentedArgSort(SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end,
dh::device_vector<size_t> *p_sorted_idx) {
using Tup = thrust::tuple<int32_t, float>;
auto &sorted_idx = *p_sorted_idx;
size_t n = std::distance(val_begin, val_end);
sorted_idx.resize(n);
dh::Iota(dh::ToSpan(sorted_idx));
dh::device_vector<Tup> keys(sorted_idx.size());
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) -> Tup {
int32_t leaf_idx;
if (i < *seg_begin) {
leaf_idx = -1;
} else {
leaf_idx = dh::SegmentId(seg_begin, seg_end, i);
}
auto residue = val_begin[i];
return thrust::make_tuple(leaf_idx, residue);
});
dh::XGBCachingDeviceAllocator<char> caching;
thrust::copy(thrust::cuda::par(caching), key_it, key_it + keys.size(), keys.begin());

dh::XGBDeviceAllocator<char> alloc;
thrust::stable_sort_by_key(thrust::cuda::par(alloc), keys.begin(), keys.end(), sorted_idx.begin(),
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
if (thrust::get<0>(l) != thrust::get<0>(r)) {
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
}
return thrust::get<1>(l) < thrust::get<1>(r); // residue
});
}

class CUDAStreamView;

class CUDAEvent {
Expand Down Expand Up @@ -1600,5 +1637,6 @@ class CUDAStream {
}

CUDAStreamView View() const { return CUDAStreamView{stream_}; }
void Sync() { this->View().Sync(); }
};
} // namespace dh
3 changes: 2 additions & 1 deletion src/common/linalg_op.cuh
Expand Up @@ -13,6 +13,7 @@ namespace xgboost {
namespace linalg {
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
static_assert(std::is_void<std::result_of_t<Fn(size_t, T&)>>::value,
"For function with return, use transform instead.");
if (t.Contiguous()) {
Expand Down Expand Up @@ -40,7 +41,7 @@ void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_
}

template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
ctx->IsCPU() ? ElementWiseKernelHost(t, ctx->Threads(), fn) : ElementWiseKernelDevice(t, fn);
}
} // namespace linalg
Expand Down

0 comments on commit fdf533f

Please sign in to comment.