-
-
Notifications
You must be signed in to change notification settings - Fork 385
/
early_stop.py
138 lines (112 loc) · 4.08 KB
/
early_stop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import TYPE_CHECKING
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
class CheckRunCallback(Callback):
"""Executes only a pipeline part from the ``Experiment``."""
def __init__(self, num_batch_steps: int = 3, num_epoch_steps: int = 2):
"""
Args:
num_batch_steps: number of batches to iterate in epoch
num_epoch_steps: number of epoch to perform in a stage
"""
super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
self.num_batch_steps = num_batch_steps
self.num_epoch_steps = num_epoch_steps
def on_epoch_end(self, runner: "IRunner"):
"""Check if iterated specified number of epochs.
Args:
runner: current runner
"""
if runner.epoch >= self.num_epoch_steps:
runner.need_early_stop = True
def on_batch_end(self, runner: "IRunner"):
"""Check if iterated specified number of batches.
Args:
runner: current runner
"""
if runner.loader_batch_step >= self.num_batch_steps:
runner.need_early_stop = True
class EarlyStoppingCallback(Callback):
"""Early exit based on metric.
Example of usage in notebook API:
.. code-block:: python
runner = SupervisedRunner()
runner.train(
...
callbacks=[
...
EarlyStoppingCallback(
patience=5,
metric="my_metric",
minimize=True,
)
...
]
)
...
Example of usage in config API:
.. code-block:: yaml
stages:
...
stage_N:
...
callbacks_params:
...
early_stopping:
callback: EarlyStoppingCallback
# arguments for EarlyStoppingCallback
patience: 5
metric: my_metric
minimize: true
...
"""
def __init__(
self,
patience: int,
metric: str = "loss",
minimize: bool = True,
min_delta: float = 1e-6,
):
"""
Args:
patience: number of epochs with no improvement
after which training will be stopped.
metric: metric name to use for early stopping, default
is ``"loss"``.
minimize: if ``True`` then expected that metric should
decrease and early stopping will be performed only when metric
stops decreasing. If ``False`` then expected
that metric should increase. Default value ``True``.
min_delta: minimum change in the monitored metric
to qualify as an improvement, i.e. an absolute change
of less than min_delta, will count as no improvement,
default value is ``1e-6``.
"""
super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
self.best_score = None
self.metric = metric
self.patience = patience
self.num_bad_epochs = 0
self.is_better = None
if minimize:
self.is_better = lambda score, best: score <= (best - min_delta)
else:
self.is_better = lambda score, best: score >= (best + min_delta)
def on_epoch_end(self, runner: "IRunner") -> None:
"""Check if should be performed early stopping.
Args:
runner: current runner
"""
if runner.stage.startswith("infer"):
return
score = runner.valid_metrics[self.metric]
if self.best_score is None or self.is_better(score, self.best_score):
self.num_bad_epochs = 0
self.best_score = score
else:
self.num_bad_epochs += 1
if self.num_bad_epochs >= self.patience:
print(f"Early stop at {runner.epoch} epoch")
runner.need_early_stop = True
__all__ = ["CheckRunCallback", "EarlyStoppingCallback"]