Skip to content

Commit

Permalink
Split up the configuration.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 13, 2022
1 parent 48112fe commit c22010b
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions src/learner.cc
Expand Up @@ -202,9 +202,11 @@ linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device)
// multi-class is not supported yet.
CHECK_EQ(base_score_.Size(), 1);
if (device == Context::kCpuId) {
// Make sure that we won't run it race condition.
CHECK(base_score_.Data()->HostCanRead());
return base_score_.HostView();
}
// Make sure that we won't run it race condition.
CHECK(base_score_.Data()->DeviceCanRead());
auto v = base_score_.View(device);
CHECK(base_score_.Data()->HostCanRead()); // make sure read access is not removed.
Expand Down Expand Up @@ -358,8 +360,10 @@ class LearnerConfiguration : public Learner {

/**
* \brief Calculate the `base_score` based on input data.
*
* \param p_fmat The training DMatrix used to estimate the base score.
*/
void ConfigureLearnerParam(DMatrix const* p_fmat) {
void InitBaseScore(DMatrix const* p_fmat) {
linalg::Tensor<float, 1> base_score;
// Before 1.0.0, we save `base_score` into binary as a transformed value by objective.
// After 1.0.0 we save the value provided by user and keep it immutable instead. To
Expand Down Expand Up @@ -387,21 +391,31 @@ class LearnerConfiguration : public Learner {
base_score(0) = ObjFunction::DefaultBaseScore();
}

auto task = obj_->Task();
mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score));
// Update the shared model parameter
this->ConfigureModelParam();
}

// transform to margin
linalg::Tensor<float, 1> copy(base_score.Shape(), ctx_.gpu_id);
auto in = base_score.HostView();
// Convert mparam to learner_model_param
void ConfigureModelParam() {
CHECK(obj_);
auto task = obj_->Task();
linalg::Tensor<float, 1> copy({1}, ctx_.gpu_id);
auto out = copy.HostView();
std::transform(linalg::cbegin(in), linalg::cend(in), linalg::begin(out),
[&](float v) { return obj_->ProbToMargin(v); });

// move it to model param, which is shared with all other components.
learner_model_param_ = LearnerModelParam(&ctx_, mparam_, std::move(copy), task);
CHECK(learner_model_param_.Initialized());
CHECK_NE(learner_model_param_.BaseScore(&ctx_).Size(), 0);
CHECK(!std::isnan(mparam_.base_score));
if (!std::isnan(mparam_.base_score)) {
// transform to margin
out(0) = obj_->ProbToMargin(mparam_.base_score);
// move it to model param, which is shared with all other components.
learner_model_param_ = LearnerModelParam(&ctx_, mparam_, std::move(copy), task);
CHECK(learner_model_param_.Initialized());
CHECK_NE(learner_model_param_.BaseScore(&ctx_).Size(), 0);
} else {
// Model is not yet fitted, use default base score.
out(0) = ObjFunction::DefaultBaseScore();
learner_model_param_ = LearnerModelParam(&ctx_, mparam_, std::move(copy), task);
}
}

public:
Expand Down Expand Up @@ -465,6 +479,7 @@ class LearnerConfiguration : public Learner {
learner_model_param_.task = obj_->Task(); // required by gbm configuration.
this->ConfigureGBM(old_tparam, args);
ctx_.ConfigureGpuId(this->gbm_->UseGPU());
this->ConfigureModelParam();

this->ConfigureMetrics(args);

Expand All @@ -479,6 +494,7 @@ class LearnerConfiguration : public Learner {

void CheckModelInitialized() const {
CHECK(learner_model_param_.Initialized()) << "Model not fit.";
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0);
}

virtual PredictionContainer* GetPredictionCache() const {
Expand Down Expand Up @@ -1255,7 +1271,7 @@ class LearnerImpl : public LearnerIO {
monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter);
this->Configure();
this->ConfigureLearnerParam(train.get());
this->InitBaseScore(train.get());

if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
Expand Down Expand Up @@ -1285,7 +1301,7 @@ class LearnerImpl : public LearnerIO {
HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter");
this->Configure();
this->ConfigureLearnerParam(train.get());
this->InitBaseScore(train.get());

if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
Expand Down

0 comments on commit c22010b

Please sign in to comment.