New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 2】12、为 Paddle 新增 OneCycleLR 优化调度器 #41825
Changes from 10 commits
702f47c
3537536
f8011a0
9be4aeb
2df7465
8bb405a
8ad14b0
de82b7b
af0420f
6d48016
6c76ab5
36864a0
b7c4a4b
348932b
98f9a9e
70f97a8
ebb04e2
a58cab2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -321,6 +321,74 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False): | |
return learning_rate * math.pow(gamma, epoch_num // step_size) | ||
|
||
|
||
def one_cycle_lr(epoch_num, | ||
max_learning_rate, | ||
total_steps=None, | ||
epochs=None, | ||
steps_per_epoch=None, | ||
pct_start=0.3, | ||
anneal_strategy='cos', | ||
divide_factor=25., | ||
final_divide_factor=1e4, | ||
three_phase=False, | ||
verbose=False): | ||
total_steps = epochs * steps_per_epoch if total_steps is None else total_steps | ||
initial_lr = max_learning_rate / divide_factor | ||
min_lr = initial_lr / final_divide_factor | ||
if three_phase: | ||
_end_steps = [ | ||
float(pct_start * total_steps) - 1, | ||
float(2 * pct_start * total_steps) - 2, total_steps - 1 | ||
] | ||
_schedule_phases = [ | ||
{ | ||
'start_lr': initial_lr, | ||
'end_lr': max_learning_rate, | ||
}, | ||
{ | ||
'start_lr': max_learning_rate, | ||
'end_lr': initial_lr, | ||
}, | ||
{ | ||
'start_lr': initial_lr, | ||
'end_lr': min_lr, | ||
}, | ||
] | ||
else: | ||
_end_steps = [float(pct_start * total_steps) - 1, total_steps - 1] | ||
_schedule_phases = [ | ||
{ | ||
'start_lr': initial_lr, | ||
'end_lr': max_learning_rate, | ||
}, | ||
{ | ||
'start_lr': max_learning_rate, | ||
'end_lr': min_lr, | ||
}, | ||
] | ||
|
||
if anneal_strategy == 'cos': | ||
|
||
def anneal_func(start, end, pct): | ||
cos_out = math.cos(math.pi * pct) + 1 | ||
return end + (start - end) / 2.0 * cos_out | ||
else: | ||
|
||
def anneal_func(start, end, pct): | ||
return (end - start) * pct + start | ||
|
||
start_step = 0 | ||
for i, phase in enumerate(_schedule_phases): | ||
end_step = _end_steps[i] | ||
if epoch_num <= end_step or i == len(_schedule_phases) - 1: | ||
pct = (epoch_num - start_step) / (end_step - start_step) | ||
computed_lr = anneal_func(phase['start_lr'], phase['end_lr'], pct) | ||
break | ||
start_step = end_step | ||
|
||
return computed_lr | ||
|
||
|
||
class TestLRScheduler(unittest.TestCase): | ||
def _test_static(self, python_func, paddle_api, kwarg, place): | ||
scheduler = paddle_api(**kwarg) | ||
|
@@ -467,6 +535,25 @@ def test_scheduler(self): | |
with self.assertRaises(ValueError): | ||
paddle.optimizer.lr.MultiStepDecay( | ||
learning_rate=0.5, milestones=[1, 2, 3], gamma=2) | ||
with self.assertRaises(TypeError): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加一下注释吧,每一项测试是针对什么异常输入情况 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
paddle.optimizer.lr.OneCycleLR( | ||
max_learning_rate='test', total_steps=20) | ||
with self.assertRaises(ValueError): | ||
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1) | ||
with self.assertRaises(TypeError): | ||
paddle.optimizer.lr.OneCycleLR( | ||
max_learning_rate=0.1, total_steps='test') | ||
with self.assertRaises(ValueError): | ||
paddle.optimizer.lr.OneCycleLR( | ||
max_learning_rate=0.1, total_steps=-10) | ||
with self.assertRaises(TypeError): | ||
paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, epochs='test') | ||
with self.assertRaises(TypeError): | ||
paddle.optimizer.lr.OneCycleLR( | ||
max_learning_rate=0.1, epochs=1, steps_per_epoch='t') | ||
with self.assertRaises(ValueError): | ||
paddle.optimizer.lr.OneCycleLR( | ||
max_learning_rate=0.1, total_steps=20, anneal_strategy='test') | ||
|
||
func_api_kwargs = [(noam_lr, paddle.optimizer.lr.NoamDecay, { | ||
"d_model": 0.01, | ||
|
@@ -527,6 +614,39 @@ def test_scheduler(self): | |
"learning_rate": 0.5, | ||
"T_max": 10, | ||
"verbose": False | ||
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { | ||
"max_learning_rate": 0.5, | ||
"total_steps": 20, | ||
"pct_start": 0.3, | ||
"anneal_strategy": 'cos', | ||
"divide_factor": 25., | ||
"final_divide_factor": 1e4, | ||
"three_phase": False, | ||
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { | ||
"max_learning_rate": 0.5, | ||
"epochs": 10, | ||
"steps_per_epoch": 2, | ||
"pct_start": 0.2, | ||
"anneal_strategy": 'linear', | ||
"divide_factor": 20., | ||
"final_divide_factor": 1000, | ||
"three_phase": False, | ||
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { | ||
"max_learning_rate": 1, | ||
"total_steps": 20, | ||
"pct_start": 0.4, | ||
"anneal_strategy": 'cos', | ||
"divide_factor": 15., | ||
"final_divide_factor": 100, | ||
"three_phase": True, | ||
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { | ||
"max_learning_rate": 0.5, | ||
"total_steps": 40, | ||
"pct_start": 0.5, | ||
"anneal_strategy": 'linear', | ||
"divide_factor": 5., | ||
"final_divide_factor": 50, | ||
"three_phase": True, | ||
})] | ||
|
||
for python_func, paddle_api, kwarg in func_api_kwargs: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,8 @@ | |
'LambdaDecay', | ||
'ReduceOnPlateau', | ||
'CosineAnnealingDecay', | ||
'MultiplicativeDecay' | ||
'MultiplicativeDecay', | ||
'OneCycleLR' | ||
] | ||
|
||
|
||
|
@@ -1591,3 +1592,222 @@ def get_lr(self): | |
for epoch in range(1, self.last_epoch + 1): | ||
cur_lr = cur_lr * self.lr_lambda(epoch) | ||
return cur_lr | ||
|
||
|
||
class OneCycleLR(LRScheduler): | ||
r""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果代码是参考别人的实现,需要遵循开源协议,添加说明引用来源 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 了解 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
Sets the learning rate according to the one cycle learning rate scheduler. | ||
The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then | ||
from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的描述没有什么可以修改的方案,其本身已经足够简洁明了。 |
||
It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_. | ||
|
||
Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle, | ||
which claims that “unpublished work has shown even better results by using only two phases”. | ||
Set ``three_phase=True``, if you want the behaviour of this scheduler to be consistent with the paper. | ||
|
||
Also note that you should update learning rate each step. | ||
|
||
This implementation was adapted from `there <https://github.com/pytorch/pytorch/blob/e5ee6f5cf714812283ff4e49362fbdf37fbd8ea9/torch/optim/lr_scheduler.py#L1346>`_. | ||
|
||
Args: | ||
max_learning_rate (float): Upper boundary of learning rate during training. | ||
Functionally, it defines the initial learning rate and the minimum learning rate by ``divide_factor`` and | ||
``final_divide_factor`` respectively. | ||
total_steps (int, optional): Number of total training steps. | ||
Note that one of total_steps and (epochs, steps_per_epoch) must be specified. | ||
If ``total_steps`` is not specified, it will be determined by ``epochs`` and ``steps_per_epoch``. Default: None. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Default: None, means xxx. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
这部分应该表明了默认的情况,后面就不再说明了。 |
||
epochs (int, optional): Number of total training epochs. Default: None. | ||
steps_per_epoch (int, optional): Number of training steps for each epoch. Default: None. | ||
pct_start (float): The percentage of total steps, which used to increasing learning rate. Default: 0.3. | ||
anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, | ||
'linear' for linear annealing. Default: 'cos'. | ||
divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. max_lr -> max_learning_rate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
final_divide_factor (float, optional): Minimum learning rate will be determined by minimum = max_lr/final_divide_factor. Default: 1e4. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 中英文公式一致 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
three_phase (bool, optional): Whether to use three phase. | ||
If ``True``: | ||
1. The learning rate will first increase from initial learning rate to maximum learning rate. | ||
2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase. | ||
3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate. | ||
If ``False``: | ||
1. The learning rate will increase to maximum learning rate. | ||
2. Then it will directly decrease to minimum learning rate. | ||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. | ||
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . | ||
|
||
Returns: | ||
``OneCycleLR`` instance to schedule learning rate. | ||
|
||
Examples: | ||
.. code-block:: python | ||
import paddle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import paddle 必须要和上面空一行,否则会有格式问题; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
import numpy as np | ||
# train on default dynamic graph mode | ||
linear = paddle.nn.Linear(10, 10) | ||
scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) | ||
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) | ||
for epoch in range(5): | ||
for batch_id in range(20): | ||
x = paddle.uniform([10, 10]) | ||
out = linear(x) | ||
loss = paddle.mean(out) | ||
loss.backward() | ||
sgd.step() | ||
sgd.clear_gradients() | ||
scheduler.step() # You should update learning rate each step | ||
# train on static graph mode | ||
paddle.enable_static() | ||
main_prog = paddle.static.Program() | ||
start_prog = paddle.static.Program() | ||
with paddle.static.program_guard(main_prog, start_prog): | ||
x = paddle.static.data(name='x', shape=[None, 4, 5]) | ||
y = paddle.static.data(name='y', shape=[None, 4, 5]) | ||
z = paddle.static.nn.fc(x, 100) | ||
loss = paddle.mean(z) | ||
scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) | ||
sgd = paddle.optimizer.SGD(learning_rate=scheduler) | ||
sgd.minimize(loss) | ||
exe = paddle.static.Executor() | ||
exe.run(start_prog) | ||
for epoch in range(5): | ||
for batch_id in range(20): | ||
out = exe.run( | ||
main_prog, | ||
feed={ | ||
'x': np.random.randn(3, 4, 5).astype('float32'), | ||
'y': np.random.randn(3, 4, 5).astype('float32') | ||
}, | ||
fetch_list=loss.name) | ||
scheduler.step() # You should update learning rate each step | ||
""" | ||
|
||
def __init__(self, | ||
max_learning_rate, | ||
total_steps=None, | ||
epochs=None, | ||
steps_per_epoch=None, | ||
pct_start=0.3, | ||
anneal_strategy='cos', | ||
divide_factor=25., | ||
final_divide_factor=1e4, | ||
three_phase=False, | ||
last_epoch=-1, | ||
verbose=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are total 11 parameters of OneCycleLR API in RFC, but only 9 parameters here, which is right? RFC and code must be consistency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here I commit a new pull request to modify RFC file. |
||
# Check type of max_learning_rate | ||
if not isinstance(max_learning_rate, (float, int)): | ||
raise TypeError( | ||
"The type of learning rate must be float, but received {}". | ||
format(type(max_learning_rate))) | ||
# Check type and value of total_steps | ||
if total_steps is None and epochs is None and steps_per_epoch is None: | ||
raise ValueError( | ||
"either total_steps or (epochs, steps_per_epoch) must be specified" | ||
) | ||
elif total_steps is not None: | ||
if not isinstance(total_steps, int): | ||
raise TypeError("'total_step' must be 'int', but received {}". | ||
format(type(total_steps))) | ||
if total_steps <= 0: | ||
raise ValueError("'total_step' must be a positive integer.") | ||
self.total_steps = total_steps | ||
else: | ||
# Check type and value of epochs and steps_per_epochs | ||
if not isinstance(epochs, int): | ||
raise TypeError("'epochs' must be 'int', but received {}". | ||
format(type(epochs))) | ||
if not isinstance(steps_per_epoch, int): | ||
raise TypeError( | ||
"'steps_per_epoch', must be 'int', but received {}".format( | ||
type(steps_per_epoch))) | ||
if epochs < 0: | ||
raise ValueError("'epochs' must be a positive integer.") | ||
if steps_per_epoch < 0: | ||
raise ValueError( | ||
"'steps_per_epoch' must be a positive integer.") | ||
self.total_steps = epochs * steps_per_epoch | ||
# Check type and value of pac_start | ||
if not isinstance(pct_start, float): | ||
raise TypeError("'pct_start' must be 'float', but received {}". | ||
format(type(pct_start))) | ||
if pct_start < 0 or pct_start > 1: | ||
raise ValueError( | ||
"'pct_start' must be between 0 and 1, but received {}".format( | ||
pct_start)) | ||
|
||
max_lr = max_learning_rate | ||
initial_lr = max_lr / divide_factor | ||
min_lr = initial_lr / final_divide_factor | ||
|
||
if three_phase: | ||
self._end_steps = [ | ||
float(pct_start * self.total_steps) - 1, | ||
float(2 * pct_start * self.total_steps) - 2, | ||
self.total_steps - 1 | ||
] | ||
self._schedule_phases = [ | ||
{ | ||
'start_lr': initial_lr, | ||
'end_lr': max_lr, | ||
}, | ||
{ | ||
'start_lr': max_lr, | ||
'end_lr': initial_lr, | ||
}, | ||
{ | ||
'start_lr': initial_lr, | ||
'end_lr': min_lr, | ||
}, | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里使用self._start_steps的元素进行相减是因为写计算表达式会因为浮点运算不精确而造成误差 |
||
else: | ||
self._end_steps = [ | ||
float(pct_start * self.total_steps) - 1, self.total_steps - 1 | ||
] | ||
self._schedule_phases = [ | ||
{ | ||
'start_lr': initial_lr, | ||
'end_lr': max_lr, | ||
}, | ||
{ | ||
'start_lr': max_lr, | ||
'end_lr': min_lr, | ||
}, | ||
] | ||
|
||
# Validate anneal_strategy | ||
if anneal_strategy not in ['cos', 'linear']: | ||
raise ValueError( | ||
"'anneal_strategy' must by one of 'cos' or 'linear', but received {}". | ||
format(anneal_strategy)) | ||
elif anneal_strategy == 'cos': | ||
self.anneal_func = self._annealing_cos | ||
elif anneal_strategy == 'linear': | ||
self.anneal_func = self._annealing_linear | ||
|
||
super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose) | ||
|
||
def _annealing_cos(self, start, end, pct): | ||
cos_out = math.cos(math.pi * pct) + 1 | ||
return end + (start - end) / 2.0 * cos_out | ||
|
||
def _annealing_linear(self, start, end, pct): | ||
return (end - start) * pct + start | ||
|
||
def get_lr(self): | ||
step_num = self.last_epoch | ||
|
||
if step_num > self.total_steps: | ||
raise ValueError( | ||
"Tried to step {} times. The specified number of total steps is {}" | ||
.format(step_num + 1, self.total_steps)) | ||
|
||
start_step = 0 | ||
for i, phase in enumerate(self._schedule_phases): | ||
end_step = self._end_steps[i] | ||
if step_num <= end_step or i == len(self._schedule_phases) - 1: | ||
pct = (step_num - start_step) / (end_step - start_step) | ||
computed_lr = self.anneal_func(phase['start_lr'], | ||
phase['end_lr'], pct) | ||
break | ||
start_step = end_step | ||
|
||
return computed_lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测代码部分没有作过多修改,仍是之前的逻辑