From a940f550693510df95303aa90c10bf1de30c1755 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 16 Dec 2020 05:35:04 +0800 Subject: [PATCH] Move metric configuration into booster. (#6504) --- python-package/xgboost/core.py | 15 ++++++++++++++- python-package/xgboost/training.py | 13 ------------- tests/python/test_basic.py | 25 +++++++++++++++++++++---- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 30d169a5a93d..cea0b340e5c1 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1,11 +1,12 @@ # coding: utf-8 # pylint: disable=too-many-arguments, too-many-branches, invalid-name -# pylint: disable=too-many-lines, too-many-locals +# pylint: disable=too-many-lines, too-many-locals, no-self-use """Core XGBoost Library.""" import collections # pylint: disable=no-name-in-module,import-error from collections.abc import Mapping # pylint: enable=no-name-in-module,import-error +from typing import Dict, Union, List import ctypes import os import re @@ -1012,6 +1013,7 @@ def __init__(self, params=None, cache=(), model_file=None): _check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)), ctypes.byref(self.handle))) params = params or {} + params = self._configure_metrics(params.copy()) if isinstance(params, list): params.append(('validate_parameters', True)) else: @@ -1041,6 +1043,17 @@ def __init__(self, params=None, cache=(), model_file=None): else: raise TypeError('Unknown type:', model_file) + def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]: + if isinstance(params, dict) and 'eval_metric' in params \ + and isinstance(params['eval_metric'], list): + params = dict((k, v) for k, v in params.items()) + eval_metrics = params['eval_metric'] + params.pop("eval_metric", None) + params = list(params.items()) + for eval_metric in eval_metrics: + params += [('eval_metric', eval_metric)] + return params + def __del__(self): if hasattr(self, 'handle') and self.handle is not None: _check_call(_LIB.XGBoosterFree(self.handle)) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index c13c8dc70776..8db3a9798a9e 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -40,18 +40,6 @@ def _is_new_callback(callbacks): for c in callbacks) or not callbacks -def _configure_metrics(params): - if isinstance(params, dict) and 'eval_metric' in params \ - and isinstance(params['eval_metric'], list): - params = dict((k, v) for k, v in params.items()) - eval_metrics = params['eval_metric'] - params.pop("eval_metric", None) - params = list(params.items()) - for eval_metric in eval_metrics: - params += [('eval_metric', eval_metric)] - return params - - def _train_internal(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, @@ -61,7 +49,6 @@ def _train_internal(params, dtrain, """internal training function""" callbacks = [] if callbacks is None else copy.copy(callbacks) evals = list(evals) - params = _configure_metrics(params.copy()) bst = Booster(params, [dtrain] + [d[0] for d in evals]) nboost = 0 diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index c0c6a6a8fdcc..7ce87b208642 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -57,6 +57,25 @@ def test_basic(self): # assert they are the same assert np.sum(np.abs(preds2 - preds)) == 0 + def test_metric_config(self): + # Make sure that the metric configuration happens in booster so the + # string `['error', 'auc']` doesn't get passed down to core. + dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') + dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') + param = {'max_depth': 2, 'eta': 1, 'verbosity': 0, + 'objective': 'binary:logistic', 'eval_metric': ['error', 'auc']} + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 2 + booster = xgb.train(param, dtrain, num_round, watchlist) + predt_0 = booster.predict(dtrain) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'model.json') + booster.save_model(path) + + booster = xgb.Booster(params=param, model_file=path) + predt_1 = booster.predict(dtrain) + np.testing.assert_allclose(predt_0, predt_1) + def test_record_results(self): dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') @@ -124,8 +143,8 @@ def test_dump(self): dump2 = bst.get_dump(with_stats=True) assert dump2[0].count('\n') == 3, 'Expected 1 root and 2 leaves - 3 lines in dump.' - assert (dump2[0].find('\n') > dump1[0].find('\n'), - 'Expected more info when with_stats=True is given.') + msg = 'Expected more info when with_stats=True is given.' + assert dump2[0].find('\n') > dump1[0].find('\n'), msg dump3 = bst.get_dump(dump_format="json") dump3j = json.loads(dump3[0]) @@ -248,13 +267,11 @@ def test_DMatrix_save_to_path(self): assert binary_path.exists() Path.unlink(binary_path) - def test_Booster_init_invalid_path(self): """An invalid model_file path should raise XGBoostError.""" with pytest.raises(xgb.core.XGBoostError): xgb.Booster(model_file=Path("invalidpath")) - def test_Booster_save_and_load(self): """Saving and loading model files from paths.""" save_path = Path("saveload.model")