Skip to content

Commit

Permalink
Reviewers' comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 10, 2020
1 parent 60240f9 commit 2316f8a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion demo/guide-python/callbacks.py
Expand Up @@ -97,7 +97,7 @@ def check(as_pickle):
# Check point to a temporary directory for demo
with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suite your need.
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
rounds=rounds,
name='model')
Expand Down
8 changes: 4 additions & 4 deletions doc/python/callbacks.rst
Expand Up @@ -4,7 +4,7 @@ Callback Functions

This document gives a basic walkthrough of callback function used in XGBoost Python
package. In XGBoost 1.3, a new callback interface is designed for Python package, which
provides the flexiablity of designing various extension for training. Also, XGBoost has a
provides the flexiblity of designing various extension for training. Also, XGBoost has a
number of pre-defined callbacks for supporting early stopping, checkpoints etc.

#######################
Expand All @@ -30,11 +30,11 @@ this callback function directly into XGBoost:
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'PyError', np.sum(r)
return 'CustomErr', np.sum(r)
# Specify which dataset and which metric should be used for early stopping.
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
metric_name='PyError',
metric_name='CustomErr',
data_name='Train')
booster = xgb.train(
Expand All @@ -48,7 +48,7 @@ this callback function directly into XGBoost:
verbose_eval=False)
dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['PyError']) == len(dump)
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
##########################
Defining your own callback
Expand Down
6 changes: 4 additions & 2 deletions python-package/xgboost/training.py
Expand Up @@ -12,6 +12,8 @@
def _configure_deprecated_callbacks(
verbose_eval, early_stopping_rounds, maximize, start_iteration,
num_boost_round, feval, evals_result, callbacks, show_stdv, cvfolds):
link = 'https://xgboost.readthedocs.io/en/latest/python/callbacks.html'
raise DeprecationWarning(f'Old style callback is deprecated. See: {link}')
# Most of legacy advanced options becomes callbacks
if early_stopping_rounds is not None:
callbacks.append(callback.early_stop(early_stopping_rounds,
Expand Down Expand Up @@ -85,7 +87,7 @@ def _train_internal(params, dtrain,
is_new_callback = _is_new_callback(callbacks)
if is_new_callback:
assert all(isinstance(c, callback.TrainingCallback)
for c in callbacks), "You can't mix two styles of callbacks."
for c in callbacks), "You can't mix new and old callback styles."
if verbose_eval:
callbacks.append(callback.EvaluationMonitor())
if early_stopping_rounds:
Expand Down Expand Up @@ -478,7 +480,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
is_new_callback = _is_new_callback(callbacks)
if is_new_callback:
assert all(isinstance(c, callback.TrainingCallback)
for c in callbacks), "You can't mix two styles of callbacks."
for c in callbacks), "You can't mix new and old callback styles."
if isinstance(verbose_eval, bool) and verbose_eval:
callbacks.append(callback.EvaluationMonitor(show_stdv=show_stdv))
if early_stopping_rounds:
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_callback.py
Expand Up @@ -74,7 +74,7 @@ def test_early_stopping_customize(self):
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
metric_name='PyError',
metric_name='CustomErr',
data_name='Train')
# Specify which dataset and which metric should be used for early stopping.
booster = xgb.train(
Expand All @@ -88,7 +88,7 @@ def test_early_stopping_customize(self):
verbose_eval=False)
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
assert len(early_stop.stopping_history['Train']['PyError']) == len(dump)
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)

def test_early_stopping_skl(self):
from sklearn.datasets import load_breast_cancer
Expand Down

0 comments on commit 2316f8a

Please sign in to comment.