diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index 60dd4948f996e..96a818549e700 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -321,6 +321,70 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False): return learning_rate * math.pow(gamma, epoch_num // step_size) +def one_cycle_lr(epoch_num, + max_learning_rate, + total_steps, + divide_factor=25, + end_learning_rate=0.0001, + phase_pct=0.3, + anneal_strategy='cos', + three_phase=False, + verbose=False): + initial_lr = max_learning_rate / divide_factor + if three_phase: + _end_steps = [ + float(phase_pct * total_steps) - 1, + float(2 * phase_pct * total_steps) - 2, total_steps - 1 + ] + _schedule_phases = [ + { + 'start_lr': initial_lr, + 'end_lr': max_learning_rate, + }, + { + 'start_lr': max_learning_rate, + 'end_lr': initial_lr, + }, + { + 'start_lr': initial_lr, + 'end_lr': end_learning_rate, + }, + ] + else: + _end_steps = [float(phase_pct * total_steps) - 1, total_steps - 1] + _schedule_phases = [ + { + 'start_lr': initial_lr, + 'end_lr': max_learning_rate, + }, + { + 'start_lr': max_learning_rate, + 'end_lr': end_learning_rate, + }, + ] + + if anneal_strategy == 'cos': + + def anneal_func(start, end, pct): + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + else: + + def anneal_func(start, end, pct): + return (end - start) * pct + start + + start_step = 0 + for i, phase in enumerate(_schedule_phases): + end_step = _end_steps[i] + if epoch_num <= end_step or i == len(_schedule_phases) - 1: + pct = (epoch_num - start_step) / (end_step - start_step) + computed_lr = anneal_func(phase['start_lr'], phase['end_lr'], pct) + break + start_step = end_step + + return computed_lr + + class TestLRScheduler(unittest.TestCase): def _test_static(self, python_func, paddle_api, kwarg, place): scheduler = paddle_api(**kwarg) @@ -467,6 +531,33 @@ def test_scheduler(self): with self.assertRaises(ValueError): paddle.optimizer.lr.MultiStepDecay( learning_rate=0.5, milestones=[1, 2, 3], gamma=2) + with self.assertRaises(TypeError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate='test', total_steps=20) + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=-1.5, total_steps=20) + with self.assertRaises(TypeError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps=20, end_learning_rate='test') + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps=20, end_learning_rate=-1) + with self.assertRaises(TypeError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps='test') + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps=-10) + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, total_steps=20, anneal_strategy='test') + with self.assertRaises(ValueError): + paddle.optimizer.lr.OneCycleLR( + max_learning_rate=0.1, + total_steps=20, + phase_pct=0.6, + three_phase=True) func_api_kwargs = [(noam_lr, paddle.optimizer.lr.NoamDecay, { "d_model": 0.01, @@ -527,6 +618,38 @@ def test_scheduler(self): "learning_rate": 0.5, "T_max": 10, "verbose": False + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 0.1, + "total_steps": 20, + "divide_factor": 5, + "end_learning_rate": 0.0001, + "anneal_strategy": 'cos', + "phase_pct": 0.3, + "three_phase": False, + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 0.5, + "total_steps": 20, + "divide_factor": 10, + "end_learning_rate": 0.001, + "anneal_strategy": 'linear', + "phase_pct": 0.4, + "three_phase": False, + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 1.0, + "total_steps": 20, + "divide_factor": 9, + "end_learning_rate": 0.0001, + "anneal_strategy": 'cos', + "phase_pct": 0.3, + "three_phase": True, + }), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, { + "max_learning_rate": 0.3, + "total_steps": 20, + "divide_factor": 25, + "end_learning_rate": 0.0005, + "anneal_strategy": 'linear', + "phase_pct": 0.2, + "three_phase": True, })] for python_func, paddle_api, kwarg in func_api_kwargs: diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index ea4349bc0b2c5..12b8272707bd8 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -33,7 +33,8 @@ 'LambdaDecay', 'ReduceOnPlateau', 'CosineAnnealingDecay', - 'MultiplicativeDecay' + 'MultiplicativeDecay', + 'OneCycleLR' ] @@ -1591,3 +1592,212 @@ def get_lr(self): for epoch in range(1, self.last_epoch + 1): cur_lr = cur_lr * self.lr_lambda(epoch) return cur_lr + + +class OneCycleLR(LRScheduler): + r""" + Sets the learning rate according to the one cycle learning rate scheduler. + The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then + from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate. + + It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates `_. + + Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle, + which claims that “unpublished work has shown even better results by using only two phases”. + If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` . + + Also note that you should update learning rate each step. + + Args: + max_learning_rate (float): The maximum learning rate. It is a python float number. + Functionally, it defines the initial learning rate by ``divide_factor`` . + total_steps (int): Number of total training steps. + divide_factor (float): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25. + end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate. + phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3. + anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, + 'linear' for linear annealing. Default: 'cos'. + three_phase (bool, optional): Whether to use three phase. + If ``True``: + 1. The learning rate will first increase from initial learning rate to maximum learning rate. + 2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase. + 3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate. + If ``False``: + 1. The learning rate will increase to maximum learning rate. + 2. Then it will directly decrease to minimum learning rate. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``OneCycleLR`` instance to schedule learning rate. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + # train on default dynamic graph mode + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) + for epoch in range(5): + for batch_id in range(20): + x = paddle.uniform([10, 10]) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_gradients() + scheduler.step() # You should update learning rate each step + + # train on static graph mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[None, 4, 5]) + y = paddle.static.data(name='y', shape=[None, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(5): + for batch_id in range(20): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=loss.name) + scheduler.step() # You should update learning rate each step + """ + + def __init__(self, + max_learning_rate, + total_steps, + divide_factor=25., + end_learning_rate=0.0001, + phase_pct=0.3, + anneal_strategy='cos', + three_phase=False, + last_epoch=-1, + verbose=False): + # Check type and value of max_learning_rate + if not isinstance(max_learning_rate, (float, int)): + raise TypeError( + "'max_learning_rate' must be 'float' or 'int', but received {}". + format(type(total_steps))) + if max_learning_rate < 0: + raise ValueError("'max_learning_rate' must be a positive integer.") + + # Check type and value of end_learning_rate + if not isinstance(end_learning_rate, (float, int)): + raise TypeError( + "'end_learning_rate' must be 'float' or 'int', but received {}". + format(type(total_steps))) + if end_learning_rate < 0: + raise ValueError("'end_learning_rate' must be a positive integer.") + + # Check type and value of total_steps + if not isinstance(total_steps, int): + raise TypeError("'total_step' must be 'int', but received {}". + format(type(total_steps))) + if total_steps <= 0: + raise ValueError("'total_step' must be a positive integer.") + self.total_steps = total_steps + + # Check type and value of pac_start + if not isinstance(phase_pct, float): + raise TypeError("'phase_pct' must be 'float', but received {}". + format(type(phase_pct))) + if phase_pct < 0 or phase_pct > 1: + raise ValueError( + "'phase_pct' must be between 0 and 1, but received {}".format( + phase_pct)) + + # Check type and value of divide_factor + if not isinstance(divide_factor, (float, int)): + raise TypeError( + "'divide_factor' must be 'float' or 'int', but received {}". + format(type(divide_factor))) + + initial_lr = max_learning_rate / float(divide_factor) + min_lr = float(end_learning_rate) + + if three_phase: + if phase_pct >= 0.5: + raise ValueError( + "When three_phase is True, 'phase_pct' must be less than 0.5" + ) + # start step and end step of each phase. + self._step_config = [ + 0, + phase_pct * self.total_steps - 1, + 2 * phase_pct * self.total_steps - 2, + self.total_steps - 1, + self.total_steps - 1, # for the last step. + ] + # step size of each phase. + self._steps_size = [ + self._step_config[1] - self._step_config[0], + self._step_config[2] - self._step_config[1], + self._step_config[3] - self._step_config[2], + self._step_config[3] - + self._step_config[2], # for the last step. + ] + # start lr and end lr of each phase. + self._lr_config = [ + initial_lr, max_learning_rate, initial_lr, min_lr + ] + else: + self._step_config = [ + 0, phase_pct * self.total_steps - 1, self.total_steps - 1, + self.total_steps - 1 + ] + self._steps_size = [ + self._step_config[1] - self._step_config[0], + self._step_config[2] - self._step_config[1], + self._step_config[2] - self._step_config[1], + ] + self._lr_config = [initial_lr, max_learning_rate, min_lr] + + # Check anneal_strategy + if anneal_strategy == 'cos': + self.anneal_func = self._cos_annealing + elif anneal_strategy == 'linear': + self.anneal_func = self._linear_annealing + else: + raise ValueError( + "'anneal_strategy' must by one of 'cos' or 'linear', but received {}". + format(anneal_strategy)) + super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose) + + def _cos_annealing(self, start_lr, end_lr, pct): + cos_out = math.cos(math.pi * pct) + 1 + return end_lr + (start_lr - end_lr) / 2.0 * cos_out + + def _linear_annealing(self, start_lr, end_lr, pct): + return (end_lr - start_lr) * pct + start_lr + + def get_lr(self): + current_step = self.last_epoch + + if current_step > self.total_steps: + raise ValueError( + "Tried to step {} times. However the number of total steps is {}" + .format(current_step, self.total_steps)) + + for (i, (end_step, step_size) + ) in enumerate(zip(self._step_config[1:], self._steps_size)): + # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None. + if current_step <= end_step or i == len(self._lr_config) - 2: + # self._step_config[i] means start step of a phase. + percentage = (current_step - self._step_config[i]) / step_size + return self.anneal_func(self._lr_config[i], + self._lr_config[i + 1], percentage)