From 2d0549b7f371c5618496c3eeeca482a8012df8ad Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 28 Dec 2020 21:47:21 -0800 Subject: [PATCH] Fix handling of global bias for binary:logitraw objective of XGBoost --- src/frontend/xgboost.cc | 2 +- src/frontend/xgboost/xgboost.h | 2 +- src/frontend/xgboost_json.cc | 3 +-- src/frontend/xgboost_util.cc | 9 +-------- tests/python/test_xgboost_integration.py | 4 ++-- 5 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/frontend/xgboost.cc b/src/frontend/xgboost.cc index 691a1174..2d7e42fe 100644 --- a/src/frontend/xgboost.cc +++ b/src/frontend/xgboost.cc @@ -407,7 +407,7 @@ inline std::unique_ptr ParseStream(dmlc::Stream* fi) { // 1.0 it's the original value provided by user. const bool need_transform_to_margin = mparam_.major_version >= 1; if (need_transform_to_margin) { - treelite::details::xgboost::TransformGlobalBiasToMargin(name_obj_, &model->param); + treelite::details::xgboost::TransformGlobalBiasToMargin(&model->param); } // traverse trees diff --git a/src/frontend/xgboost/xgboost.h b/src/frontend/xgboost/xgboost.h index 9555fe95..6d914343 100644 --- a/src/frontend/xgboost/xgboost.h +++ b/src/frontend/xgboost/xgboost.h @@ -33,7 +33,7 @@ extern const std::vector exponential_objectives; void SetPredTransform(const std::string& objective_name, ModelParam* param); // Transform the global bias parameter from probability into margin score -void TransformGlobalBiasToMargin(const std::string& objective_name, ModelParam* param); +void TransformGlobalBiasToMargin(ModelParam* param); enum FeatureType { kNumerical = 0, diff --git a/src/frontend/xgboost_json.cc b/src/frontend/xgboost_json.cc index c061943d..0979fa9c 100644 --- a/src/frontend/xgboost_json.cc +++ b/src/frontend/xgboost_json.cc @@ -432,8 +432,7 @@ bool XGBoostModelHandler::EndObject(std::size_t memberCount) { // 1.0 it's the original value provided by user. const bool need_transform_to_margin = (version[0] >= 1); if (need_transform_to_margin) { - treelite::details::xgboost::TransformGlobalBiasToMargin( - output.objective_name, &output.model->param); + treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param); } return pop_handler(); } diff --git a/src/frontend/xgboost_util.cc b/src/frontend/xgboost_util.cc index e855b4b5..11488595 100644 --- a/src/frontend/xgboost_util.cc +++ b/src/frontend/xgboost_util.cc @@ -54,15 +54,8 @@ void SetPredTransform(const std::string& objective_name, ModelParam* param) { } // Transform the global bias parameter from probability into margin score -void TransformGlobalBiasToMargin(const std::string& objective_name, ModelParam* param) { +void TransformGlobalBiasToMargin(ModelParam* param) { std::string bias_transform{param->pred_transform}; - if (objective_name == "binary:logitraw") { - // Special handling for 'logitraw', where the global bias is transformed with 'sigmoid', - // but the prediction is returned un-transformed. - CHECK_EQ(bias_transform, "identity"); - bias_transform = "sigmoid"; - } - if (bias_transform == "sigmoid") { param->global_bias = ProbToMargin::Sigmoid(param->global_bias); } else if (bias_transform == "exponential") { diff --git a/tests/python/test_xgboost_integration.py b/tests/python/test_xgboost_integration.py index 3b551a27..6406dac5 100644 --- a/tests/python/test_xgboost_integration.py +++ b/tests/python/test_xgboost_integration.py @@ -114,18 +114,18 @@ def test_xgb_iris(tmpdir, toolchain, objective, model_format, expected_pred_tran np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) -@pytest.mark.parametrize('toolchain', os_compatible_toolchains()) @pytest.mark.parametrize('model_format', ['binary', 'json']) @pytest.mark.parametrize('objective,max_label,expected_global_bias', [('binary:logistic', 2, 0), ('binary:hinge', 2, 0.5), - ('binary:logitraw', 2, 0), + ('binary:logitraw', 2, 0.5), ('count:poisson', 4, math.log(0.5)), ('rank:pairwise', 5, 0.5), ('rank:ndcg', 5, 0.5), ('rank:map', 5, 0.5)], ids=['binary:logistic', 'binary:hinge', 'binary:logitraw', 'count:poisson', 'rank:pairwise', 'rank:ndcg', 'rank:map']) +@pytest.mark.parametrize('toolchain', os_compatible_toolchains()) def test_nonlinear_objective(tmpdir, objective, max_label, expected_global_bias, toolchain, model_format): # pylint: disable=too-many-locals,too-many-arguments