From 702f47c885be20fdaf4b8be33c0afc9c48139b6d Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 14 Apr 2022 21:40:02 +0800 Subject: [PATCH 01/11] add OneCycleLR --- .../tests/unittests/test_lr_scheduler.py | 120 ++++++++++ python/paddle/optimizer/lr.py | 215 +++++++++++++++++- 2 files changed, 334 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 60dd4948f996e..7f1b585f88f36 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -321,6 +321,74 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False): return learning_rate * math.pow(gamma, epoch_num // step_size) +def one_cycle_lr(epoch_num, + max_learning_rate, + total_steps=None, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + divide_factor=25., + final_divide_factor=1e4, + three_phase=False, + verbose=False): + total_steps = epochs * steps_per_epoch if total_steps is None else total_steps + initial_lr = max_learning_rate / divide_factor + min_lr = initial_lr / final_divide_factor + if three_phase: + _end_steps = [ + float(pct_start * total_steps) - 1, + float(2 * pct_start * total_steps) - 2, total_steps - 1 + ] + _schedule_phases = [ + { + 'start_lr': initial_lr, + 'end_lr': max_learning_rate, + }, + { + 'start_lr': max_learning_rate, + 'end_lr': initial_lr, + }, + { + 'start_lr': initial_lr, + 'end_lr': min_lr, + }, + ] + else: + _end_steps = [float(pct_start * total_steps) - 1, total_steps - 1] + _schedule_phases = [ + { + 'start_lr': initial_lr, + 'end_lr': max_learning_rate, + }, + { + 'start_lr': max_learning_rate, + 'end_lr': min_lr, + }, + ] + + if anneal_strategy == 'cos': + + def anneal_func(start, end, pct): + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + else: + + def anneal_func(start, end, pct): + return (end - start) * pct + start + + start_step = 0 + for i, phase in enumerate(_schedule_phases): + end_step = _end_steps[i] + if epoch_num <= end_step or i == len(_schedule_phases) - 1: + pct = (epoch_num - start_step) / (end_step - start_step) + computed_lr = anneal_func(phase['start_lr'], phase['end_lr'], pct) + break + start_step = end_step + + return computed_lr + + class TestLRScheduler(unittest.TestCase): def _test_static(self, python_func, paddle_api, kwarg, place): scheduler = paddle_api(**kwarg) @@ -467,6 +535,25 @@ def test_scheduler(self): with self.assertRaises(ValueError): paddle.optimizer.lr.MultiStepDecay( learning_rate=0.5, milestones=[1, 2, 3], gamma=2) + with self.assertRaises(TypeError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate='test', total_steps=20) + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1) + with self.assertRaises(TypeError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps='test') + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps=-10) + with self.assertRaises(TypeError): + paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, epochs='test') + with self.assertRaises(TypeError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, epochs=1, steps_per_epoch='t') + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps=20, anneal_strategy='test') func_api_kwargs = [(noam_lr, paddle.optimizer.lr.NoamDecay, { "d_model": 0.01, @@ -527,6 +614,39 @@ def test_scheduler(self): "learning_rate": 0.5, "T_max": 10, "verbose": False + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 0.5, + "total_steps": 20, + "pct_start": 0.3, + "anneal_strategy": 'cos', + "divide_factor": 25., + "final_divide_factor": 1e4, + "three_phase": False, + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 0.5, + "epochs": 10, + "steps_per_epoch": 2, + "pct_start": 0.2, + "anneal_strategy": 'linear', + "divide_factor": 20., + "final_divide_factor": 1000, + "three_phase": False, + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 1, + "total_steps": 20, + "pct_start": 0.4, + "anneal_strategy": 'cos', + "divide_factor": 15., + "final_divide_factor": 100, + "three_phase": True, + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 0.5, + "total_steps": 40, + "pct_start": 0.5, + "anneal_strategy": 'linear', + "divide_factor": 5., + "final_divide_factor": 50, + "three_phase": True, })] for python_func, paddle_api, kwarg in func_api_kwargs: diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index d0d5eef03c42c..eeba061d83b87 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -33,7 +33,8 @@ 'LambdaDecay', 'ReduceOnPlateau', 'CosineAnnealingDecay', - 'MultiplicativeDecay' + 'MultiplicativeDecay', + 'OneCycleLR' ] @@ -1591,3 +1592,215 @@ def get_lr(self): return self.last_lr * self.lr_lambda(self.last_epoch) else: return self.base_lr + + +class OneCycleLR(LRScheduler): + r""" + Sets the learning rate according to the 1cycle learning rate scheduler. + The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then + from that maximum learning rate to the minimum learning rate, which is much lower than the initial learning rate. + + It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates `_. + + Please note that the default behaviour of this scheduler follows the fastai implementation of 1cycle, + which claims that “unpublished work has shown even better results by using only two phases”. + Set ``three_phase=True``, If you want the behaviour of this scheduler to be consistent with the paper. + + Args: + max_learning_rate (float): Upper boundary of learning rate in the whole training phase. + Functionally, it defines the initial learning rate and the minimum learning rate by ``divide_factor`` and + ``final_divide_factor`` respectively. + total_steps (int, optional): Number of total training steps. + Note that one of total_steps and (epochs, steps_per_epoch) must be specified. + If ``total_steps`` is not specified, it will be determined by ``epochs`` and ``steps_per_epoch``. Default: None. + epochs (int, optional): Number of total training epochs. Default: None. + steps_per_epoch (int, optional): Number of training steps for each epoch. Default: None. + pct_start (float): The percentage of learning rate increasing steps to total steps. Default: 0.3. + anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, + 'linear' for linear annealing. Default: 'cos'. + divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25. + final_divide_factor (float, optional): Minimum learning rate will be determined by initial_lr = max_lr/div_factor. Default: 1e4. + three_phase (bool, optional): Whether to use three phase. + If ``True``: + 1. The learning rate will first increase from initial learning rate to maximum learning rate. + 2. Then it will be decrease to learning rate. + 3. Finally, it decrease to minimum learning rate which is much less than initial learning rate. + If ``False``: + 1. The learning rate will increase to maximum learning rate. + 2. Then it will directly decrease to minimum learning rate. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``OneCycleLR`` instance to schedule learning rate. + + Examples: + .. code-block:: python + import paddle + import numpy as np + # train on default dynamic graph mode + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) + for epoch in range(5): + for batch_id in range(20): + x = paddle.uniform([10, 10]) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_gradients() + scheduler.step() # You should update learning rate each step + # train on static graph mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[None, 4, 5]) + y = paddle.static.data(name='y', shape=[None, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(5): + for batch_id in range(20): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=loss.name) + scheduler.step() # You should update learning rate each step + """ + + def __init__(self, + max_learning_rate, + total_steps=None, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + divide_factor=25., + final_divide_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False): + if not isinstance(max_learning_rate, (float, int)): + raise TypeError( + "The type of learning rate must be float, but received {}". + format(type(max_learning_rate))) + + if total_steps is None and epochs is None and steps_per_epoch is None: + raise ValueError( + "either total_steps or (epochs, steps_per_epoch) must be specified" + ) + elif total_steps is not None: + if not isinstance(total_steps, int): + raise TypeError("'total_step' must be 'int', but received {}". + format(type(total_steps))) + if total_steps <= 0: + raise ValueError("'total_step' must be a positive integer.") + self.total_steps = total_steps + else: + if not isinstance(epochs, int): + raise TypeError("'epochs' must be 'int', but received {}". + format(type(epochs))) + if not isinstance(steps_per_epoch, int): + raise TypeError( + "'steps_per_epoch', must be 'int', but received {}".format( + type(steps_per_epoch))) + if epochs < 0: + raise ValueError("'epochs' must be a positive integer.") + if steps_per_epoch < 0: + raise ValueError( + "'steps_per_epoch' must be a positive integer.") + + if not isinstance(pct_start, float): + raise TypeError("'pct_start' must be 'float', but received {}". + format(type(pct_start))) + if pct_start < 0 or pct_start > 1: + raise ValueError( + "'pct_start' must be between 0 and 1, but received {}".format( + pct_start)) + + max_lr = max_learning_rate + initial_lr = max_lr / divide_factor + min_lr = initial_lr / final_divide_factor + + if three_phase: + self._end_steps = [ + float(pct_start * self.total_steps) - 1, + float(2 * pct_start * self.total_steps) - 2, + self.total_steps - 1 + ] + self._schedule_phases = [ + { + 'start_lr': initial_lr, + 'end_lr': max_lr, + }, + { + 'start_lr': max_lr, + 'end_lr': initial_lr, + }, + { + 'start_lr': initial_lr, + 'end_lr': min_lr, + }, + ] + else: + self._end_steps = [ + float(pct_start * self.total_steps) - 1, self.total_steps - 1 + ] + self._schedule_phases = [ + { + 'start_lr': initial_lr, + 'end_lr': max_lr, + }, + { + 'start_lr': max_lr, + 'end_lr': min_lr, + }, + ] + + # Validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError( + "'anneal_strategy' must by one of 'cos' or 'linear', but received {}". + format(anneal_strategy)) + elif anneal_strategy == 'cos': + self.anneal_func = self._annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = self._annealing_linear + + super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose) + + def _annealing_cos(self, start, end, pct): + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def _annealing_linear(self, start, end, pct): + return (end - start) * pct + start + + def get_lr(self): + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError( + "Tried to step {} times. The specified number of total steps is {}" + .format(step_num + 1, self.total_steps)) + + start_step = 0 + for i, phase in enumerate(self._schedule_phases): + end_step = self._end_steps[i] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self.anneal_func(phase['start_lr'], + phase['end_lr'], pct) + break + start_step = end_step + + return computed_lr From f8011a04fe20f6637e350cd6d8320fa97de3ba6d Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 14 Apr 2022 22:42:02 +0800 Subject: [PATCH 02/11] add missing total_steps --- python/paddle/optimizer/lr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index eeba061d83b87..609306f1ed9e0 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1718,6 +1718,7 @@ def __init__(self, if steps_per_epoch < 0: raise ValueError( "'steps_per_epoch' must be a positive integer.") + self.total_steps = epochs * steps_per_epoch if not isinstance(pct_start, float): raise TypeError("'pct_start' must be 'float', but received {}". From 9be4aeb76e77072575226b9fb06f404dd4defe5c Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 20 Apr 2022 23:05:06 +0800 Subject: [PATCH 03/11] try --- python/paddle/optimizer/lr.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 609306f1ed9e0..3c627f1289a14 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1596,18 +1596,22 @@ def get_lr(self): class OneCycleLR(LRScheduler): r""" - Sets the learning rate according to the 1cycle learning rate scheduler. + Sets the learning rate according to the one cycle learning rate scheduler. The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then - from that maximum learning rate to the minimum learning rate, which is much lower than the initial learning rate. + from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate. It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates `_. - Please note that the default behaviour of this scheduler follows the fastai implementation of 1cycle, + Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle, which claims that “unpublished work has shown even better results by using only two phases”. Set ``three_phase=True``, If you want the behaviour of this scheduler to be consistent with the paper. + Also note that you should update learning rate each step. + + This implementation was adapted from PyTorch. + Args: - max_learning_rate (float): Upper boundary of learning rate in the whole training phase. + max_learning_rate (float): Upper boundary of learning rate during training. Functionally, it defines the initial learning rate and the minimum learning rate by ``divide_factor`` and ``final_divide_factor`` respectively. total_steps (int, optional): Number of total training steps. @@ -1615,7 +1619,7 @@ class OneCycleLR(LRScheduler): If ``total_steps`` is not specified, it will be determined by ``epochs`` and ``steps_per_epoch``. Default: None. epochs (int, optional): Number of total training epochs. Default: None. steps_per_epoch (int, optional): Number of training steps for each epoch. Default: None. - pct_start (float): The percentage of learning rate increasing steps to total steps. Default: 0.3. + pct_start (float): The percentage of total steps, which used to increasing learning rate. Default: 0.3. anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'. divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25. @@ -1623,7 +1627,7 @@ class OneCycleLR(LRScheduler): three_phase (bool, optional): Whether to use three phase. If ``True``: 1. The learning rate will first increase from initial learning rate to maximum learning rate. - 2. Then it will be decrease to learning rate. + 2. Then it will be decrease to learning rate. Number of step in this phase is the same as the one in first phase. 3. Finally, it decrease to minimum learning rate which is much less than initial learning rate. If ``False``: 1. The learning rate will increase to maximum learning rate. @@ -1689,11 +1693,12 @@ def __init__(self, three_phase=False, last_epoch=-1, verbose=False): + # Check type of max_learning_rate if not isinstance(max_learning_rate, (float, int)): raise TypeError( "The type of learning rate must be float, but received {}". format(type(max_learning_rate))) - + # Check type and value of total_steps if total_steps is None and epochs is None and steps_per_epoch is None: raise ValueError( "either total_steps or (epochs, steps_per_epoch) must be specified" @@ -1706,6 +1711,7 @@ def __init__(self, raise ValueError("'total_step' must be a positive integer.") self.total_steps = total_steps else: + # Check type and value of epochs and steps_per_epochs if not isinstance(epochs, int): raise TypeError("'epochs' must be 'int', but received {}". format(type(epochs))) @@ -1719,7 +1725,7 @@ def __init__(self, raise ValueError( "'steps_per_epoch' must be a positive integer.") self.total_steps = epochs * steps_per_epoch - + # Check type and value of pac_start if not isinstance(pct_start, float): raise TypeError("'pct_start' must be 'float', but received {}". format(type(pct_start))) From 8bb405a9e1008e5688b7e806b25c5887b2433374 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 25 Apr 2022 12:04:38 +0800 Subject: [PATCH 04/11] update --- python/paddle/optimizer/lr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 3c627f1289a14..c6e2a554024ce 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1608,7 +1608,7 @@ class OneCycleLR(LRScheduler): Also note that you should update learning rate each step. - This implementation was adapted from PyTorch. + This implementation was adapted from `there `_. Args: max_learning_rate (float): Upper boundary of learning rate during training. From de82b7b47f03320eead68dba4dbd0d0738b1e031 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 25 Apr 2022 16:19:20 +0800 Subject: [PATCH 05/11] fix conflict bug --- python/paddle/optimizer/lr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index c6e2a554024ce..4dac254e306ad 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1588,10 +1588,10 @@ def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False): verbose) def get_lr(self): - if self.last_epoch > 0: - return self.last_lr * self.lr_lambda(self.last_epoch) - else: - return self.base_lr + cur_lr = self.base_lr + for epoch in range(1, self.last_epoch + 1): + cur_lr = cur_lr * self.lr_lambda(epoch) + return cur_lr class OneCycleLR(LRScheduler): From af0420fed9ad79fc6256f680b364c5caf1027bd5 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 29 Apr 2022 18:42:48 +0800 Subject: [PATCH 06/11] fix typo --- python/paddle/optimizer/lr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 4dac254e306ad..5d9f95b8971f5 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1604,7 +1604,7 @@ class OneCycleLR(LRScheduler): Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle, which claims that “unpublished work has shown even better results by using only two phases”. - Set ``three_phase=True``, If you want the behaviour of this scheduler to be consistent with the paper. + Set ``three_phase=True``, if you want the behaviour of this scheduler to be consistent with the paper. Also note that you should update learning rate each step. @@ -1623,12 +1623,12 @@ class OneCycleLR(LRScheduler): anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'. divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25. - final_divide_factor (float, optional): Minimum learning rate will be determined by initial_lr = max_lr/div_factor. Default: 1e4. + final_divide_factor (float, optional): Minimum learning rate will be determined by minimum = max_lr/final_divide_factor. Default: 1e4. three_phase (bool, optional): Whether to use three phase. If ``True``: 1. The learning rate will first increase from initial learning rate to maximum learning rate. - 2. Then it will be decrease to learning rate. Number of step in this phase is the same as the one in first phase. - 3. Finally, it decrease to minimum learning rate which is much less than initial learning rate. + 2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase. + 3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate. If ``False``: 1. The learning rate will increase to maximum learning rate. 2. Then it will directly decrease to minimum learning rate. From 6c76ab593d2cfb874be847ae4558f07c32641373 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 6 May 2022 20:20:19 +0800 Subject: [PATCH 07/11] update doc --- python/paddle/optimizer/lr.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 5d9f95b8971f5..1cb6eaebc34c0 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1604,7 +1604,7 @@ class OneCycleLR(LRScheduler): Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle, which claims that “unpublished work has shown even better results by using only two phases”. - Set ``three_phase=True``, if you want the behaviour of this scheduler to be consistent with the paper. + Set ``three_phase=True`` , if you want the behaviour of this scheduler to be consistent with the paper. Also note that you should update learning rate each step. @@ -1622,8 +1622,8 @@ class OneCycleLR(LRScheduler): pct_start (float): The percentage of total steps, which used to increasing learning rate. Default: 0.3. anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'. - divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_lr/div_factor. Default: 25. - final_divide_factor (float, optional): Minimum learning rate will be determined by minimum = max_lr/final_divide_factor. Default: 1e4. + divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_learning_rate/divide_factor. Default: 25. + final_divide_factor (float, optional): Minimum learning rate will be determined by minimum = max_learning_rate/final_divide_factor. Default: 1e4. three_phase (bool, optional): Whether to use three phase. If ``True``: 1. The learning rate will first increase from initial learning rate to maximum learning rate. @@ -1640,8 +1640,10 @@ class OneCycleLR(LRScheduler): Examples: .. code-block:: python + import paddle import numpy as np + # train on default dynamic graph mode linear = paddle.nn.Linear(10, 10) scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) @@ -1655,6 +1657,7 @@ class OneCycleLR(LRScheduler): sgd.step() sgd.clear_gradients() scheduler.step() # You should update learning rate each step + # train on static graph mode paddle.enable_static() main_prog = paddle.static.Program() @@ -1667,6 +1670,7 @@ class OneCycleLR(LRScheduler): scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) sgd = paddle.optimizer.SGD(learning_rate=scheduler) sgd.minimize(loss) + exe = paddle.static.Executor() exe.run(start_prog) for epoch in range(5): From b7c4a4b2182715265141e57d5f628e64aacfc6c9 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 6 May 2022 20:48:00 +0800 Subject: [PATCH 08/11] update and polish --- python/paddle/optimizer/lr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 1cb6eaebc34c0..7f0ee4296e932 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1604,7 +1604,7 @@ class OneCycleLR(LRScheduler): Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle, which claims that “unpublished work has shown even better results by using only two phases”. - Set ``three_phase=True`` , if you want the behaviour of this scheduler to be consistent with the paper. + If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` . Also note that you should update learning rate each step. @@ -1777,7 +1777,7 @@ def __init__(self, }, ] - # Validate anneal_strategy + # Check anneal_strategy if anneal_strategy not in ['cos', 'linear']: raise ValueError( "'anneal_strategy' must by one of 'cos' or 'linear', but received {}". From 348932b02d6784ae276782b20341725d3fcc94fe Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Tue, 10 May 2022 21:21:46 +0800 Subject: [PATCH 09/11] Refactor --- .../tests/unittests/test_lr_scheduler.py | 102 +++++----- python/paddle/optimizer/lr.py | 186 +++++++----------- 2 files changed, 125 insertions(+), 163 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 7f1b585f88f36..68e17507a6223 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -322,48 +322,44 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False): def one_cycle_lr(epoch_num, - max_learning_rate, - total_steps=None, - epochs=None, - steps_per_epoch=None, - pct_start=0.3, + learning_rate, + total_steps, + scale_factor=25, + end_lr=0.0001, + phase_pct=0.3, anneal_strategy='cos', - divide_factor=25., - final_divide_factor=1e4, three_phase=False, verbose=False): - total_steps = epochs * steps_per_epoch if total_steps is None else total_steps - initial_lr = max_learning_rate / divide_factor - min_lr = initial_lr / final_divide_factor + max_lr = learning_rate * scale_factor if three_phase: _end_steps = [ - float(pct_start * total_steps) - 1, - float(2 * pct_start * total_steps) - 2, total_steps - 1 + float(phase_pct * total_steps) - 1, + float(2 * phase_pct * total_steps) - 2, total_steps - 1 ] _schedule_phases = [ { - 'start_lr': initial_lr, - 'end_lr': max_learning_rate, + 'start_lr': learning_rate, + 'end_lr': max_lr, }, { - 'start_lr': max_learning_rate, - 'end_lr': initial_lr, + 'start_lr': max_lr, + 'end_lr': learning_rate, }, { - 'start_lr': initial_lr, - 'end_lr': min_lr, + 'start_lr': learning_rate, + 'end_lr': end_lr, }, ] else: - _end_steps = [float(pct_start * total_steps) - 1, total_steps - 1] + _end_steps = [float(phase_pct * total_steps) - 1, total_steps - 1] _schedule_phases = [ { - 'start_lr': initial_lr, - 'end_lr': max_learning_rate, + 'start_lr': learning_rate, + 'end_lr': max_lr, }, { - 'start_lr': max_learning_rate, - 'end_lr': min_lr, + 'start_lr': max_lr, + 'end_lr': end_lr, }, ] @@ -536,24 +532,27 @@ def test_scheduler(self): paddle.optimizer.lr.MultiStepDecay( learning_rate=0.5, milestones=[1, 2, 3], gamma=2) with self.assertRaises(TypeError): - paddle.optimizer.lr.OneCycleLR( - max_learning_rate='test', total_steps=20) - with self.assertRaises(ValueError): - paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1) + paddle.optimizer.lr.OneCycleLR(learning_rate='test', total_steps=20) with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR( - max_learning_rate=0.1, total_steps='test') + learning_rate=0.1, total_steps=20, end_lr='test') with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR( - max_learning_rate=0.1, total_steps=-10) - with self.assertRaises(TypeError): - paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, epochs='test') + learning_rate=0.1, total_steps=20, end_lr=-1) with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR( - max_learning_rate=0.1, epochs=1, steps_per_epoch='t') + learning_rate=0.1, total_steps='test') + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR(learning_rate=0.1, total_steps=-10) with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR( - max_learning_rate=0.1, total_steps=20, anneal_strategy='test') + learning_rate=0.1, total_steps=20, anneal_strategy='test') + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + learning_rate=0.1, + total_steps=20, + phase_pct=0.6, + three_phase=True) func_api_kwargs = [(noam_lr, paddle.optimizer.lr.NoamDecay, { "d_model": 0.01, @@ -615,37 +614,36 @@ def test_scheduler(self): "T_max": 10, "verbose": False }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "max_learning_rate": 0.5, + "learning_rate": 0.1, "total_steps": 20, - "pct_start": 0.3, + "scale_factor": 5, + "end_lr": 0.0001, "anneal_strategy": 'cos', - "divide_factor": 25., - "final_divide_factor": 1e4, + "phase_pct": 0.3, "three_phase": False, }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "max_learning_rate": 0.5, - "epochs": 10, - "steps_per_epoch": 2, - "pct_start": 0.2, + "learning_rate": 0.5, + "total_steps": 20, + "scale_factor": 10, + "end_lr": 0.001, "anneal_strategy": 'linear', - "divide_factor": 20., - "final_divide_factor": 1000, + "phase_pct": 0.4, "three_phase": False, }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "max_learning_rate": 1, + "learning_rate": 0.1, "total_steps": 20, - "pct_start": 0.4, + "scale_factor": 9, + "end_lr": 0.0001, "anneal_strategy": 'cos', - "divide_factor": 15., - "final_divide_factor": 100, + "phase_pct": 0.3, "three_phase": True, }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "max_learning_rate": 0.5, - "total_steps": 40, - "pct_start": 0.5, + "learning_rate": 0.3, + "total_steps": 20, + "scale_factor": 25, + "end_lr": 0.0005, "anneal_strategy": 'linear', - "divide_factor": 5., - "final_divide_factor": 50, + "phase_pct": 0.2, "three_phase": True, })] diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 7f0ee4296e932..d79b59ed78837 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1608,22 +1608,15 @@ class OneCycleLR(LRScheduler): Also note that you should update learning rate each step. - This implementation was adapted from `there `_. - Args: - max_learning_rate (float): Upper boundary of learning rate during training. - Functionally, it defines the initial learning rate and the minimum learning rate by ``divide_factor`` and - ``final_divide_factor`` respectively. - total_steps (int, optional): Number of total training steps. - Note that one of total_steps and (epochs, steps_per_epoch) must be specified. - If ``total_steps`` is not specified, it will be determined by ``epochs`` and ``steps_per_epoch``. Default: None. - epochs (int, optional): Number of total training epochs. Default: None. - steps_per_epoch (int, optional): Number of training steps for each epoch. Default: None. - pct_start (float): The percentage of total steps, which used to increasing learning rate. Default: 0.3. + learning_rate (float): The initial learning rate. It is a python float number. + Functionally, it defines the maximum learning rate by ``scale_factor`` . + total_steps (int): Number of total training steps. + scale_factor (float): Maximum learning rate will be determined by maximum_lr = learning_rate * scale_factor. Default: 25. + end_lr (float, optional): The minimum learning rate of schedule, it should be much less than initial learning rate. + phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3. anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'. - divide_factor (float, optional): Initial learning rate will be determined by initial_lr = max_learning_rate/divide_factor. Default: 25. - final_divide_factor (float, optional): Minimum learning rate will be determined by minimum = max_learning_rate/final_divide_factor. Default: 1e4. three_phase (bool, optional): Whether to use three phase. If ``True``: 1. The learning rate will first increase from initial learning rate to maximum learning rate. @@ -1646,7 +1639,7 @@ class OneCycleLR(LRScheduler): # train on default dynamic graph mode linear = paddle.nn.Linear(10, 10) - scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) + scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True) sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) for epoch in range(5): for batch_id in range(20): @@ -1667,7 +1660,7 @@ class OneCycleLR(LRScheduler): y = paddle.static.data(name='y', shape=[None, 4, 5]) z = paddle.static.nn.fc(x, 100) loss = paddle.mean(z) - scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) + scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True) sgd = paddle.optimizer.SGD(learning_rate=scheduler) sgd.minimize(loss) @@ -1686,96 +1679,71 @@ class OneCycleLR(LRScheduler): """ def __init__(self, - max_learning_rate, - total_steps=None, - epochs=None, - steps_per_epoch=None, - pct_start=0.3, + learning_rate, + total_steps, + scale_factor=25, + end_lr=0.0001, + phase_pct=0.3, anneal_strategy='cos', - divide_factor=25., - final_divide_factor=1e4, three_phase=False, last_epoch=-1, verbose=False): - # Check type of max_learning_rate - if not isinstance(max_learning_rate, (float, int)): + # Check type and value of end_lr + if not isinstance(end_lr, (float, int)): raise TypeError( - "The type of learning rate must be float, but received {}". - format(type(max_learning_rate))) + "'total_step' must be 'float' or 'int', but received {}".format( + type(total_steps))) + if end_lr < 0: + raise ValueError("'end_lr' must be a positive integer.") + # Check type and value of total_steps - if total_steps is None and epochs is None and steps_per_epoch is None: - raise ValueError( - "either total_steps or (epochs, steps_per_epoch) must be specified" - ) - elif total_steps is not None: - if not isinstance(total_steps, int): - raise TypeError("'total_step' must be 'int', but received {}". - format(type(total_steps))) - if total_steps <= 0: - raise ValueError("'total_step' must be a positive integer.") - self.total_steps = total_steps - else: - # Check type and value of epochs and steps_per_epochs - if not isinstance(epochs, int): - raise TypeError("'epochs' must be 'int', but received {}". - format(type(epochs))) - if not isinstance(steps_per_epoch, int): - raise TypeError( - "'steps_per_epoch', must be 'int', but received {}".format( - type(steps_per_epoch))) - if epochs < 0: - raise ValueError("'epochs' must be a positive integer.") - if steps_per_epoch < 0: - raise ValueError( - "'steps_per_epoch' must be a positive integer.") - self.total_steps = epochs * steps_per_epoch + if not isinstance(total_steps, int): + raise TypeError("'total_step' must be 'int', but received {}". + format(type(total_steps))) + if total_steps <= 0: + raise ValueError("'total_step' must be a positive integer.") + self.total_steps = total_steps + # Check type and value of pac_start - if not isinstance(pct_start, float): - raise TypeError("'pct_start' must be 'float', but received {}". - format(type(pct_start))) - if pct_start < 0 or pct_start > 1: + if not isinstance(phase_pct, float): + raise TypeError("'phase_pct' must be 'float', but received {}". + format(type(phase_pct))) + if phase_pct < 0 or phase_pct > 1: raise ValueError( - "'pct_start' must be between 0 and 1, but received {}".format( - pct_start)) + "'phase_pct' must be between 0 and 1, but received {}".format( + phase_pct)) - max_lr = max_learning_rate - initial_lr = max_lr / divide_factor - min_lr = initial_lr / final_divide_factor + max_lr = learning_rate * scale_factor + min_lr = float(end_lr) if three_phase: - self._end_steps = [ - float(pct_start * self.total_steps) - 1, - float(2 * pct_start * self.total_steps) - 2, - self.total_steps - 1 + if phase_pct >= 0.5: + raise ValueError( + "When three_phase is True, 'phase_pct' must be smaller than 0.5" + ) + self._start_steps = [ + 0, + phase_pct * self.total_steps - 1, + 2 * phase_pct * self.total_steps - 2, + self.total_steps - 1, ] - self._schedule_phases = [ - { - 'start_lr': initial_lr, - 'end_lr': max_lr, - }, - { - 'start_lr': max_lr, - 'end_lr': initial_lr, - }, - { - 'start_lr': initial_lr, - 'end_lr': min_lr, - }, + self._steps_size = [ + self._start_steps[1] - self._start_steps[0], + self._start_steps[2] - self._start_steps[1], + self._start_steps[3] - self._start_steps[2], + self._start_steps[3] - self._start_steps[2], ] + self._lr_config = [learning_rate, max_lr, learning_rate, min_lr] else: - self._end_steps = [ - float(pct_start * self.total_steps) - 1, self.total_steps - 1 + self._start_steps = [ + 0, phase_pct * self.total_steps - 1, self.total_steps - 1 ] - self._schedule_phases = [ - { - 'start_lr': initial_lr, - 'end_lr': max_lr, - }, - { - 'start_lr': max_lr, - 'end_lr': min_lr, - }, + self._steps_size = [ + self._start_steps[1] - self._start_steps[0], + self._start_steps[2] - self._start_steps[1], + self._start_steps[2] - self._start_steps[1], ] + self._lr_config = [learning_rate, max_lr, min_lr] # Check anneal_strategy if anneal_strategy not in ['cos', 'linear']: @@ -1783,35 +1751,31 @@ def __init__(self, "'anneal_strategy' must by one of 'cos' or 'linear', but received {}". format(anneal_strategy)) elif anneal_strategy == 'cos': - self.anneal_func = self._annealing_cos + self.anneal_func = self._cos_annealing elif anneal_strategy == 'linear': - self.anneal_func = self._annealing_linear + self.anneal_func = self._linear_annealing - super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose) + super(OneCycleLR, self).__init__(learning_rate, last_epoch, verbose) - def _annealing_cos(self, start, end, pct): + def _cos_annealing(self, start_lr, end_lr, pct): cos_out = math.cos(math.pi * pct) + 1 - return end + (start - end) / 2.0 * cos_out + return end_lr + (start_lr - end_lr) / 2.0 * cos_out - def _annealing_linear(self, start, end, pct): - return (end - start) * pct + start + def _linear_annealing(self, start_lr, end_lr, pct): + return (end_lr - start_lr) * pct + start_lr def get_lr(self): - step_num = self.last_epoch + current_step = self.last_epoch - if step_num > self.total_steps: + if current_step > self.total_steps: raise ValueError( "Tried to step {} times. The specified number of total steps is {}" - .format(step_num + 1, self.total_steps)) - - start_step = 0 - for i, phase in enumerate(self._schedule_phases): - end_step = self._end_steps[i] - if step_num <= end_step or i == len(self._schedule_phases) - 1: - pct = (step_num - start_step) / (end_step - start_step) - computed_lr = self.anneal_func(phase['start_lr'], - phase['end_lr'], pct) - break - start_step = end_step - - return computed_lr + .format(current_step, self.total_steps)) + + for (i, (start_step, step_size) + ) in enumerate(zip(self._start_steps, self._steps_size)): + if current_step <= (start_step + step_size + ) or i == len(self._lr_config) - 2: + percentage = (current_step - start_step) / step_size + return self.anneal_func(self._lr_config[i], + self._lr_config[i + 1], percentage) From 70f97a841459e76528c5107c98b76a4882a3397f Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 11 May 2022 14:43:04 +0800 Subject: [PATCH 10/11] update --- .../tests/unittests/test_lr_scheduler.py | 57 ++++++----- python/paddle/optimizer/lr.py | 94 ++++++++++++------- 2 files changed, 89 insertions(+), 62 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 68e17507a6223..ee8e5834967cc 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -322,15 +322,15 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False): def one_cycle_lr(epoch_num, - learning_rate, + max_learning_rate, total_steps, - scale_factor=25, + divide_factor=25, end_lr=0.0001, phase_pct=0.3, anneal_strategy='cos', three_phase=False, verbose=False): - max_lr = learning_rate * scale_factor + initial_lr = max_learning_rate / divide_factor if three_phase: _end_steps = [ float(phase_pct * total_steps) - 1, @@ -338,15 +338,15 @@ def one_cycle_lr(epoch_num, ] _schedule_phases = [ { - 'start_lr': learning_rate, - 'end_lr': max_lr, + 'start_lr': initial_lr, + 'end_lr': max_learning_rate, }, { - 'start_lr': max_lr, - 'end_lr': learning_rate, + 'start_lr': max_learning_rate, + 'end_lr': initial_lr, }, { - 'start_lr': learning_rate, + 'start_lr': initial_lr, 'end_lr': end_lr, }, ] @@ -354,11 +354,11 @@ def one_cycle_lr(epoch_num, _end_steps = [float(phase_pct * total_steps) - 1, total_steps - 1] _schedule_phases = [ { - 'start_lr': learning_rate, - 'end_lr': max_lr, + 'start_lr': initial_lr, + 'end_lr': max_learning_rate, }, { - 'start_lr': max_lr, + 'start_lr': max_learning_rate, 'end_lr': end_lr, }, ] @@ -532,24 +532,29 @@ def test_scheduler(self): paddle.optimizer.lr.MultiStepDecay( learning_rate=0.5, milestones=[1, 2, 3], gamma=2) with self.assertRaises(TypeError): - paddle.optimizer.lr.OneCycleLR(learning_rate='test', total_steps=20) + paddle.optimizer.lr.OneCycleLR( + max_learning_rate='test', total_steps=20) + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=-1.5, total_steps=20) with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR( - learning_rate=0.1, total_steps=20, end_lr='test') + max_learning_rate=0.1, total_steps=20, end_lr='test') with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR( - learning_rate=0.1, total_steps=20, end_lr=-1) + max_learning_rate=0.1, total_steps=20, end_lr=-1) with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR( - learning_rate=0.1, total_steps='test') + max_learning_rate=0.1, total_steps='test') with self.assertRaises(ValueError): - paddle.optimizer.lr.OneCycleLR(learning_rate=0.1, total_steps=-10) + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps=-10) with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR( - learning_rate=0.1, total_steps=20, anneal_strategy='test') + max_learning_rate=0.1, total_steps=20, anneal_strategy='test') with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR( - learning_rate=0.1, + max_learning_rate=0.1, total_steps=20, phase_pct=0.6, three_phase=True) @@ -614,33 +619,33 @@ def test_scheduler(self): "T_max": 10, "verbose": False }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "learning_rate": 0.1, + "max_learning_rate": 0.1, "total_steps": 20, - "scale_factor": 5, + "divide_factor": 5, "end_lr": 0.0001, "anneal_strategy": 'cos', "phase_pct": 0.3, "three_phase": False, }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "learning_rate": 0.5, + "max_learning_rate": 0.5, "total_steps": 20, - "scale_factor": 10, + "divide_factor": 10, "end_lr": 0.001, "anneal_strategy": 'linear', "phase_pct": 0.4, "three_phase": False, }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "learning_rate": 0.1, + "max_learning_rate": 1.0, "total_steps": 20, - "scale_factor": 9, + "divide_factor": 9, "end_lr": 0.0001, "anneal_strategy": 'cos', "phase_pct": 0.3, "three_phase": True, }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { - "learning_rate": 0.3, + "max_learning_rate": 0.3, "total_steps": 20, - "scale_factor": 25, + "divide_factor": 25, "end_lr": 0.0005, "anneal_strategy": 'linear', "phase_pct": 0.2, diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index d79b59ed78837..c3de4b781174b 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1609,11 +1609,11 @@ class OneCycleLR(LRScheduler): Also note that you should update learning rate each step. Args: - learning_rate (float): The initial learning rate. It is a python float number. - Functionally, it defines the maximum learning rate by ``scale_factor`` . + max_learning_rate (float): The maximum learning rate. It is a python float number. + Functionally, it defines the initial learning rate by ``divide_factor`` . total_steps (int): Number of total training steps. - scale_factor (float): Maximum learning rate will be determined by maximum_lr = learning_rate * scale_factor. Default: 25. - end_lr (float, optional): The minimum learning rate of schedule, it should be much less than initial learning rate. + divide_factor (float): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25. + end_lr (float, optional): The minimum learning rate during training, it should be much less than initial learning rate. phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3. anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'. @@ -1639,7 +1639,7 @@ class OneCycleLR(LRScheduler): # train on default dynamic graph mode linear = paddle.nn.Linear(10, 10) - scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True) + scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) for epoch in range(5): for batch_id in range(20): @@ -1660,7 +1660,7 @@ class OneCycleLR(LRScheduler): y = paddle.static.data(name='y', shape=[None, 4, 5]) z = paddle.static.nn.fc(x, 100) loss = paddle.mean(z) - scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True) + scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) sgd = paddle.optimizer.SGD(learning_rate=scheduler) sgd.minimize(loss) @@ -1679,19 +1679,27 @@ class OneCycleLR(LRScheduler): """ def __init__(self, - learning_rate, + max_learning_rate, total_steps, - scale_factor=25, + divide_factor=25., end_lr=0.0001, phase_pct=0.3, anneal_strategy='cos', three_phase=False, last_epoch=-1, verbose=False): + # Check type and value of max_learning_rate + if not isinstance(max_learning_rate, (float, int)): + raise TypeError( + "'max_learning_rate' must be 'float' or 'int', but received {}". + format(type(total_steps))) + if max_learning_rate < 0: + raise ValueError("'max_learning_rate' must be a positive integer.") + # Check type and value of end_lr if not isinstance(end_lr, (float, int)): raise TypeError( - "'total_step' must be 'float' or 'int', but received {}".format( + "'end_lr' must be 'float' or 'int', but received {}".format( type(total_steps))) if end_lr < 0: raise ValueError("'end_lr' must be a positive integer.") @@ -1713,49 +1721,62 @@ def __init__(self, "'phase_pct' must be between 0 and 1, but received {}".format( phase_pct)) - max_lr = learning_rate * scale_factor + # Check type and value of divide_factor + if not isinstance(divide_factor, (float, int)): + raise TypeError( + "'divide_factor' must be 'float' or 'int', but received {}". + format(type(divide_factor))) + + initial_lr = max_learning_rate / float(divide_factor) min_lr = float(end_lr) if three_phase: if phase_pct >= 0.5: raise ValueError( - "When three_phase is True, 'phase_pct' must be smaller than 0.5" + "When three_phase is True, 'phase_pct' must be less than 0.5" ) - self._start_steps = [ + # start step and end step of each phase. + self._step_config = [ 0, phase_pct * self.total_steps - 1, 2 * phase_pct * self.total_steps - 2, self.total_steps - 1, + self.total_steps - 1, # for the last step. ] + # step size of each phase. self._steps_size = [ - self._start_steps[1] - self._start_steps[0], - self._start_steps[2] - self._start_steps[1], - self._start_steps[3] - self._start_steps[2], - self._start_steps[3] - self._start_steps[2], + self._step_config[1] - self._step_config[0], + self._step_config[2] - self._step_config[1], + self._step_config[3] - self._step_config[2], + self._step_config[3] - + self._step_config[2], # for the last step. + ] + # start lr and end lr of each phase. + self._lr_config = [ + initial_lr, max_learning_rate, initial_lr, min_lr ] - self._lr_config = [learning_rate, max_lr, learning_rate, min_lr] else: - self._start_steps = [ - 0, phase_pct * self.total_steps - 1, self.total_steps - 1 + self._step_config = [ + 0, phase_pct * self.total_steps - 1, self.total_steps - 1, + self.total_steps - 1 ] self._steps_size = [ - self._start_steps[1] - self._start_steps[0], - self._start_steps[2] - self._start_steps[1], - self._start_steps[2] - self._start_steps[1], + self._step_config[1] - self._step_config[0], + self._step_config[2] - self._step_config[1], + self._step_config[2] - self._step_config[1], ] - self._lr_config = [learning_rate, max_lr, min_lr] + self._lr_config = [initial_lr, max_learning_rate, min_lr] # Check anneal_strategy - if anneal_strategy not in ['cos', 'linear']: - raise ValueError( - "'anneal_strategy' must by one of 'cos' or 'linear', but received {}". - format(anneal_strategy)) - elif anneal_strategy == 'cos': + if anneal_strategy == 'cos': self.anneal_func = self._cos_annealing elif anneal_strategy == 'linear': self.anneal_func = self._linear_annealing - - super(OneCycleLR, self).__init__(learning_rate, last_epoch, verbose) + else: + raise ValueError( + "'anneal_strategy' must by one of 'cos' or 'linear', but received {}". + format(anneal_strategy)) + super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose) def _cos_annealing(self, start_lr, end_lr, pct): cos_out = math.cos(math.pi * pct) + 1 @@ -1769,13 +1790,14 @@ def get_lr(self): if current_step > self.total_steps: raise ValueError( - "Tried to step {} times. The specified number of total steps is {}" + "Tried to step {} times. However the number of total steps is {}" .format(current_step, self.total_steps)) - for (i, (start_step, step_size) - ) in enumerate(zip(self._start_steps, self._steps_size)): - if current_step <= (start_step + step_size - ) or i == len(self._lr_config) - 2: - percentage = (current_step - start_step) / step_size + for (i, (end_step, step_size) + ) in enumerate(zip(self._step_config[1:], self._steps_size)): + # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None. + if current_step <= end_step or i == len(self._lr_config) - 2: + # self._step_config[i] means start step of a phase. + percentage = (current_step - self._step_config[i]) / step_size return self.anneal_func(self._lr_config[i], self._lr_config[i + 1], percentage) From ebb04e289c4bb75e1886fe77187a8b866e2b363b Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 13 May 2022 17:40:47 +0800 Subject: [PATCH 11/11] change end_lr to end_learning_rate --- .../fluid/tests/unittests/test_lr_scheduler.py | 18 +++++++++--------- python/paddle/optimizer/lr.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index ee8e5834967cc..96a818549e700 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -325,7 +325,7 @@ def one_cycle_lr(epoch_num, max_learning_rate, total_steps, divide_factor=25, - end_lr=0.0001, + end_learning_rate=0.0001, phase_pct=0.3, anneal_strategy='cos', three_phase=False, @@ -347,7 +347,7 @@ def one_cycle_lr(epoch_num, }, { 'start_lr': initial_lr, - 'end_lr': end_lr, + 'end_lr': end_learning_rate, }, ] else: @@ -359,7 +359,7 @@ def one_cycle_lr(epoch_num, }, { 'start_lr': max_learning_rate, - 'end_lr': end_lr, + 'end_lr': end_learning_rate, }, ] @@ -539,10 +539,10 @@ def test_scheduler(self): max_learning_rate=-1.5, total_steps=20) with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR( - max_learning_rate=0.1, total_steps=20, end_lr='test') + max_learning_rate=0.1, total_steps=20, end_learning_rate='test') with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR( - max_learning_rate=0.1, total_steps=20, end_lr=-1) + max_learning_rate=0.1, total_steps=20, end_learning_rate=-1) with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR( max_learning_rate=0.1, total_steps='test') @@ -622,7 +622,7 @@ def test_scheduler(self): "max_learning_rate": 0.1, "total_steps": 20, "divide_factor": 5, - "end_lr": 0.0001, + "end_learning_rate": 0.0001, "anneal_strategy": 'cos', "phase_pct": 0.3, "three_phase": False, @@ -630,7 +630,7 @@ def test_scheduler(self): "max_learning_rate": 0.5, "total_steps": 20, "divide_factor": 10, - "end_lr": 0.001, + "end_learning_rate": 0.001, "anneal_strategy": 'linear', "phase_pct": 0.4, "three_phase": False, @@ -638,7 +638,7 @@ def test_scheduler(self): "max_learning_rate": 1.0, "total_steps": 20, "divide_factor": 9, - "end_lr": 0.0001, + "end_learning_rate": 0.0001, "anneal_strategy": 'cos', "phase_pct": 0.3, "three_phase": True, @@ -646,7 +646,7 @@ def test_scheduler(self): "max_learning_rate": 0.3, "total_steps": 20, "divide_factor": 25, - "end_lr": 0.0005, + "end_learning_rate": 0.0005, "anneal_strategy": 'linear', "phase_pct": 0.2, "three_phase": True, diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index c3de4b781174b..12b8272707bd8 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -1613,7 +1613,7 @@ class OneCycleLR(LRScheduler): Functionally, it defines the initial learning rate by ``divide_factor`` . total_steps (int): Number of total training steps. divide_factor (float): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25. - end_lr (float, optional): The minimum learning rate during training, it should be much less than initial learning rate. + end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate. phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3. anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'. @@ -1682,7 +1682,7 @@ def __init__(self, max_learning_rate, total_steps, divide_factor=25., - end_lr=0.0001, + end_learning_rate=0.0001, phase_pct=0.3, anneal_strategy='cos', three_phase=False, @@ -1696,13 +1696,13 @@ def __init__(self, if max_learning_rate < 0: raise ValueError("'max_learning_rate' must be a positive integer.") - # Check type and value of end_lr - if not isinstance(end_lr, (float, int)): + # Check type and value of end_learning_rate + if not isinstance(end_learning_rate, (float, int)): raise TypeError( - "'end_lr' must be 'float' or 'int', but received {}".format( - type(total_steps))) - if end_lr < 0: - raise ValueError("'end_lr' must be a positive integer.") + "'end_learning_rate' must be 'float' or 'int', but received {}". + format(type(total_steps))) + if end_learning_rate < 0: + raise ValueError("'end_learning_rate' must be a positive integer.") # Check type and value of total_steps if not isinstance(total_steps, int): @@ -1728,7 +1728,7 @@ def __init__(self, format(type(divide_factor))) initial_lr = max_learning_rate / float(divide_factor) - min_lr = float(end_lr) + min_lr = float(end_learning_rate) if three_phase: if phase_pct >= 0.5: