Skip to content

Commit

Permalink
Typehint for subset of core API. (#7348)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 28, 2021
1 parent 45aef75 commit c676948
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 412 deletions.
197 changes: 116 additions & 81 deletions python-package/xgboost/callback.py
Expand Up @@ -6,17 +6,18 @@
import collections
import os
import pickle
from typing import Callable, List, Optional, Union, Dict, Tuple
from typing import Callable, List, Optional, Union, Dict, Tuple, TypeVar, cast
from typing import Sequence
import numpy

from . import rabit
from .core import Booster, XGBoostError
from .core import Booster, DMatrix, XGBoostError
from .compat import STRING_TYPES


# The new implementation of callback functions.
# Breaking:
# - reset learning rate no longer accepts total boosting rounds
_Score = Union[float, Tuple[float, float]]
_ScoreList = Union[List[float], List[Tuple[float, float]]]


# pylint: disable=unused-argument
class TrainingCallback(ABC):
Expand All @@ -26,9 +27,9 @@ class TrainingCallback(ABC):
'''

EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]
EvalsLog = Dict[str, Dict[str, _ScoreList]]

def __init__(self):
def __init__(self) -> None:
pass

def before_training(self, model):
Expand All @@ -48,35 +49,39 @@ def after_iteration(self, model, epoch: int, evals_log: EvalsLog) -> bool:
return False


def _aggcv(rlist):
# pylint: disable=invalid-name
def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
# pylint: disable=invalid-name, too-many-locals
"""Aggregate cross-validation results.
"""
cvmap = {}
cvmap: Dict[Tuple[int, str], List[float]] = {}
idx = rlist[0].split()[0]
for line in rlist:
arr = line.split()
arr: List[str] = line.split()
assert idx == arr[0]
for metric_idx, it in enumerate(arr[1:]):
if not isinstance(it, STRING_TYPES):
if not isinstance(it, str):
it = it.decode()
k, v = it.split(':')
if (metric_idx, k) not in cvmap:
cvmap[(metric_idx, k)] = []
cvmap[(metric_idx, k)].append(float(v))
msg = idx
results = []
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
v = numpy.array(v)
for (_, name), s in sorted(cvmap.items(), key=lambda x: x[0][0]):
as_arr = numpy.array(s)
if not isinstance(msg, STRING_TYPES):
msg = msg.decode()
mean, std = numpy.mean(v), numpy.std(v)
results.extend([(k, mean, std)])
mean, std = numpy.mean(as_arr), numpy.std(as_arr)
results.extend([(name, mean, std)])
return results


def _allreduce_metric(score):
# allreduce type
_ART = TypeVar("_ART")


def _allreduce_metric(score: _ART) -> _ART:
'''Helper function for computing customized metric in distributed
environment. Not strictly correct as many functions don't use mean value
as final result.
Expand All @@ -89,13 +94,13 @@ def _allreduce_metric(score):
if isinstance(score, tuple): # has mean and stdv
raise ValueError(
'xgboost.cv function should not be used in distributed environment.')
score = numpy.array([score])
score = rabit.allreduce(score, rabit.Op.SUM) / world
return score[0]
arr = numpy.array([score])
arr = rabit.allreduce(arr, rabit.Op.SUM) / world
return arr[0]


class CallbackContainer:
'''A special callback for invoking a list of other callbacks.
'''A special internal callback for invoking a list of other callbacks.
.. versionadded:: 1.3.0
Expand All @@ -105,7 +110,7 @@ class CallbackContainer:

def __init__(
self,
callbacks: List[TrainingCallback],
callbacks: Sequence[TrainingCallback],
metric: Callable = None,
output_margin: bool = True,
is_cv: bool = False
Expand Down Expand Up @@ -146,33 +151,50 @@ def after_training(self, model):
assert isinstance(model, Booster), msg
return model

def before_iteration(self, model, epoch, dtrain, evals) -> bool:
def before_iteration(
self, model, epoch: int, dtrain: DMatrix, evals: List[Tuple[DMatrix, str]]
) -> bool:
'''Function called before training iteration.'''
return any(c.before_iteration(model, epoch, self.history)
for c in self.callbacks)

def _update_history(self, score, epoch):
def _update_history(
self,
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
epoch: int
) -> None:
for d in score:
name, s = d[0], float(d[1])
name: str = d[0]
s: float = d[1]
if self.is_cv:
std = float(d[2])
s = (s, std)
std = float(cast(Tuple[str, float, float], d)[2])
x: _Score = (s, std)
else:
x = s
splited_names = name.split('-')
data_name = splited_names[0]
metric_name = '-'.join(splited_names[1:])
s = _allreduce_metric(s)
if data_name in self.history:
data_history = self.history[data_name]
if metric_name in data_history:
data_history[metric_name].append(s)
else:
data_history[metric_name] = [s]
else:
x = _allreduce_metric(x)
if data_name not in self.history:
self.history[data_name] = collections.OrderedDict()
self.history[data_name][metric_name] = [s]
return False
data_history = self.history[data_name]
if metric_name not in data_history:
data_history[metric_name] = cast(_ScoreList, [])
metric_history = data_history[metric_name]
if self.is_cv:
cast(List[Tuple[float, float]], metric_history).append(
cast(Tuple[float, float], x)
)
else:
cast(List[float], metric_history).append(cast(float, x))

def after_iteration(self, model, epoch, dtrain, evals) -> bool:
def after_iteration(
self,
model,
epoch: int,
dtrain: DMatrix,
evals: Optional[List[Tuple[DMatrix, str]]],
) -> bool:
'''Function called after training iteration.'''
if self.is_cv:
scores = model.eval(epoch, self.metric, self._output_margin)
Expand All @@ -183,18 +205,20 @@ def after_iteration(self, model, epoch, dtrain, evals) -> bool:
evals = [] if evals is None else evals
for _, name in evals:
assert name.find('-') == -1, 'Dataset name should not contain `-`'
score = model.eval_set(evals, epoch, self.metric, self._output_margin)
score = score.split()[1:] # into datasets
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
splited = score.split()[1:] # into datasets
# split up `test-error:0.1234`
score = [tuple(s.split(':')) for s in score]
self._update_history(score, epoch)
metric_score_str = [tuple(s.split(':')) for s in splited]
# convert to float
metric_score = [(n, float(s)) for n, s in metric_score_str]
self._update_history(metric_score, epoch)
ret = any(c.after_iteration(model, epoch, self.history)
for c in self.callbacks)
return ret


class LearningRateScheduler(TrainingCallback):
'''Callback function for scheduling learning rate.
"""Callback function for scheduling learning rate.
.. versionadded:: 1.3.0
Expand All @@ -207,18 +231,24 @@ class LearningRateScheduler(TrainingCallback):
should be a sequence like list or tuple with the same size of boosting
rounds.
'''
def __init__(self, learning_rates) -> None:
assert callable(learning_rates) or \
isinstance(learning_rates, collections.abc.Sequence)
"""

def __init__(
self, learning_rates: Union[Callable[[int], float], Sequence[float]]
) -> None:
assert callable(learning_rates) or isinstance(
learning_rates, collections.abc.Sequence
)
if callable(learning_rates):
self.learning_rates = learning_rates
else:
self.learning_rates = lambda epoch: learning_rates[epoch]
self.learning_rates = lambda epoch: cast(Sequence, learning_rates)[epoch]
super().__init__()

def after_iteration(self, model, epoch, evals_log) -> bool:
model.set_param('learning_rate', self.learning_rates(epoch))
def after_iteration(
self, model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
model.set_param("learning_rate", self.learning_rates(epoch))
return False


Expand All @@ -230,17 +260,17 @@ class EarlyStopping(TrainingCallback):
Parameters
----------
rounds
rounds :
Early stopping rounds.
metric_name
metric_name :
Name of metric that is used for early stopping.
data_name
data_name :
Name of dataset that is used for early stopping.
maximize
maximize :
Whether to maximize evaluation metric. None means auto (discouraged).
save_best
save_best :
Whether training should return the best model or the last model.
min_delta
min_delta :
Minimum absolute change in score to be qualified as an improvement.
.. versionadded:: 1.5.0
Expand Down Expand Up @@ -279,8 +309,6 @@ def __init__(
if self._min_delta < 0:
raise ValueError("min_delta must be greater or equal to 0.")

self.improve_op = None

self.current_rounds: int = 0
self.best_scores: dict = {}
self.starting_round: int = 0
Expand All @@ -290,16 +318,18 @@ def before_training(self, model):
self.starting_round = model.num_boosted_rounds()
return model

def _update_rounds(self, score, name, metric, model, epoch) -> bool:
def get_s(x):
def _update_rounds(
self, score: _Score, name: str, metric: str, model, epoch: int
) -> bool:
def get_s(x: _Score) -> float:
"""get score if it's cross validation history."""
return x[0] if isinstance(x, tuple) else x

def maximize(new, best):
def maximize(new: _Score, best: _Score) -> bool:
"""New score should be greater than the old one."""
return numpy.greater(get_s(new) - self._min_delta, get_s(best))

def minimize(new, best):
def minimize(new: _Score, best: _Score) -> bool:
"""New score should be smaller than the old one."""
return numpy.greater(get_s(best) - self._min_delta, get_s(new))

Expand All @@ -314,25 +344,25 @@ def minimize(new, best):
self.maximize = False

if self.maximize:
self.improve_op = maximize
improve_op = maximize
else:
self.improve_op = minimize
improve_op = minimize

assert self.improve_op
assert improve_op

if not self.stopping_history: # First round
self.current_rounds = 0
self.stopping_history[name] = {}
self.stopping_history[name][metric] = [score]
self.stopping_history[name][metric] = cast(_ScoreList, [score])
self.best_scores[name] = {}
self.best_scores[name][metric] = [score]
model.set_attr(best_score=str(score), best_iteration=str(epoch))
elif not self.improve_op(score, self.best_scores[name][metric][-1]):
elif not improve_op(score, self.best_scores[name][metric][-1]):
# Not improved
self.stopping_history[name][metric].append(score)
self.stopping_history[name][metric].append(score) # type: ignore
self.current_rounds += 1
else: # Improved
self.stopping_history[name][metric].append(score)
self.stopping_history[name][metric].append(score) # type: ignore
self.best_scores[name][metric].append(score)
record = self.stopping_history[name][metric][-1]
model.set_attr(best_score=str(record), best_iteration=str(epoch))
Expand Down Expand Up @@ -390,16 +420,16 @@ class EvaluationMonitor(TrainingCallback):
Parameters
----------
metric : callable
metric :
Extra user defined metric.
rank : int
rank :
Which worker should be used for printing the result.
period : int
period :
How many epoches between printing.
show_stdv : bool
show_stdv :
Used in cv to show standard deviation. Users should not specify it.
'''
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
self.printer_rank = rank
self.show_stdv = show_stdv
self.period = period
Expand Down Expand Up @@ -457,22 +487,27 @@ class TrainingCheckPoint(TrainingCallback):
Parameters
----------
directory : os.PathLike
directory :
Output model directory.
name : str
name :
pattern of output model file. Models will be saved as name_0.json, name_1.json,
name_2.json ....
as_pickle : boolean
as_pickle :
When set to Ture, all training parameters will be saved in pickle format, instead
of saving only the model.
iterations : int
iterations :
Interval of checkpointing. Checkpointing is slow so setting a larger number can
reduce performance hit.
'''
def __init__(self, directory: os.PathLike, name: str = 'model',
as_pickle=False, iterations: int = 100):
self._path = directory
def __init__(
self,
directory: Union[str, os.PathLike],
name: str = 'model',
as_pickle: bool = False,
iterations: int = 100
) -> None:
self._path = os.fspath(directory)
self._name = name
self._as_pickle = as_pickle
self._iterations = iterations
Expand Down

0 comments on commit c676948

Please sign in to comment.