Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass infomration about objective to metrics. #7386

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
#include <dmlc/any.h>
#include <xgboost/base.h>
#include <xgboost/feature_map.h>
#include <xgboost/predictor.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>
#include <xgboost/predictor.h>
#include <xgboost/task.h>

#include <utility>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace xgboost {
Expand Down Expand Up @@ -307,11 +308,13 @@ struct LearnerModelParam {
uint32_t num_feature { 0 };
/* \brief number of classes, if it is multi-class classification */
uint32_t num_output_group { 0 };
/* \brief Current task, determined by objective. */
ObjInfo task{ObjInfo::kRegression};

LearnerModelParam() = default;
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
// this one as an immutable copy.
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t);
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0; }
};
Expand Down
5 changes: 3 additions & 2 deletions include/xgboost/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <xgboost/data.h>
#include <xgboost/base.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/task.h>

#include <vector>
#include <string>
Expand Down Expand Up @@ -73,7 +74,7 @@ class Metric : public Configurable {
* \param tparam A global generic parameter
* \return the created metric.
*/
static Metric* Create(const std::string& name, GenericParameter const* tparam);
static Metric* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task);
};

/*!
Expand All @@ -83,7 +84,7 @@ class Metric : public Configurable {
*/
struct MetricReg
: public dmlc::FunctionRegEntryBase<MetricReg,
std::function<Metric* (const char*)> > {
std::function<Metric* (const char*, ObjInfo task)> > {
};

/*!
Expand Down
6 changes: 6 additions & 0 deletions include/xgboost/objective.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <xgboost/model.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/task.h>

#include <vector>
#include <utility>
Expand Down Expand Up @@ -72,6 +73,11 @@ class ObjFunction : public Configurable {
virtual bst_float ProbToMargin(bst_float base_score) const {
return base_score;
}
/*!
* \brief Return task of this objective.
*/
virtual struct ObjInfo Task() const = 0;

/*!
* \brief Create an objective function according to name.
* \param tparam Generic parameters.
Expand Down
57 changes: 57 additions & 0 deletions include/xgboost/task.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_TASK_H_
#define XGBOOST_TASK_H_

#include <cinttypes>
#include <exception>

namespace xgboost {
/*!
* \brief A struct returned by objective, which determines task at hand. The struct is
* not used by any algorithm yet, only for future development like categorical
* split.
*
* The task field is useful for tree split finding, also for some metrics like auc. While
* const_hess is useful for algorithms like adaptive tree where one needs to update the
* leaf value after building the tree. Lastly, knowing whether hessian is constant can
* allow some optimizations like skipping the quantile sketching.
*
* This struct should not be serialized since it can be recovered from objective function,
* hence it doesn't need to be stable.
*/
struct ObjInfo {
// What kind of problem are we trying to solve
enum : uint8_t {
kRegression = 0,
kBinary = 1,
kClassification = 2,
kSurvival = 3,
kRanking = 4,
kOther = 5,
} task;
// Does the objective have constant hessian value?
bool const_hess{false};

char const* TaskStr() const {
switch (task) {
case kRegression:
return "regression";
case kBinary:
return "binary classification";
case kClassification:
return "multi-class classification";
case kSurvival:
return "survival";
case kRanking:
return "learning to rank";
case kOther:
return "unknown";
default:
std::terminate();
}
}
};
} // namespace xgboost
#endif // XGBOOST_TASK_H_
14 changes: 7 additions & 7 deletions include/xgboost/tree_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
#include <dmlc/registry.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/tree_model.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>
#include <xgboost/linalg.h>
#include <xgboost/model.h>
#include <xgboost/task.h>
#include <xgboost/tree_model.h>

#include <functional>
#include <vector>
#include <utility>
#include <string>
#include <utility>
#include <vector>

namespace xgboost {

Expand Down Expand Up @@ -83,16 +84,15 @@ class TreeUpdater : public Configurable {
* \param name Name of the tree updater.
* \param tparam A global runtime parameter
*/
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam);
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task);
};

/*!
* \brief Registry entry for tree updater.
*/
struct TreeUpdaterReg
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg,
std::function<TreeUpdater* ()> > {
};
std::function<TreeUpdater*(ObjInfo task)> > {};

/*!
* \brief Macro to register tree updater.
Expand Down
5 changes: 5 additions & 0 deletions plugin/example/custom_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class MyLogistic : public ObjFunction {
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
}

struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}

void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
Expand Down
6 changes: 4 additions & 2 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ void GBTree::InitUpdater(Args const& cfg) {

// create new updaters
for (const std::string& pstr : ups) {
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str(), generic_param_));
std::unique_ptr<TreeUpdater> up(
TreeUpdater::Create(pstr.c_str(), generic_param_, model_.learner_model_param->task));
up->Configure(cfg);
updaters_.push_back(std::move(up));
}
Expand Down Expand Up @@ -391,7 +392,8 @@ void GBTree::LoadConfig(Json const& in) {
auto const& j_updaters = get<Object const>(in["updater"]);
updaters_.clear();
for (auto const& kv : j_updaters) {
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(kv.first, generic_param_));
std::unique_ptr<TreeUpdater> up(
TreeUpdater::Create(kv.first, generic_param_, model_.learner_model_param->task));
up->LoadConfig(kv.second);
updaters_.push_back(std::move(up));
}
Expand Down
27 changes: 14 additions & 13 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,12 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
}
};

LearnerModelParam::LearnerModelParam(
LearnerModelParamLegacy const &user_param, float base_margin)
: base_score{base_margin}, num_feature{user_param.num_feature},
num_output_group{user_param.num_class == 0
? 1
: static_cast<uint32_t>(user_param.num_class)}
{}
LearnerModelParam::LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin,
ObjInfo t)
: base_score{base_margin},
num_feature{user_param.num_feature},
num_output_group{user_param.num_class == 0 ? 1 : static_cast<uint32_t>(user_param.num_class)},
task{t} {}

struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none.
Expand Down Expand Up @@ -339,8 +338,8 @@ class LearnerConfiguration : public Learner {
// - model is created from scratch.
// - model is configured second time due to change of parameter
if (!learner_model_param_.Initialized() || mparam_.base_score != mparam_backup.base_score) {
learner_model_param_ = LearnerModelParam(mparam_,
obj_->ProbToMargin(mparam_.base_score));
learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
}

this->ConfigureGBM(old_tparam, args);
Expand Down Expand Up @@ -390,7 +389,7 @@ class LearnerConfiguration : public Learner {
for (size_t i = 0; i < n_metrics; ++i) {
metric_names_[i]= get<String>(j_metrics[i]);
metrics_[i] = std::unique_ptr<Metric>(
Metric::Create(metric_names_[i], &generic_parameters_));
Metric::Create(metric_names_[i], &generic_parameters_, obj_->Task()));
}

FromJson(learner_parameters.at("generic_param"), &generic_parameters_);
Expand Down Expand Up @@ -645,7 +644,8 @@ class LearnerConfiguration : public Learner {
return m->Name() != name;
};
if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) {
metrics_.emplace_back(std::unique_ptr<Metric>(Metric::Create(name, &generic_parameters_)));
metrics_.emplace_back(
std::unique_ptr<Metric>(Metric::Create(name, &generic_parameters_, obj_->Task())));
mparam_.contain_eval_metrics = 1;
}
}
Expand Down Expand Up @@ -832,7 +832,7 @@ class LearnerIO : public LearnerConfiguration {
}

learner_model_param_ =
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score));
LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score), obj_->Task());
if (attributes_.find("objective") != attributes_.cend()) {
auto obj_str = attributes_.at("objective");
auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});
Expand Down Expand Up @@ -1122,7 +1122,8 @@ class LearnerImpl : public LearnerIO {
if (tparam_.objective == "binary:logitraw") {
warn_default_eval_metric(tparam_.objective, "auc", "logloss", "1.4.0");
}
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_));
metrics_.emplace_back(
Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_, obj_->Task()));
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
}

Expand Down
20 changes: 15 additions & 5 deletions src/metric/auc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,

template <typename Curve>
class EvalAUC : public Metric {
ObjInfo task_;

public:
explicit EvalAUC(ObjInfo task) : task_{task} {}
float Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
bool distributed) override {
float auc {0};
Expand All @@ -257,10 +261,11 @@ class EvalAUC : public Metric {
if (meta[0] == 0) {
// Empty across all workers, which is not supported.
auc = std::numeric_limits<float>::quiet_NaN();
} else if (!info.group_ptr_.empty()) {
} else if (task_.task == ObjInfo::kRanking) {
/**
* learning to rank
*/
CHECK(!info.group_ptr_.empty()) << "Group or QID is required for ranking.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.group_ptr_.size() - 1);
}
Expand All @@ -286,14 +291,14 @@ class EvalAUC : public Metric {
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
<< ", valid groups: " << valid_groups;
}
} else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) {
} else if (task_.task == ObjInfo::kClassification) {
/**
* multi class
*/
size_t n_classes = meta[1] / meta[0];
CHECK_NE(n_classes, 0);
auc = static_cast<Curve *>(this)->EvalMultiClass(preds, info, n_classes);
} else {
} else if (task_.task == ObjInfo::kBinary) {
/**
* binary classification
*/
Expand All @@ -314,6 +319,8 @@ class EvalAUC : public Metric {
// normalization
auc = auc / local_area;
}
} else {
LOG(FATAL) << "Can not calculate AUC for " << task_.TaskStr() << " task.";
}
if (std::isnan(auc)) {
LOG(WARNING) << "Dataset is empty, or contains only positive or negative samples.";
Expand All @@ -326,6 +333,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
std::shared_ptr<DeviceAUCCache> d_cache_;

public:
using EvalAUC<EvalROCAUC>::EvalAUC;
std::pair<float, uint32_t> EvalRanking(HostDeviceVector<float> const &predts,
MetaInfo const &info) {
float auc{0};
Expand Down Expand Up @@ -378,7 +386,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {

XGBOOST_REGISTER_METRIC(EvalAUC, "auc")
.describe("Receiver Operating Characteristic Area Under the Curve.")
.set_body([](const char*) { return new EvalROCAUC(); });
.set_body([](const char*, ObjInfo task) { return new EvalROCAUC(task); });

#if !defined(XGBOOST_USE_CUDA)
std::tuple<float, float, float>
Expand Down Expand Up @@ -409,6 +417,8 @@ class EvalAUCPR : public EvalAUC<EvalAUCPR> {
std::shared_ptr<DeviceAUCCache> d_cache_;

public:
using EvalAUC<EvalAUCPR>::EvalAUC;

std::tuple<float, float, float>
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
float pr, re, auc;
Expand Down Expand Up @@ -460,7 +470,7 @@ class EvalAUCPR : public EvalAUC<EvalAUCPR> {

XGBOOST_REGISTER_METRIC(AUCPR, "aucpr")
.describe("Area under PR curve for both classification and rank.")
.set_body([](char const *) { return new EvalAUCPR{}; });
.set_body([](char const *, ObjInfo task) { return new EvalAUCPR{task}; });

#if !defined(XGBOOST_USE_CUDA)
std::tuple<float, float, float>
Expand Down