Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport] Move metric configuration into booster. (#6504) #6533

Merged
merged 1 commit into from Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 14 additions & 1 deletion python-package/xgboost/core.py
@@ -1,11 +1,12 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals
# pylint: disable=too-many-lines, too-many-locals, no-self-use
"""Core XGBoost Library."""
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping
# pylint: enable=no-name-in-module,import-error
from typing import Dict, Union, List
import ctypes
import os
import re
Expand Down Expand Up @@ -1012,6 +1013,7 @@ def __init__(self, params=None, cache=(), model_file=None):
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
ctypes.byref(self.handle)))
params = params or {}
params = self._configure_metrics(params.copy())
if isinstance(params, list):
params.append(('validate_parameters', True))
else:
Expand Down Expand Up @@ -1041,6 +1043,17 @@ def __init__(self, params=None, cache=(), model_file=None):
else:
raise TypeError('Unknown type:', model_file)

def _configure_metrics(self, params: Union[Dict, List]) -> Union[Dict, List]:
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params

def __del__(self):
if hasattr(self, 'handle') and self.handle is not None:
_check_call(_LIB.XGBoosterFree(self.handle))
Expand Down
13 changes: 0 additions & 13 deletions python-package/xgboost/training.py
Expand Up @@ -40,18 +40,6 @@ def _is_new_callback(callbacks):
for c in callbacks) or not callbacks


def _configure_metrics(params):
if isinstance(params, dict) and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
return params


def _train_internal(params, dtrain,
num_boost_round=10, evals=(),
obj=None, feval=None,
Expand All @@ -61,7 +49,6 @@ def _train_internal(params, dtrain,
"""internal training function"""
callbacks = [] if callbacks is None else copy.copy(callbacks)
evals = list(evals)
params = _configure_metrics(params.copy())

bst = Booster(params, [dtrain] + [d[0] for d in evals])
nboost = 0
Expand Down
25 changes: 21 additions & 4 deletions tests/python/test_basic.py
Expand Up @@ -57,6 +57,25 @@ def test_basic(self):
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0

def test_metric_config(self):
# Make sure that the metric configuration happens in booster so the
# string `['error', 'auc']` doesn't get passed down to core.
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
'objective': 'binary:logistic', 'eval_metric': ['error', 'auc']}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
booster = xgb.train(param, dtrain, num_round, watchlist)
predt_0 = booster.predict(dtrain)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
booster.save_model(path)

booster = xgb.Booster(params=param, model_file=path)
predt_1 = booster.predict(dtrain)
np.testing.assert_allclose(predt_0, predt_1)

def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
Expand Down Expand Up @@ -124,8 +143,8 @@ def test_dump(self):

dump2 = bst.get_dump(with_stats=True)
assert dump2[0].count('\n') == 3, 'Expected 1 root and 2 leaves - 3 lines in dump.'
assert (dump2[0].find('\n') > dump1[0].find('\n'),
'Expected more info when with_stats=True is given.')
msg = 'Expected more info when with_stats=True is given.'
assert dump2[0].find('\n') > dump1[0].find('\n'), msg

dump3 = bst.get_dump(dump_format="json")
dump3j = json.loads(dump3[0])
Expand Down Expand Up @@ -248,13 +267,11 @@ def test_DMatrix_save_to_path(self):
assert binary_path.exists()
Path.unlink(binary_path)


def test_Booster_init_invalid_path(self):
"""An invalid model_file path should raise XGBoostError."""
with pytest.raises(xgb.core.XGBoostError):
xgb.Booster(model_file=Path("invalidpath"))


def test_Booster_save_and_load(self):
"""Saving and loading model files from paths."""
save_path = Path("saveload.model")
Expand Down