New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 2】12、为 Paddle 新增 OneCycleLR 优化调度器 #41825
Changes from 1 commit
702f47c
3537536
f8011a0
9be4aeb
2df7465
8bb405a
8ad14b0
de82b7b
af0420f
6d48016
6c76ab5
36864a0
b7c4a4b
348932b
98f9a9e
70f97a8
ebb04e2
a58cab2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1596,34 +1596,38 @@ def get_lr(self): | |
|
||
class OneCycleLR(LRScheduler): | ||
r""" | ||
Sets the learning rate according to the 1cycle learning rate scheduler. | ||
Sets the learning rate according to the one cycle learning rate scheduler. | ||
The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then | ||
from that maximum learning rate to the minimum learning rate, which is much lower than the initial learning rate. | ||
from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的描述没有什么可以修改的方案,其本身已经足够简洁明了。 |
||
It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_. | ||
|
||
Please note that the default behaviour of this scheduler follows the fastai implementation of 1cycle, | ||
Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle, | ||
which claims that “unpublished work has shown even better results by using only two phases”. | ||
Set ``three_phase=True``, If you want the behaviour of this scheduler to be consistent with the paper. | ||
|
||
Also note that you should update learning rate each step. | ||
|
||
This implementation was adapted from PyTorch. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参考需要具体到文件或函数行,adapted from [文件网址]这样吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
||
Args: | ||
max_learning_rate (float): Upper boundary of learning rate in the whole training phase. | ||
max_learning_rate (float): Upper boundary of learning rate during training. | ||
Functionally, it defines the initial learning rate and the minimum learning rate by ``divide_factor`` and | ||
``final_divide_factor`` respectively. | ||
total_steps (int, optional): Number of total training steps. | ||
Note that one of total_steps and (epochs, steps_per_epoch) must be specified. | ||
If ``total_steps`` is not specified, it will be determined by ``epochs`` and ``steps_per_epoch``. Default: None. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Default: None, means xxx. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
这部分应该表明了默认的情况,后面就不再说明了。 |
||
epochs (int, optional): Number of total training epochs. Default: None. | ||
steps_per_epoch (int, optional): Number of training steps for each epoch. Default: None. | ||
pct_start (float): The percentage of learning rate increasing steps to total steps. Default: 0.3. | ||
pct_start (float): The percentage of total steps, which used to increasing learning rate. Default: 0.3. | ||
anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, | ||
'linear' for linear annealing. Default: 'cos'. | ||
divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. max_lr -> max_learning_rate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
final_divide_factor (float, optional): Minimum learning rate will be determined by initial_lr = max_lr/div_factor. Default: 1e4. | ||
three_phase (bool, optional): Whether to use three phase. | ||
If ``True``: | ||
1. The learning rate will first increase from initial learning rate to maximum learning rate. | ||
2. Then it will be decrease to learning rate. | ||
2. Then it will be decrease to learning rate. Number of step in this phase is the same as the one in first phase. | ||
3. Finally, it decrease to minimum learning rate which is much less than initial learning rate. | ||
If ``False``: | ||
1. The learning rate will increase to maximum learning rate. | ||
|
@@ -1689,11 +1693,12 @@ def __init__(self, | |
three_phase=False, | ||
last_epoch=-1, | ||
verbose=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are total 11 parameters of OneCycleLR API in RFC, but only 9 parameters here, which is right? RFC and code must be consistency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here I commit a new pull request to modify RFC file. |
||
# Check type of max_learning_rate | ||
if not isinstance(max_learning_rate, (float, int)): | ||
raise TypeError( | ||
"The type of learning rate must be float, but received {}". | ||
format(type(max_learning_rate))) | ||
|
||
# Check type and value of total_steps | ||
if total_steps is None and epochs is None and steps_per_epoch is None: | ||
raise ValueError( | ||
"either total_steps or (epochs, steps_per_epoch) must be specified" | ||
|
@@ -1706,6 +1711,7 @@ def __init__(self, | |
raise ValueError("'total_step' must be a positive integer.") | ||
self.total_steps = total_steps | ||
else: | ||
# Check type and value of epochs and steps_per_epochs | ||
if not isinstance(epochs, int): | ||
raise TypeError("'epochs' must be 'int', but received {}". | ||
format(type(epochs))) | ||
|
@@ -1719,7 +1725,7 @@ def __init__(self, | |
raise ValueError( | ||
"'steps_per_epoch' must be a positive integer.") | ||
self.total_steps = epochs * steps_per_epoch | ||
|
||
# Check type and value of pac_start | ||
if not isinstance(pct_start, float): | ||
raise TypeError("'pct_start' must be 'float', but received {}". | ||
format(type(pct_start))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果代码是参考别人的实现,需要遵循开源协议,添加说明引用来源
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
了解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改