Skip to content

Commit

Permalink
Calculate base_score based on input labels for mae. (#8107)
Browse files Browse the repository at this point in the history
Fit an intercept as base score for abs loss.
  • Loading branch information
trivialfis committed Sep 20, 2022
1 parent 4f42aa5 commit fffb1fc
Show file tree
Hide file tree
Showing 42 changed files with 999 additions and 343 deletions.
13 changes: 7 additions & 6 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -75,19 +75,20 @@
#include "../src/collective/communicator.cc"

// common
#include "../src/common/common.cc"
#include "../src/common/column_matrix.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc"
#include "../src/common/timer.cc"
#include "../src/common/quantile.cc"
#include "../src/common/host_device_vector.cc"
#include "../src/common/column_matrix.cc"
#include "../src/common/common.cc"
#include "../src/common/hist_util.cc"
#include "../src/common/host_device_vector.cc"
#include "../src/common/io.cc"
#include "../src/common/json.cc"
#include "../src/common/numeric.cc"
#include "../src/common/pseudo_huber.cc"
#include "../src/common/quantile.cc"
#include "../src/common/random.cc"
#include "../src/common/survival_util.cc"
#include "../src/common/threading_utils.cc"
#include "../src/common/timer.cc"
#include "../src/common/version.cc"

// c_api
Expand Down
30 changes: 23 additions & 7 deletions include/xgboost/learner.h
Expand Up @@ -8,10 +8,9 @@
#ifndef XGBOOST_LEARNER_H_
#define XGBOOST_LEARNER_H_

#include <dmlc/any.h>
#include <xgboost/base.h>
#include <xgboost/feature_map.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/generic_parameters.h> // Context
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>
#include <xgboost/predictor.h>
Expand Down Expand Up @@ -274,7 +273,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
/**
* \brief Return the context object of this Booster.
*/
virtual GenericParameter const* Ctx() const = 0;
virtual Context const* Ctx() const = 0;
/*!
* \brief Get configuration arguments currently stored by the learner
* \return Key-value pairs representing configuration arguments
Expand All @@ -289,7 +288,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
/*! \brief The evaluation metrics used to evaluate the model. */
std::vector<std::unique_ptr<Metric> > metrics_;
/*! \brief Training parameter. */
GenericParameter generic_parameters_;
Context ctx_;
};

struct LearnerModelParamLegacy;
Expand All @@ -298,8 +297,14 @@ struct LearnerModelParamLegacy;
* \brief Basic Model Parameters, used to describe the booster.
*/
struct LearnerModelParam {
/* \brief global bias */
bst_float base_score { 0.5f };
private:
/**
* \brief Global bias, this is just a scalar value but can be extended to vector when we
* support multi-class and multi-target.
*/
linalg::Tensor<float, 1> base_score_;

public:
/* \brief number of features */
uint32_t num_feature { 0 };
/* \brief number of classes, if it is multi-class classification */
Expand All @@ -310,7 +315,18 @@ struct LearnerModelParam {
LearnerModelParam() = default;
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
// this one as an immutable copy.
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t);
LearnerModelParam(Context const* ctx, LearnerModelParamLegacy const& user_param,
linalg::Tensor<float, 1> base_margin, ObjInfo t);
LearnerModelParam(LearnerModelParamLegacy const& user_param, ObjInfo t);
LearnerModelParam(bst_feature_t n_features, linalg::Tensor<float, 1> base_margin,
uint32_t n_groups)
: base_score_{std::move(base_margin)}, num_feature{n_features}, num_output_group{n_groups} {}

linalg::TensorView<float const, 1> BaseScore(Context const* ctx) const;
linalg::TensorView<float const, 1> BaseScore(int32_t device) const;

void Copy(LearnerModelParam const& that);

/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0; }
};
Expand Down
56 changes: 48 additions & 8 deletions include/xgboost/linalg.h
Expand Up @@ -8,6 +8,7 @@

#include <dmlc/endian.h>
#include <xgboost/base.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/json.h>
#include <xgboost/span.h>
Expand All @@ -16,6 +17,7 @@
#include <cassert>
#include <limits>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -213,6 +215,22 @@ LINALG_HD decltype(auto) constexpr Apply(Fn &&f, Tup &&t) {
constexpr auto kSize = std::tuple_size<Tup>::value;
return Apply(std::forward<Fn>(f), std::forward<Tup>(t), std::make_index_sequence<kSize>{});
}

/**
* C++ 17 conjunction
*/
template <class...>
struct Conjunction : std::true_type {};
template <class B1>
struct Conjunction<B1> : B1 {};
template <class B1, class... Bn>
struct Conjunction<B1, Bn...> : std::conditional_t<bool(B1::value), Conjunction<Bn...>, B1> {};

template <typename... Index>
using IsAllIntegral = Conjunction<std::is_integral<std::remove_reference_t<Index>>...>;

template <typename... Index>
using EnableIfIntegral = std::enable_if_t<IsAllIntegral<Index...>::value>;
} // namespace detail

/**
Expand Down Expand Up @@ -406,7 +424,7 @@ class TensorView {
*
* \endcode
*/
template <typename... Index>
template <typename... Index, detail::EnableIfIntegral<Index...> * = nullptr>
LINALG_HD T &operator()(Index &&...index) {
static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
Expand All @@ -416,7 +434,7 @@ class TensorView {
/**
* \brief Index the tensor to obtain a scalar value.
*/
template <typename... Index>
template <typename... Index, detail::EnableIfIntegral<Index...> * = nullptr>
LINALG_HD T const &operator()(Index &&...index) const {
static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
Expand Down Expand Up @@ -656,7 +674,7 @@ class Tensor {
}
if (device >= 0) {
data_.SetDevice(device);
data_.DevicePointer(); // Pull to device;
data_.ConstDevicePointer(); // Pull to device;
}
CHECK_EQ(data_.Size(), detail::CalcSize(shape_));
}
Expand Down Expand Up @@ -702,12 +720,29 @@ class Tensor {
}

template <typename I, int32_t D>
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], int32_t device) {
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D],
int32_t device = Context::kCpuId) {
auto &h_vec = data_.HostVector();
h_vec = data;
// shape
this->Initialize(shape, device);
}
/**
* \brief Index operator. Not thread safe, should not be used in performance critical
* region. For more efficient indexing, consider getting a view first.
*/
template <typename... Index>
T &operator()(Index &&...idx) {
return this->HostView()(std::forward<Index>(idx)...);
}
/**
* \brief Index operator. Not thread safe, should not be used in performance critical
* region. For more efficient indexing, consider getting a view first.
*/
template <typename... Index>
T const &operator()(Index &&...idx) const {
return this->HostView()(std::forward<Index>(idx)...);
}

/**
* \brief Get a \ref TensorView for this tensor.
Expand Down Expand Up @@ -761,7 +796,7 @@ class Tensor {
*
* If the total size is changed, then data in this tensor is no longer valid.
*/
template <typename... S>
template <typename... S, detail::EnableIfIntegral<S...> * = nullptr>
void Reshape(S &&...s) {
static_assert(sizeof...(S) <= kDim, "Invalid shape.");
detail::ReshapeImpl<0>(shape_, std::forward<S>(s)...);
Expand All @@ -777,15 +812,20 @@ class Tensor {
*
* If the total size is changed, then data in this tensor is no longer valid.
*/
template <int32_t D>
void Reshape(size_t (&shape)[D]) {
template <size_t D>
void Reshape(common::Span<size_t const, D> shape) {
static_assert(D <= kDim, "Invalid shape.");
std::copy(shape, shape + D, this->shape_);
std::copy(shape.data(), shape.data() + D, this->shape_);
std::fill(shape_ + D, shape_ + kDim, 1);
auto n = detail::CalcSize(shape_);
data_.Resize(n);
}

template <size_t D>
void Reshape(size_t (&shape)[D]) {
this->Reshape(common::Span<size_t const, D>{shape});
}

/**
* \brief Set device ordinal for this tensor.
*/
Expand Down
12 changes: 11 additions & 1 deletion include/xgboost/objective.h
Expand Up @@ -27,7 +27,10 @@ class RegTree;
/*! \brief interface of objective function */
class ObjFunction : public Configurable {
protected:
GenericParameter const* ctx_;
Context const* ctx_;

public:
static constexpr float DefaultBaseScore() { return 0.5f; }

public:
/*! \brief virtual destructor */
Expand Down Expand Up @@ -75,6 +78,13 @@ class ObjFunction : public Configurable {
virtual bst_float ProbToMargin(bst_float base_score) const {
return base_score;
}
/**
* \brief Make initialize estimation of prediction.
*
* \param info MetaInfo that contains label.
* \param base_score Output estimation.
*/
virtual void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const;
/*!
* \brief Return task of this objective.
*/
Expand Down
7 changes: 2 additions & 5 deletions include/xgboost/predictor.h
Expand Up @@ -102,13 +102,10 @@ class PredictionContainer {
*/
class Predictor {
protected:
/*
* \brief Runtime parameters.
*/
GenericParameter const* ctx_;
Context const* ctx_;

public:
explicit Predictor(GenericParameter const* ctx) : ctx_{ctx} {}
explicit Predictor(Context const* ctx) : ctx_{ctx} {}

virtual ~Predictor() = default;

Expand Down
4 changes: 3 additions & 1 deletion src/common/algorithm.h
@@ -1,7 +1,8 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#pragma once
#ifndef XGBOOST_COMMON_ALGORITHM_H_
#define XGBOOST_COMMON_ALGORITHM_H_
#include <algorithm> // std::upper_bound
#include <cinttypes> // std::size_t

Expand All @@ -14,3 +15,4 @@ auto SegmentId(It first, It last, Idx idx) {
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_ALGORITHM_H_
12 changes: 10 additions & 2 deletions src/common/common.h
Expand Up @@ -265,6 +265,7 @@ struct OptionalWeights {
explicit OptionalWeights(float w) : dft{w} {}

XGBOOST_DEVICE float operator[](size_t i) const { return weights.empty() ? dft : weights[i]; }
auto Empty() const { return weights.empty(); }
};

/**
Expand All @@ -276,22 +277,29 @@ XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
}

/**
* @brief A CRTP (curiously recurring template pattern) helper function.
* \brief A CRTP (curiously recurring template pattern) helper function.
*
* https://www.fluentcpp.com/2017/05/19/crtp-helper/
*
* Does two things:
* 1. Makes "crtp" explicit in the inheritance structure of a CRTP base class.
* 2. Avoids having to `static_cast` in a lot of places.
*
* @tparam T The derived class in a CRTP hierarchy.
* \tparam T The derived class in a CRTP hierarchy.
*/
template <typename T>
struct Crtp {
T &Underlying() { return static_cast<T &>(*this); }
T const &Underlying() const { return static_cast<T const &>(*this); }
};

/**
* \brief C++17 std::as_const
*/
template <typename T>
typename std::add_const<T>::type &AsConst(T &v) noexcept { // NOLINT(runtime/references)
return v;
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_
26 changes: 26 additions & 0 deletions src/common/linalg_op.h
Expand Up @@ -4,6 +4,7 @@
#ifndef XGBOOST_COMMON_LINALG_OP_H_
#define XGBOOST_COMMON_LINALG_OP_H_
#include <type_traits>
#include <cstdint> // std::int32_t

#include "common.h"
#include "threading_utils.h"
Expand Down Expand Up @@ -59,6 +60,31 @@ void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView<T, D> t,
ElementWiseKernelHost(t, ctx->Threads(), fn);
}
#endif // !defined(XGBOOST_USE_CUDA)

template <typename T, std::int32_t kDim>
auto cbegin(TensorView<T, kDim> v) { // NOLINT
auto it = common::MakeIndexTransformIter([&](size_t i) -> std::remove_cv_t<T> const& {
return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape()));
});
return it;
}

template <typename T, std::int32_t kDim>
auto cend(TensorView<T, kDim> v) { // NOLINT
return cbegin(v) + v.Size();
}

template <typename T, std::int32_t kDim>
auto begin(TensorView<T, kDim> v) { // NOLINT
auto it = common::MakeIndexTransformIter(
[&](size_t i) -> T& { return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape())); });
return it;
}

template <typename T, std::int32_t kDim>
auto end(TensorView<T, kDim> v) { // NOLINT
return begin(v) + v.Size();
}
} // namespace linalg
} // namespace xgboost
#endif // XGBOOST_COMMON_LINALG_OP_H_
28 changes: 28 additions & 0 deletions src/common/numeric.cc
@@ -0,0 +1,28 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include "numeric.h"

#include <numeric> // std::accumulate
#include <type_traits> // std::is_same

#include "threading_utils.h" // MemStackAllocator, ParallelFor, DefaultMaxThreads
#include "xgboost/generic_parameters.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector

namespace xgboost {
namespace common {
double Reduce(Context const* ctx, HostDeviceVector<float> const& values) {
if (ctx->IsCPU()) {
auto const& h_values = values.ConstHostVector();
MemStackAllocator<double, DefaultMaxThreads()> result_tloc(ctx->Threads(), 0);
ParallelFor(h_values.size(), ctx->Threads(),
[&](auto i) { result_tloc[omp_get_thread_num()] += h_values[i]; });
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cend(), 0.0);
static_assert(std::is_same<decltype(result), double>::value, "");
return result;
}
return cuda::Reduce(ctx, values);
}
} // namespace common
} // namespace xgboost

0 comments on commit fffb1fc

Please sign in to comment.