Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 2】12、为 Paddle 新增 OneCycleLR 优化调度器 #41825

Merged
merged 18 commits into from May 16, 2022
Merged

【PaddlePaddle Hackathon 2】12、为 Paddle 新增 OneCycleLR 优化调度器 #41825

merged 18 commits into from May 16, 2022

Conversation

Asthestarsfalll
Copy link
Contributor

@Asthestarsfalll Asthestarsfalll commented Apr 14, 2022

PR types

Others

PR changes

APIs

Describe

解决了issue:#40322
增加了API: paddle.optimizer.lr.OneCycleLR,该优化调度器在训练过程中调整学习率从初始学习率到最大学习率,再到最小学习率。
设计文档:PaddlePaddle/community#29
中文文档: PaddlePaddle/docs#4713

@paddle-bot-old paddle-bot-old bot added contributor External developers status: proposed labels Apr 14, 2022
@paddle-bot-old
Copy link

paddle-bot-old bot commented Apr 14, 2022

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot-old
Copy link

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Asthestarsfalll Asthestarsfalll changed the title Onecyclelr 【PaddlePaddle Hackathon 2】12、为 Paddle 新增 OneCycleLR 优化调度器 Apr 14, 2022
@paddle-bot-old
Copy link

PR格式检查通过,你的PR将接受Paddle专家以及开源社区的review,请及时关注PR动态。
The format inspection passed. Your PR will be reviewed by experts of Paddle and developers from the open-source community. Stay tuned.

@@ -467,6 +535,25 @@ def test_scheduler(self):
with self.assertRaises(ValueError):
paddle.optimizer.lr.MultiStepDecay(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
with self.assertRaises(TypeError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一下注释吧,每一项测试是针对什么异常输入情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


class OneCycleLR(LRScheduler):
r"""
Sets the learning rate according to the 1cycle learning rate scheduler.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档是不能大面积拷贝别人的,自己组织优化一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API描述部分参考了pytorch与其他LRScheduler的,参数部分由自己理解完成,稍后将会再进行优化。



class OneCycleLR(LRScheduler):
r"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果代码是参考别人的实现,需要遵循开源协议,添加说明引用来源

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

了解

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

which claims that “unpublished work has shown even better results by using only two phases”.
Set ``three_phase=True``, If you want the behaviour of this scheduler to be consistent with the paper.

Also note that you should update learning rate each step.

This implementation was adapted from PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考需要具体到文件或函数行,adapted from [文件网址]这样吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

zhiboniu
zhiboniu previously approved these changes Apr 26, 2022
@TCChenlong
Copy link
Contributor

请添加中文文档,并将链接填写在 Describe 中

@Asthestarsfalll
Copy link
Contributor Author

Asthestarsfalll commented Apr 29, 2022

请添加中文文档,并将链接填写在 Describe 中

已添加~

``final_divide_factor`` respectively.
total_steps (int, optional): Number of total training steps.
Note that one of total_steps and (epochs, steps_per_epoch) must be specified.
If ``total_steps`` is not specified, it will be determined by ``epochs`` and ``steps_per_epoch``. Default: None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default: None, means xxx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If total_steps is not specified, it will be determined by epochs and steps_per_epoch .

这部分应该表明了默认的情况,后面就不再说明了。

pct_start (float): The percentage of total steps, which used to increasing learning rate. Default: 0.3.
anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'.
divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_lr -> max_learning_rate
div_factor -> divide_factor
保持一致,其他地方同理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'.
divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25.
final_divide_factor (float, optional): Minimum learning rate will be determined by minimum = max_lr/final_divide_factor. Default: 1e4.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中英文公式一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


Examples:
.. code-block:: python
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import paddle 必须要和上面空一行,否则会有格式问题;
示例代码整体注意增加一些空行,保证阅读体验~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@Asthestarsfalll Asthestarsfalll dismissed stale reviews from DDDivano and TCChenlong via 98f9a9e May 10, 2022 13:22
@Asthestarsfalll
Copy link
Contributor Author

@zhiboniu 已更新

three_phase=False,
last_epoch=-1,
verbose=False):
# Check type and value of end_lr
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数部分将max_learning_rate更改为learning_rate与paddle中其他调度器对齐;
去除了额外的epochs和steps_per_epoch,实际上用户调用时可直接使用
OneCycleLR(learning_rate=0.1, total_steps=epochs*steps_per_epoch)
max_lr将由参数scale_factor推断而来;
min_lr则直接接收一个浮点数。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learning_rate的参数不需要修改,保持原来的就好。
这里涉及一个使用中关注点的问题,使用OneCycleLR其实对学习率的设置最关注的是max_lr,这个值直接影响模型训练效果和是否收敛,是比较明确的。初始学习率和最终学习率一般只是表示一个很小的值,并不具有精确的意义。所以仍然建议参数中使用原来的max_learning_rate。
这里再改一下吧,其他部分的修改我觉得没有问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于别人好的设计也是可以借鉴的,并不要求全盘否定。主要实现代码和描述独立实现就好。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的~已更新

min_lr = float(end_lr)

if three_phase:
if phase_pct >= 0.5:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了在三阶段情况下对phase_pct不能超过0.5的异常检查

self._start_steps[2] - self._start_steps[1],
self._start_steps[3] - self._start_steps[2],
self._start_steps[3] - self._start_steps[2],
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用self._start_steps的元素进行相减是因为写计算表达式会因为浮点运算不精确而造成误差

Sets the learning rate according to the one cycle learning rate scheduler.
The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的描述没有什么可以修改的方案,其本身已经足够简洁明了。

break
start_step = end_step

return computed_lr
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测代码部分没有作过多修改,仍是之前的逻辑

self.total_steps - 1,
self.total_steps - 1, # for the last step.
]
# step size of each phase.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里又做了一点修改,在get_lr时可以进行更少的计算;
同时添加了一些注释。

zhiboniu
zhiboniu previously approved these changes May 12, 2022
Copy link
Contributor

@zhiboniu zhiboniu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Asthestarsfalll
Copy link
Contributor Author

@TCChenlong @DDDivano 劳烦进行后续review!

TCChenlong
TCChenlong previously approved these changes May 13, 2022
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

DDDivano
DDDivano previously approved these changes May 13, 2022
max_learning_rate,
total_steps,
divide_factor=25.,
end_lr=0.0001,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议用全称吧,因为前面max_learning_rate用的是全称,保持一致
end_lr -> end_learning_rate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改~
这里使用end_lr主要是因为注意到前面的一些调度器如PolynomialDecay、PolynomialDecay等也都使用了end_lr

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG API

anneal_strategy='cos',
three_phase=False,
last_epoch=-1,
verbose=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are total 11 parameters of OneCycleLR API in RFC, but only 9 parameters here, which is right? RFC and code must be consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I commit a new pull request to modify RFC file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants