Skip to content

Commit

Permalink
Configuration for init estimation. (#8343)
Browse files Browse the repository at this point in the history
* Configuration for init estimation.

* Check whether the model needs configuration based on const attribute `ModelFitted`
instead of a mutable state.
* Add parameter `boost_from_average` to tell whether the user has specified base score.
* Add tests.
  • Loading branch information
trivialfis committed Oct 17, 2022
1 parent 2176e51 commit 031d66e
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 106 deletions.
4 changes: 3 additions & 1 deletion doc/parameter.rst
Expand Up @@ -370,9 +370,11 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_.
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_.

* ``base_score`` [default=0.5]
* ``base_score``

- The initial prediction score of all instances, global bias
- The parameter is automatically estimated for selected objectives before training. To
disable the estimation, specify a real number argument.
- For sufficient number of iterations, changing this value will not have too much effect.

* ``eval_metric`` [default according to objective]
Expand Down
5 changes: 5 additions & 0 deletions include/xgboost/gbm.h
Expand Up @@ -75,6 +75,11 @@ class GradientBooster : public Model, public Configurable {
/*! \brief Return number of boosted rounds.
*/
virtual int32_t BoostedRounds() const = 0;
/**
* \brief Whether the model has already been trained. When tree booster is chosen, then
* returns true when there are existing trees.
*/
virtual bool ModelFitted() const = 0;
/*!
* \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/learner.h
Expand Up @@ -328,7 +328,7 @@ struct LearnerModelParam {
void Copy(LearnerModelParam const& that);

/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0; }
bool Initialized() const { return num_feature != 0 && num_output_group != 0; }
};

} // namespace xgboost
Expand Down
4 changes: 4 additions & 0 deletions src/common/host_device_vector.cu
Expand Up @@ -162,6 +162,10 @@ class HostDeviceVectorImpl {
if (device_ >= 0) {
LazySyncHost(GPUAccess::kNone);
}

if (device_ >= 0 && device >= 0) {
CHECK_EQ(device_, device) << "New device ordinal is different from previous one.";
}
device_ = device;
if (device_ >= 0) {
LazyResizeDevice(data_h_.size());
Expand Down
6 changes: 3 additions & 3 deletions src/common/linalg_op.h
Expand Up @@ -3,8 +3,8 @@
*/
#ifndef XGBOOST_COMMON_LINALG_OP_H_
#define XGBOOST_COMMON_LINALG_OP_H_
#include <type_traits>
#include <cstdint> // std::int32_t
#include <type_traits>

#include "common.h"
#include "threading_utils.h"
Expand Down Expand Up @@ -43,12 +43,12 @@ void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& f

#if !defined(XGBOOST_USE_CUDA)
template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, void* s = nullptr) {
void ElementWiseKernelDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
common::AssertGPUSupport();
}

template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, void* s = nullptr) {
void ElementWiseTransformDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
common::AssertGPUSupport();
}

Expand Down
2 changes: 2 additions & 0 deletions src/gbm/gblinear.cc
Expand Up @@ -95,6 +95,8 @@ class GBLinear : public GradientBooster {
return model_.num_boosted_rounds;
}

bool ModelFitted() const override { return BoostedRounds() != 0; }

void Load(dmlc::Stream* fi) override {
model_.Load(fi);
}
Expand Down
4 changes: 4 additions & 0 deletions src/gbm/gbtree.h
Expand Up @@ -252,6 +252,10 @@ class GBTree : public GradientBooster {
return model_.trees.size() / this->LayerTrees();
}

bool ModelFitted() const override {
return !model_.trees.empty() || !model_.trees_to_update.empty();
}

void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds,
bool training, unsigned layer_begin, unsigned layer_end) override;

Expand Down
137 changes: 81 additions & 56 deletions src/learner.cc
Expand Up @@ -12,6 +12,7 @@
#include <dmlc/thread_local.h>

#include <algorithm>
#include <array>
#include <atomic>
#include <iomanip>
#include <limits> // std::numeric_limits
Expand All @@ -27,7 +28,6 @@
#include "common/charconv.h"
#include "common/common.h"
#include "common/io.h"
#include "common/linalg_op.h"
#include "common/observer.h"
#include "common/random.h"
#include "common/threading_utils.h"
Expand Down Expand Up @@ -64,6 +64,15 @@ DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);

namespace xgboost {
Learner::~Learner() = default;
namespace {
StringView ModelNotFitted() { return "Model is not yet initialized (not fitted)."; }

template <typename T>
T& UsePtr(T& ptr) { // NOLINT
CHECK(ptr);
return ptr;
}
} // anonymous namespace

/*! \brief training parameter for regression
*
Expand All @@ -75,20 +84,28 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
/* \brief global bias */
bst_float base_score;
/* \brief number of features */
uint32_t num_feature;
bst_feature_t num_feature;
/* \brief number of classes, if it is multi-class classification */
int32_t num_class;
std::int32_t num_class;
/*! \brief Model contain additional properties */
int32_t contain_extra_attrs;
/*! \brief Model contain eval metrics */
int32_t contain_eval_metrics;
/*! \brief the version of XGBoost. */
uint32_t major_version;
uint32_t minor_version;
std::uint32_t major_version;
std::uint32_t minor_version;

uint32_t num_target{1};

int32_t base_score_estimated{0};
/**
* \brief Whether we should calculate the base score from training data.
*
* This is a private parameter as we can't expose it as boolean due to binary model
* format. Exposing it as integer creates inconsistency with other parameters.
*
* Automatically disabled when base_score is specifed by user. int32 is used instead
* of bool for the ease of serialization.
*/
std::int32_t boost_from_average{true};
/*! \brief reserved field */
int reserved[25];
/*! \brief constructor */
Expand All @@ -98,14 +115,14 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
num_target = 1;
major_version = std::get<0>(Version::Self());
minor_version = std::get<1>(Version::Self());
base_score_estimated = 0;
boost_from_average = true;
static_assert(sizeof(LearnerModelParamLegacy) == 136,
"Do not change the size of this struct, as it will break binary IO.");
}

// Skip other legacy fields.
Json ToJson() const {
Object obj;
Json obj{Object{}};
char floats[NumericLimits<float>::kToCharsSize];
auto ret = to_chars(floats, floats + NumericLimits<float>::kToCharsSize, base_score);
CHECK(ret.ec == std::errc{});
Expand All @@ -120,15 +137,19 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
static_cast<int64_t>(num_class));
CHECK(ret.ec == std::errc());
obj["num_class"] =
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};
obj["num_class"] = std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};

ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
static_cast<int64_t>(num_target));
obj["num_target"] =
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};

return Json(std::move(obj));
ret = to_chars(integers, integers + NumericLimits<std::int64_t>::kToCharsSize,
static_cast<std::int64_t>(boost_from_average));
obj["boost_from_average"] =
std::string{integers, static_cast<std::size_t>(std::distance(integers, ret.ptr))};

return obj;
}
void FromJson(Json const& obj) {
auto const& j_param = get<Object const>(obj);
Expand All @@ -139,13 +160,15 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
if (n_targets_it != j_param.cend()) {
m["num_target"] = get<String const>(n_targets_it->second);
}
auto bse_it = j_param.find("boost_from_average");
if (bse_it != j_param.cend()) {
m["boost_from_average"] = get<String const>(bse_it->second);
}

this->Init(m);

std::string str = get<String const>(j_param.at("base_score"));
from_chars(str.c_str(), str.c_str() + str.size(), base_score);
// It can only be estimated during the first training, we consider it estimated afterward
base_score_estimated = 1;
}

LearnerModelParamLegacy ByteSwap() const {
Expand All @@ -158,22 +181,21 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1);
dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1);
dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1);
dmlc::ByteSwap(&x.base_score_estimated, sizeof(x.base_score_estimated), 1);
dmlc::ByteSwap(&x.boost_from_average, sizeof(x.boost_from_average), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x;
}

template <typename Container>
Args UpdateAllowUnknown(Container const& kwargs) {
// Detect whether user has made their own base score.
if (std::find_if(kwargs.cbegin(), kwargs.cend(),
[](auto const& kv) { return kv.first == "base_score"; }) != kwargs.cend()) {
base_score_estimated = true;
}
if (std::find_if(kwargs.cbegin(), kwargs.cend(), [](auto const& kv) {
return kv.first == "base_score_estimated";
}) != kwargs.cend()) {
LOG(FATAL) << "`base_score_estimated` cannot be specified as hyper-parameter.";
auto find_key = [&kwargs](char const* key) {
return std::find_if(kwargs.cbegin(), kwargs.cend(),
[key](auto const& kv) { return kv.first == key; });
};
auto it = find_key("base_score");
if (it != kwargs.cend()) {
boost_from_average = false;
}
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
}
Expand All @@ -195,7 +217,9 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
.set_default(1)
.set_lower_bound(1)
.describe("Number of target for multi-target regression.");
DMLC_DECLARE_FIELD(base_score_estimated).set_default(0);
DMLC_DECLARE_FIELD(boost_from_average)
.set_default(true)
.describe("Whether we should calculate the base score from training data.");
}
};

Expand Down Expand Up @@ -224,7 +248,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy

linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device) const {
// multi-class is not yet supported.
CHECK_EQ(base_score_.Size(), 1);
CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted();
if (device == Context::kCpuId) {
// Make sure that we won't run into race condition.
CHECK(base_score_.Data()->HostCanRead());
Expand Down Expand Up @@ -385,6 +409,21 @@ class LearnerConfiguration : public Learner {
// Initial prediction.
std::vector<std::string> metric_names_;

void ConfigureModelParamWithoutBaseScore() {
// Convert mparam to learner_model_param
this->ConfigureTargets();

auto task = UsePtr(obj_)->Task();
linalg::Tensor<float, 1> base_score({1}, Ctx()->gpu_id);
auto h_base_score = base_score.HostView();

// transform to margin
h_base_score(0) = obj_->ProbToMargin(mparam_.base_score);
// move it to model param, which is shared with all other components.
learner_model_param_ = LearnerModelParam(Ctx(), mparam_, std::move(base_score), task);
CHECK(learner_model_param_.Initialized());
CHECK_NE(learner_model_param_.BaseScore(Ctx()).Size(), 0);
}
/**
* \brief Calculate the `base_score` based on input data.
*
Expand All @@ -403,38 +442,24 @@ class LearnerConfiguration : public Learner {
// - model loaded from new binary or JSON.
// - model is created from scratch.
// - model is configured second time due to change of parameter
CHECK(obj_);
if (!mparam_.base_score_estimated) {
if (!learner_model_param_.Initialized()) {
this->ConfigureModelParamWithoutBaseScore();
}
if (mparam_.boost_from_average && !UsePtr(gbm_)->ModelFitted()) {
if (p_fmat) {
auto const& info = p_fmat->Info();
info.Validate(Ctx()->gpu_id);
// We estimate it from input data.
linalg::Tensor<float, 1> base_score;
obj_->InitEstimation(p_fmat->Info(), &base_score);
UsePtr(obj_)->InitEstimation(info, &base_score);
mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score));
} else {
mparam_.base_score = ObjFunction::DefaultBaseScore();
}
mparam_.base_score_estimated = true;
// Update the shared model parameter
this->ConfigureModelParam();
this->ConfigureModelParamWithoutBaseScore();
}
}

// Convert mparam to learner_model_param
void ConfigureModelParam() {
this->ConfigureTargets();

CHECK(obj_);
auto task = obj_->Task();
linalg::Tensor<float, 1> base_score({1}, Ctx()->gpu_id);
auto h_base_score = base_score.HostView();

// transform to margin
h_base_score(0) = obj_->ProbToMargin(mparam_.base_score);
// move it to model param, which is shared with all other components.
learner_model_param_ = LearnerModelParam(Ctx(), mparam_, std::move(base_score), task);
CHECK(learner_model_param_.Initialized());
CHECK_NE(learner_model_param_.BaseScore(Ctx()).Size(), 0);
CHECK(!std::isnan(mparam_.base_score));
CHECK(!std::isinf(mparam_.base_score));
}

public:
Expand Down Expand Up @@ -496,7 +521,8 @@ class LearnerConfiguration : public Learner {
learner_model_param_.task = obj_->Task(); // required by gbm configuration.
this->ConfigureGBM(old_tparam, args);
ctx_.ConfigureGpuId(this->gbm_->UseGPU());
this->ConfigureModelParam();

this->ConfigureModelParamWithoutBaseScore();

this->ConfigureMetrics(args);

Expand All @@ -510,8 +536,8 @@ class LearnerConfiguration : public Learner {
}

void CheckModelInitialized() const {
CHECK(learner_model_param_.Initialized()) << "Model not yet initialized.";
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0);
CHECK(learner_model_param_.Initialized()) << ModelNotFitted();
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted();
}

virtual PredictionContainer* GetPredictionCache() const {
Expand Down Expand Up @@ -1318,8 +1344,6 @@ class LearnerImpl : public LearnerIO {
HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter");
this->Configure();
// Should have been set to default in the first prediction.
CHECK(mparam_.base_score_estimated);

if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
Expand Down Expand Up @@ -1380,7 +1404,9 @@ class LearnerImpl : public LearnerIO {
static_cast<int>(pred_interactions) +
static_cast<int>(pred_contribs);
this->Configure();
this->InitBaseScore(nullptr);
if (training) {
this->InitBaseScore(nullptr);
}
this->CheckModelInitialized();

CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
Expand Down Expand Up @@ -1425,7 +1451,6 @@ class LearnerImpl : public LearnerIO {
HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin,
uint32_t iteration_end) override {
this->Configure();
this->InitBaseScore(nullptr);
this->CheckModelInitialized();

auto& out_predictions = this->GetThreadLocal().prediction_entry;
Expand Down

0 comments on commit 031d66e

Please sign in to comment.