Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rework Python callback functions. (#6199)
* Define a new callback interface for Python. * Deprecate the old callbacks. * Enable early stopping on dask.
- Loading branch information
1 parent
b5b2435
commit ab5b351
Showing
13 changed files
with
1,180 additions
and
275 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
''' | ||
Demo for using and defining callback functions. | ||
.. versionadded:: 1.3.0 | ||
''' | ||
import xgboost as xgb | ||
import tempfile | ||
import os | ||
import numpy as np | ||
from sklearn.datasets import load_breast_cancer | ||
from sklearn.model_selection import train_test_split | ||
from matplotlib import pyplot as plt | ||
import argparse | ||
|
||
|
||
class Plotting(xgb.callback.TrainingCallback): | ||
'''Plot evaluation result during training. Only for demonstration purpose as it's quite | ||
slow to draw. | ||
''' | ||
def __init__(self, rounds): | ||
self.fig = plt.figure() | ||
self.ax = self.fig.add_subplot(111) | ||
self.rounds = rounds | ||
self.lines = {} | ||
self.fig.show() | ||
self.x = np.linspace(0, self.rounds, self.rounds) | ||
plt.ion() | ||
|
||
def _get_key(self, data, metric): | ||
return f'{data}-{metric}' | ||
|
||
def after_iteration(self, model, epoch, evals_log): | ||
'''Update the plot.''' | ||
if not self.lines: | ||
for data, metric in evals_log.items(): | ||
for metric_name, log in metric.items(): | ||
key = self._get_key(data, metric_name) | ||
expanded = log + [0] * (self.rounds - len(log)) | ||
self.lines[key], = self.ax.plot(self.x, expanded, label=key) | ||
self.ax.legend() | ||
else: | ||
# https://pythonspot.com/matplotlib-update-plot/ | ||
for data, metric in evals_log.items(): | ||
for metric_name, log in metric.items(): | ||
key = self._get_key(data, metric_name) | ||
expanded = log + [0] * (self.rounds - len(log)) | ||
self.lines[key].set_ydata(expanded) | ||
self.fig.canvas.draw() | ||
# False to indicate training should not stop. | ||
return False | ||
|
||
|
||
def custom_callback(): | ||
'''Demo for defining a custom callback function that plots evaluation result during | ||
training.''' | ||
X, y = load_breast_cancer(return_X_y=True) | ||
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0) | ||
|
||
D_train = xgb.DMatrix(X_train, y_train) | ||
D_valid = xgb.DMatrix(X_valid, y_valid) | ||
|
||
num_boost_round = 100 | ||
plotting = Plotting(num_boost_round) | ||
|
||
# Pass it to the `callbacks` parameter as a list. | ||
xgb.train( | ||
{ | ||
'objective': 'binary:logistic', | ||
'eval_metric': ['error', 'rmse'], | ||
'tree_method': 'gpu_hist' | ||
}, | ||
D_train, | ||
evals=[(D_train, 'Train'), (D_valid, 'Valid')], | ||
num_boost_round=num_boost_round, | ||
callbacks=[plotting]) | ||
|
||
|
||
def check_point_callback(): | ||
# only for demo, set a larger value (like 100) in practice as checkpointing is quite | ||
# slow. | ||
rounds = 2 | ||
|
||
def check(as_pickle): | ||
for i in range(0, 10, rounds): | ||
if i == 0: | ||
continue | ||
if as_pickle: | ||
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl') | ||
else: | ||
path = os.path.join(tmpdir, 'model_' + str(i) + '.json') | ||
assert(os.path.exists(path)) | ||
|
||
X, y = load_breast_cancer(return_X_y=True) | ||
m = xgb.DMatrix(X, y) | ||
# Check point to a temporary directory for demo | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
# Use callback class from xgboost.callback | ||
# Feel free to subclass/customize it to suit your need. | ||
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir, | ||
iterations=rounds, | ||
name='model') | ||
xgb.train({'objective': 'binary:logistic'}, m, | ||
num_boost_round=10, | ||
verbose_eval=False, | ||
callbacks=[check_point]) | ||
check(False) | ||
|
||
# This version of checkpoint saves everything including parameters and | ||
# model. See: doc/tutorials/saving_model.rst | ||
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir, | ||
iterations=rounds, | ||
as_pickle=True, | ||
name='model') | ||
xgb.train({'objective': 'binary:logistic'}, m, | ||
num_boost_round=10, | ||
verbose_eval=False, | ||
callbacks=[check_point]) | ||
check(True) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--plot', default=1, type=int) | ||
args = parser.parse_args() | ||
|
||
check_point_callback() | ||
|
||
if args.plot: | ||
custom_callback() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
################## | ||
Callback Functions | ||
################## | ||
|
||
This document gives a basic walkthrough of callback function used in XGBoost Python | ||
package. In XGBoost 1.3, a new callback interface is designed for Python package, which | ||
provides the flexiblity of designing various extension for training. Also, XGBoost has a | ||
number of pre-defined callbacks for supporting early stopping, checkpoints etc. | ||
|
||
####################### | ||
Using builtin callbacks | ||
####################### | ||
|
||
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and | ||
``verbose``/``verbose_eval``, when specified the training procedure will define the | ||
corresponding callbacks internally. For example, when ``early_stopping_rounds`` is | ||
specified, ``EarlyStopping`` callback is invoked inside iteration loop. You can also pass | ||
this callback function directly into XGBoost: | ||
|
||
.. code-block:: python | ||
D_train = xgb.DMatrix(X_train, y_train) | ||
D_valid = xgb.DMatrix(X_valid, y_valid) | ||
# Define a custom evaluation metric used for early stopping. | ||
def eval_error_metric(predt, dtrain: xgb.DMatrix): | ||
label = dtrain.get_label() | ||
r = np.zeros(predt.shape) | ||
gt = predt > 0.5 | ||
r[gt] = 1 - label[gt] | ||
le = predt <= 0.5 | ||
r[le] = label[le] | ||
return 'CustomErr', np.sum(r) | ||
# Specify which dataset and which metric should be used for early stopping. | ||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, | ||
metric_name='CustomErr', | ||
data_name='Train') | ||
booster = xgb.train( | ||
{'objective': 'binary:logistic', | ||
'eval_metric': ['error', 'rmse'], | ||
'tree_method': 'hist'}, D_train, | ||
evals=[(D_train, 'Train'), (D_valid, 'Valid')], | ||
feval=eval_error_metric, | ||
num_boost_round=1000, | ||
callbacks=[early_stop], | ||
verbose_eval=False) | ||
dump = booster.get_dump(dump_format='json') | ||
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump) | ||
########################## | ||
Defining your own callback | ||
########################## | ||
|
||
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user | ||
defined callbacks should inherit this class and override corresponding methods. There's a | ||
working example in `demo/guide-python/callbacks.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/callbacks.py>`_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.