Skip to content

Commit

Permalink
[backport] Move metric configuration into booster. (#6504) (#6533)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 20, 2020
1 parent bce7ca3 commit 7109c6c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
15 changes: 14 additions & 1 deletion python-package/xgboost/core.py
@@ -1,11 +1,12 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals
# pylint: disable=too-many-lines, too-many-locals, no-self-use
"""Core XGBoost Library."""
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping
# pylint: enable=no-name-in-module,import-error
from typing import Dict, Union, List
import ctypes
import os
import re
Expand Down Expand Up @@ -1012,6 +1013,7 @@ def __init__(self, params=None, cache=(), model_file=None):
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
ctypes.byref(self.handle)))
params = params or {}
params = self._configure_metrics(params.copy())
if isinstance(params, list):
params.append(('validate_parameters', True))
else:
Expand Down Expand Up @@ -1041,6 +1043,17 @@ def __init__(self, params=None, cache=(), model_file=None):
else:
raise TypeError('Unknown type:', model_file)

def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params

def __del__(self):
if hasattr(self, 'handle') and self.handle is not None:
_check_call(_LIB.XGBoosterFree(self.handle))
Expand Down
13 changes: 0 additions & 13 deletions python-package/xgboost/training.py
Expand Up @@ -40,18 +40,6 @@ def _is_new_callback(callbacks):
for c in callbacks) or not callbacks


def _configure_metrics(params):
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params


def _train_internal(params, dtrain,
num_boost_round=10, evals=(),
obj=None, feval=None,
Expand All @@ -61,7 +49,6 @@ def _train_internal(params, dtrain,
"""internal training function"""
callbacks = [] if callbacks is None else copy.copy(callbacks)
evals = list(evals)
params = _configure_metrics(params.copy())

bst = Booster(params, [dtrain] + [d[0] for d in evals])
nboost = 0
Expand Down
25 changes: 21 additions & 4 deletions tests/python/test_basic.py
Expand Up @@ -57,6 +57,25 @@ def test_basic(self):
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0

def test_metric_config(self):
# Make sure that the metric configuration happens in booster so the
# string `['error', 'auc']` doesn't get passed down to core.
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic', 'eval_metric': ['error', 'auc']}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
booster = xgb.train(param, dtrain, num_round, watchlist)
predt_0 = booster.predict(dtrain)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
booster.save_model(path)

booster = xgb.Booster(params=param, model_file=path)
predt_1 = booster.predict(dtrain)
np.testing.assert_allclose(predt_0, predt_1)

def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
Expand Down Expand Up @@ -124,8 +143,8 @@ def test_dump(self):

dump2 = bst.get_dump(with_stats=True)
assert dump2[0].count('\n') == 3, 'Expected 1 root and 2 leaves - 3 lines in dump.'
assert (dump2[0].find('\n') > dump1[0].find('\n'),
'Expected more info when with_stats=True is given.')
msg = 'Expected more info when with_stats=True is given.'
assert dump2[0].find('\n') > dump1[0].find('\n'), msg

dump3 = bst.get_dump(dump_format="json")
dump3j = json.loads(dump3[0])
Expand Down Expand Up @@ -248,13 +267,11 @@ def test_DMatrix_save_to_path(self):
assert binary_path.exists()
Path.unlink(binary_path)


def test_Booster_init_invalid_path(self):
"""An invalid model_file path should raise XGBoostError."""
with pytest.raises(xgb.core.XGBoostError):
xgb.Booster(model_file=Path("invalidpath"))


def test_Booster_save_and_load(self):
"""Saving and loading model files from paths."""
save_path = Path("saveload.model")
Expand Down

0 comments on commit 7109c6c

Please sign in to comment.