Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】12、为 Paddle 新增 OneCycleLR 优化调度器 #41825

Merged
merged 18 commits into from May 16, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
118 changes: 118 additions & 0 deletions python/paddle/fluid/tests/unittests/test_lr_scheduler.py
Expand Up @@ -321,6 +321,70 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
return learning_rate * math.pow(gamma, epoch_num // step_size)


def one_cycle_lr(epoch_num,
learning_rate,
total_steps,
scale_factor=25,
end_lr=0.0001,
phase_pct=0.3,
anneal_strategy='cos',
three_phase=False,
verbose=False):
max_lr = learning_rate * scale_factor
if three_phase:
_end_steps = [
float(phase_pct * total_steps) - 1,
float(2 * phase_pct * total_steps) - 2, total_steps - 1
]
_schedule_phases = [
{
'start_lr': learning_rate,
'end_lr': max_lr,
},
{
'start_lr': max_lr,
'end_lr': learning_rate,
},
{
'start_lr': learning_rate,
'end_lr': end_lr,
},
]
else:
_end_steps = [float(phase_pct * total_steps) - 1, total_steps - 1]
_schedule_phases = [
{
'start_lr': learning_rate,
'end_lr': max_lr,
},
{
'start_lr': max_lr,
'end_lr': end_lr,
},
]

if anneal_strategy == 'cos':

def anneal_func(start, end, pct):
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
else:

def anneal_func(start, end, pct):
return (end - start) * pct + start

start_step = 0
for i, phase in enumerate(_schedule_phases):
end_step = _end_steps[i]
if epoch_num <= end_step or i == len(_schedule_phases) - 1:
pct = (epoch_num - start_step) / (end_step - start_step)
computed_lr = anneal_func(phase['start_lr'], phase['end_lr'], pct)
break
start_step = end_step

return computed_lr
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测代码部分没有作过多修改,仍是之前的逻辑



class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
scheduler = paddle_api(**kwarg)
Expand Down Expand Up @@ -467,6 +531,28 @@ def test_scheduler(self):
with self.assertRaises(ValueError):
paddle.optimizer.lr.MultiStepDecay(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
with self.assertRaises(TypeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一下注释吧,每一项测试是针对什么异常输入情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

paddle.optimizer.lr.OneCycleLR(learning_rate='test', total_steps=20)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps=20, end_lr='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps=20, end_lr=-1)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(learning_rate=0.1, total_steps=-10)
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps=20, anneal_strategy='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1,
total_steps=20,
phase_pct=0.6,
three_phase=True)

func_api_kwargs = [(noam_lr, paddle.optimizer.lr.NoamDecay, {
"d_model": 0.01,
Expand Down Expand Up @@ -527,6 +613,38 @@ def test_scheduler(self):
"learning_rate": 0.5,
"T_max": 10,
"verbose": False
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.1,
"total_steps": 20,
"scale_factor": 5,
"end_lr": 0.0001,
"anneal_strategy": 'cos',
"phase_pct": 0.3,
"three_phase": False,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.5,
"total_steps": 20,
"scale_factor": 10,
"end_lr": 0.001,
"anneal_strategy": 'linear',
"phase_pct": 0.4,
"three_phase": False,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.1,
"total_steps": 20,
"scale_factor": 9,
"end_lr": 0.0001,
"anneal_strategy": 'cos',
"phase_pct": 0.3,
"three_phase": True,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.3,
"total_steps": 20,
"scale_factor": 25,
"end_lr": 0.0005,
"anneal_strategy": 'linear',
"phase_pct": 0.2,
"three_phase": True,
})]

for python_func, paddle_api, kwarg in func_api_kwargs:
Expand Down
190 changes: 189 additions & 1 deletion python/paddle/optimizer/lr.py
Expand Up @@ -33,7 +33,8 @@
'LambdaDecay',
'ReduceOnPlateau',
'CosineAnnealingDecay',
'MultiplicativeDecay'
'MultiplicativeDecay',
'OneCycleLR'
]


Expand Down Expand Up @@ -1591,3 +1592,190 @@ def get_lr(self):
for epoch in range(1, self.last_epoch + 1):
cur_lr = cur_lr * self.lr_lambda(epoch)
return cur_lr


class OneCycleLR(LRScheduler):
r"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果代码是参考别人的实现,需要遵循开源协议,添加说明引用来源

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

了解

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Sets the learning rate according to the one cycle learning rate scheduler.
The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的描述没有什么可以修改的方案,其本身已经足够简洁明了。

It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.

Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
which claims that “unpublished work has shown even better results by using only two phases”.
If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .

Also note that you should update learning rate each step.

Args:
learning_rate (float): The initial learning rate. It is a python float number.
Functionally, it defines the maximum learning rate by ``scale_factor`` .
total_steps (int): Number of total training steps.
scale_factor (float): Maximum learning rate will be determined by maximum_lr = learning_rate * scale_factor. Default: 25.
end_lr (float, optional): The minimum learning rate of schedule, it should be much less than initial learning rate.
phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'.
three_phase (bool, optional): Whether to use three phase.
If ``True``:
1. The learning rate will first increase from initial learning rate to maximum learning rate.
2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
If ``False``:
1. The learning rate will increase to maximum learning rate.
2. Then it will directly decrease to minimum learning rate.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

Returns:
``OneCycleLR`` instance to schedule learning rate.

Examples:
.. code-block:: python

import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import paddle 必须要和上面空一行,否则会有格式问题;
示例代码整体注意增加一些空行,保证阅读体验~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

import numpy as np

# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(5):
for batch_id in range(20):
x = paddle.uniform([10, 10])
out = linear(x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_gradients()
scheduler.step() # You should update learning rate each step

# train on static graph mode
paddle.enable_static()
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[None, 4, 5])
y = paddle.static.data(name='y', shape=[None, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)

exe = paddle.static.Executor()
exe.run(start_prog)
for epoch in range(5):
for batch_id in range(20):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=loss.name)
scheduler.step() # You should update learning rate each step
"""

def __init__(self,
learning_rate,
total_steps,
scale_factor=25,
end_lr=0.0001,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议用全称吧,因为前面max_learning_rate用的是全称,保持一致
end_lr -> end_learning_rate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改~
这里使用end_lr主要是因为注意到前面的一些调度器如PolynomialDecay、PolynomialDecay等也都使用了end_lr

phase_pct=0.3,
anneal_strategy='cos',
three_phase=False,
last_epoch=-1,
verbose=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are total 11 parameters of OneCycleLR API in RFC, but only 9 parameters here, which is right? RFC and code must be consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I commit a new pull request to modify RFC file.

# Check type and value of end_lr
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数部分将max_learning_rate更改为learning_rate与paddle中其他调度器对齐;
去除了额外的epochs和steps_per_epoch,实际上用户调用时可直接使用
OneCycleLR(learning_rate=0.1, total_steps=epochs*steps_per_epoch)
max_lr将由参数scale_factor推断而来;
min_lr则直接接收一个浮点数。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learning_rate的参数不需要修改,保持原来的就好。
这里涉及一个使用中关注点的问题,使用OneCycleLR其实对学习率的设置最关注的是max_lr,这个值直接影响模型训练效果和是否收敛,是比较明确的。初始学习率和最终学习率一般只是表示一个很小的值,并不具有精确的意义。所以仍然建议参数中使用原来的max_learning_rate。
这里再改一下吧,其他部分的修改我觉得没有问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于别人好的设计也是可以借鉴的,并不要求全盘否定。主要实现代码和描述独立实现就好。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的~已更新

if not isinstance(end_lr, (float, int)):
raise TypeError(
"'total_step' must be 'float' or 'int', but received {}".format(
type(total_steps)))
if end_lr < 0:
raise ValueError("'end_lr' must be a positive integer.")

# Check type and value of total_steps
if not isinstance(total_steps, int):
raise TypeError("'total_step' must be 'int', but received {}".
format(type(total_steps)))
if total_steps <= 0:
raise ValueError("'total_step' must be a positive integer.")
self.total_steps = total_steps

# Check type and value of pac_start
if not isinstance(phase_pct, float):
raise TypeError("'phase_pct' must be 'float', but received {}".
format(type(phase_pct)))
if phase_pct < 0 or phase_pct > 1:
raise ValueError(
"'phase_pct' must be between 0 and 1, but received {}".format(
phase_pct))

max_lr = learning_rate * scale_factor
min_lr = float(end_lr)

if three_phase:
if phase_pct >= 0.5:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了在三阶段情况下对phase_pct不能超过0.5的异常检查

raise ValueError(
"When three_phase is True, 'phase_pct' must be smaller than 0.5"
)
self._start_steps = [
0,
phase_pct * self.total_steps - 1,
2 * phase_pct * self.total_steps - 2,
self.total_steps - 1,
]
self._steps_size = [
self._start_steps[1] - self._start_steps[0],
self._start_steps[2] - self._start_steps[1],
self._start_steps[3] - self._start_steps[2],
self._start_steps[3] - self._start_steps[2],
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用self._start_steps的元素进行相减是因为写计算表达式会因为浮点运算不精确而造成误差

self._lr_config = [learning_rate, max_lr, learning_rate, min_lr]
else:
self._start_steps = [
0, phase_pct * self.total_steps - 1, self.total_steps - 1
]
self._steps_size = [
self._start_steps[1] - self._start_steps[0],
self._start_steps[2] - self._start_steps[1],
self._start_steps[2] - self._start_steps[1],
]
self._lr_config = [learning_rate, max_lr, min_lr]

# Check anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError(
"'anneal_strategy' must by one of 'cos' or 'linear', but received {}".
format(anneal_strategy))
elif anneal_strategy == 'cos':
self.anneal_func = self._cos_annealing
elif anneal_strategy == 'linear':
self.anneal_func = self._linear_annealing

super(OneCycleLR, self).__init__(learning_rate, last_epoch, verbose)

def _cos_annealing(self, start_lr, end_lr, pct):
cos_out = math.cos(math.pi * pct) + 1
return end_lr + (start_lr - end_lr) / 2.0 * cos_out

def _linear_annealing(self, start_lr, end_lr, pct):
return (end_lr - start_lr) * pct + start_lr

def get_lr(self):
current_step = self.last_epoch

if current_step > self.total_steps:
raise ValueError(
"Tried to step {} times. The specified number of total steps is {}"
.format(current_step, self.total_steps))

for (i, (start_step, step_size)
) in enumerate(zip(self._start_steps, self._steps_size)):
if current_step <= (start_step + step_size
) or i == len(self._lr_config) - 2:
percentage = (current_step - start_step) / step_size
return self.anneal_func(self._lr_config[i],
self._lr_config[i + 1], percentage)