Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework Python callback functions. #6199

Merged
merged 53 commits into from Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0158987
Port previous impl.
trivialfis Oct 3, 2020
9cebcd2
Check point.
trivialfis Oct 3, 2020
358b227
Pass in the parameter.
trivialfis Oct 3, 2020
4a5879c
Distributed.
trivialfis Oct 3, 2020
39efda8
Port test.
trivialfis Oct 3, 2020
ba1b44a
Test checkpoint.
trivialfis Oct 3, 2020
6a19780
Start writing dask test.
trivialfis Oct 3, 2020
1b0a6c6
Pass dask test.
trivialfis Oct 4, 2020
e1022f5
Define custom metric.
trivialfis Oct 4, 2020
94995ca
More tests.
trivialfis Oct 4, 2020
dcc5262
Revise monitor.
trivialfis Oct 4, 2020
e1b6f68
Test for early stopping.
trivialfis Oct 4, 2020
888c9cc
Extract dask test.
trivialfis Oct 4, 2020
f6a8a2b
Allreduce.
trivialfis Oct 4, 2020
3129762
Add check point demo.
trivialfis Oct 4, 2020
88d5d9b
Add test for ES custom feval
trivialfis Oct 4, 2020
94f0413
ES custom eval skl.
trivialfis Oct 4, 2020
0a3cca5
Test learning rate scheduling.
trivialfis Oct 4, 2020
13332d7
Lint.
trivialfis Oct 4, 2020
5731802
[ES] Consider specified metric name and data name.
trivialfis Oct 4, 2020
b24fc13
Lint.
trivialfis Oct 4, 2020
5922f2a
Test for customization.
trivialfis Oct 4, 2020
d0e11d7
Basic doc.
trivialfis Oct 4, 2020
4f4ec6c
TODO.
trivialfis Oct 4, 2020
97f678a
Use evals_log instead of actual data.
trivialfis Oct 4, 2020
a2d99b4
Packed booster.
trivialfis Oct 4, 2020
c881417
todos.
trivialfis Oct 4, 2020
899add7
Auto config.
trivialfis Oct 4, 2020
268ab02
Use set.
trivialfis Oct 4, 2020
31403df
Minor cleaning.
trivialfis Oct 4, 2020
4fc2369
Remove redundant aggcv.
trivialfis Oct 4, 2020
c2fa6a1
Correct cv print.
trivialfis Oct 4, 2020
e2aa739
Demo.
trivialfis Oct 4, 2020
bfbe2cf
Test demo.
trivialfis Oct 4, 2020
a746b52
Support old callbacks.
trivialfis Oct 4, 2020
b93a2d7
Remove todo.
trivialfis Oct 4, 2020
c3ad97e
Initial batch of fixes.
trivialfis Oct 4, 2020
3d52ed9
Fix cv.
trivialfis Oct 4, 2020
9dfe80f
Use error.
trivialfis Oct 4, 2020
b41e0da
Fix weird metric name.
trivialfis Oct 4, 2020
8a9ad13
Fix attr.
trivialfis Oct 4, 2020
810be6b
Lint.
trivialfis Oct 4, 2020
14fe528
Legacy callback.
trivialfis Oct 5, 2020
bc77716
Fix dask parameter.
trivialfis Oct 5, 2020
0625e40
Cleanup.
trivialfis Oct 5, 2020
b991092
Fix moved test.
trivialfis Oct 5, 2020
60240f9
Reviewers' comment.
trivialfis Oct 6, 2020
2316f8a
Reviewers' comments.
trivialfis Oct 10, 2020
4a6bb5d
Redundant attribute in demo.
trivialfis Oct 10, 2020
10c705a
Small fixes in doc and parameter naming.
trivialfis Oct 10, 2020
e3ad79a
Fix naming in test.
trivialfis Oct 10, 2020
f8662cf
Pytest deprecated call.
trivialfis Oct 10, 2020
8e45139
Fix typo.
trivialfis Oct 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
130 changes: 130 additions & 0 deletions demo/guide-python/callbacks.py
@@ -0,0 +1,130 @@
'''
Demo for using and defining callback functions.

.. versionadded:: 1.3.0
'''
import xgboost as xgb
import tempfile
import os
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import argparse


class Plotting(xgb.callback.TrainingCallback):
'''Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw.

'''
def __init__(self, rounds):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.rounds = rounds
self.lines = {}
self.fig.show()
self.x = np.linspace(0, self.rounds, self.rounds)
plt.ion()

def _get_key(self, data, metric):
return f'{data}-{metric}'

def after_iteration(self, model, epoch, evals_log):
'''Update the plot.'''
if not self.lines:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key], = self.ax.plot(self.x, expanded, label=key)
self.ax.legend()
else:
# https://pythonspot.com/matplotlib-update-plot/
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key].set_ydata(expanded)
self.fig.canvas.draw()
# False to indicate training should not stop.
return False


def custom_callback():
'''Demo for defining a custom callback function that plots evaluation result during
training.'''
X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)

D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)

num_boost_round = 100
plotting = Plotting(num_boost_round)

# Pass it to the `callbacks` parameter as a list.
xgb.train(
{
'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'gpu_hist'
},
D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=num_boost_round,
callbacks=[plotting])


def check_point_callback():
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
# slow.
rounds = 2

def check(as_pickle):
for i in range(0, 10, rounds):
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl')
else:
path = os.path.join(tmpdir, 'model_' + str(i) + '.json')
assert(os.path.exists(path))

X, y = load_breast_cancer(return_X_y=True)
m = xgb.DMatrix(X, y)
# Check point to a temporary directory for demo
with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check(False)

# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
as_pickle=True,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check(True)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--plot', default=1, type=int)
args = parser.parse_args()

check_point_callback()

if args.plot:
custom_callback()
2 changes: 2 additions & 0 deletions demo/guide-python/data_iterator.py
@@ -1,5 +1,7 @@
'''A demo for defining data iterator.

.. versionadded:: 1.2.0

The demo that defines a customized iterator for passing batches of data into
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for
training. The feature is used primarily designed to reduce the required GPU
Expand Down
59 changes: 59 additions & 0 deletions doc/python/callbacks.rst
@@ -0,0 +1,59 @@
##################
Callback Functions
##################

This document gives a basic walkthrough of callback function used in XGBoost Python
package. In XGBoost 1.3, a new callback interface is designed for Python package, which
provides the flexiblity of designing various extension for training. Also, XGBoost has a
number of pre-defined callbacks for supporting early stopping, checkpoints etc.

#######################
Using builtin callbacks
#######################

By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
``verbose``/``verbose_eval``, when specified the training procedure will define the
corresponding callbacks internally. For example, when ``early_stopping_rounds`` is
specified, ``EarlyStopping`` callback is invoked inside iteration loop. You can also pass
this callback function directly into XGBoost:

.. code-block:: python

D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)

# Define a custom evaluation metric used for early stopping.
def eval_error_metric(predt, dtrain: xgb.DMatrix):
label = dtrain.get_label()
r = np.zeros(predt.shape)
gt = predt > 0.5
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'CustomErr', np.sum(r)

# Specify which dataset and which metric should be used for early stopping.
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
metric_name='CustomErr',
data_name='Train')

booster = xgb.train(
{'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
feval=eval_error_metric,
num_boost_round=1000,
callbacks=[early_stop],
verbose_eval=False)

dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)

##########################
Defining your own callback
##########################

XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
defined callbacks should inherit this class and override corresponding methods. There's a
working example in `demo/guide-python/callbacks.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/callbacks.py>`_
1 change: 1 addition & 0 deletions doc/python/index.rst
Expand Up @@ -11,4 +11,5 @@ Contents
.. toctree::
python_intro
python_api
callbacks
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>