Skip to content

Commit

Permalink
Rework Python callback functions. (#6199)
Browse files Browse the repository at this point in the history
* Define a new callback interface for Python.
* Deprecate the old callbacks.
* Enable early stopping on dask.
  • Loading branch information
trivialfis committed Oct 10, 2020
1 parent b5b2435 commit ab5b351
Show file tree
Hide file tree
Showing 13 changed files with 1,180 additions and 275 deletions.
130 changes: 130 additions & 0 deletions demo/guide-python/callbacks.py
@@ -0,0 +1,130 @@
'''
Demo for using and defining callback functions.
.. versionadded:: 1.3.0
'''
import xgboost as xgb
import tempfile
import os
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import argparse


class Plotting(xgb.callback.TrainingCallback):
'''Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw.
'''
def __init__(self, rounds):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.rounds = rounds
self.lines = {}
self.fig.show()
self.x = np.linspace(0, self.rounds, self.rounds)
plt.ion()

def _get_key(self, data, metric):
return f'{data}-{metric}'

def after_iteration(self, model, epoch, evals_log):
'''Update the plot.'''
if not self.lines:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key], = self.ax.plot(self.x, expanded, label=key)
self.ax.legend()
else:
# https://pythonspot.com/matplotlib-update-plot/
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key].set_ydata(expanded)
self.fig.canvas.draw()
# False to indicate training should not stop.
return False


def custom_callback():
'''Demo for defining a custom callback function that plots evaluation result during
training.'''
X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)

D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)

num_boost_round = 100
plotting = Plotting(num_boost_round)

# Pass it to the `callbacks` parameter as a list.
xgb.train(
{
'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'gpu_hist'
},
D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=num_boost_round,
callbacks=[plotting])


def check_point_callback():
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
# slow.
rounds = 2

def check(as_pickle):
for i in range(0, 10, rounds):
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl')
else:
path = os.path.join(tmpdir, 'model_' + str(i) + '.json')
assert(os.path.exists(path))

X, y = load_breast_cancer(return_X_y=True)
m = xgb.DMatrix(X, y)
# Check point to a temporary directory for demo
with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check(False)

# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
as_pickle=True,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check(True)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--plot', default=1, type=int)
args = parser.parse_args()

check_point_callback()

if args.plot:
custom_callback()
2 changes: 2 additions & 0 deletions demo/guide-python/data_iterator.py
@@ -1,5 +1,7 @@
'''A demo for defining data iterator.
.. versionadded:: 1.2.0
The demo that defines a customized iterator for passing batches of data into
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for
training. The feature is used primarily designed to reduce the required GPU
Expand Down
59 changes: 59 additions & 0 deletions doc/python/callbacks.rst
@@ -0,0 +1,59 @@
##################
Callback Functions
##################

This document gives a basic walkthrough of callback function used in XGBoost Python
package. In XGBoost 1.3, a new callback interface is designed for Python package, which
provides the flexiblity of designing various extension for training. Also, XGBoost has a
number of pre-defined callbacks for supporting early stopping, checkpoints etc.

#######################
Using builtin callbacks
#######################

By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
``verbose``/``verbose_eval``, when specified the training procedure will define the
corresponding callbacks internally. For example, when ``early_stopping_rounds`` is
specified, ``EarlyStopping`` callback is invoked inside iteration loop. You can also pass
this callback function directly into XGBoost:

.. code-block:: python
D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)
# Define a custom evaluation metric used for early stopping.
def eval_error_metric(predt, dtrain: xgb.DMatrix):
label = dtrain.get_label()
r = np.zeros(predt.shape)
gt = predt > 0.5
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'CustomErr', np.sum(r)
# Specify which dataset and which metric should be used for early stopping.
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
metric_name='CustomErr',
data_name='Train')
booster = xgb.train(
{'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
feval=eval_error_metric,
num_boost_round=1000,
callbacks=[early_stop],
verbose_eval=False)
dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
##########################
Defining your own callback
##########################

XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
defined callbacks should inherit this class and override corresponding methods. There's a
working example in `demo/guide-python/callbacks.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/callbacks.py>`_
1 change: 1 addition & 0 deletions doc/python/index.rst
Expand Up @@ -11,4 +11,5 @@ Contents
.. toctree::
python_intro
python_api
callbacks
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>

0 comments on commit ab5b351

Please sign in to comment.