diff --git a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py index fb2038819fcc5..def22575eea91 100644 --- a/python/paddle/fluid/tests/unittests/test_lr_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_lr_scheduler.py @@ -389,6 +389,53 @@ def anneal_func(start, end, pct): return computed_lr +def cyclic_lr(epoch_num, + base_learning_rate, + max_learning_rate, + step_size_up, + step_size_down, + mode, + exp_gamma=0.1, + scale_fn=None, + scale_mode='cycle', + verbose=False): + total_steps = step_size_up + step_size_down + step_ratio = step_size_up / total_steps + + def triangular(x): + return 1. + + def triangular2(x): + return 1 / (2.**(x - 1)) + + def exp_range(x): + return exp_gamma**x + + if scale_fn is None: + if mode == 'triangular': + scale_fn = triangular + scale_mode = 'cycle' + elif mode == 'triangular2': + scale_fn = triangular2 + scale_mode = 'cycle' + elif mode == 'exp_range': + scale_fn = exp_range + scale_mode = 'iterations' + + cycle = math.floor(1 + epoch_num / total_steps) + iterations = epoch_num + x = 1. + epoch_num / total_steps - cycle + + if x <= step_ratio: + scale_factor = x / step_ratio + else: + scale_factor = (x - 1) / (step_ratio - 1) + + base_height = (max_learning_rate - base_learning_rate) * scale_factor + + return base_learning_rate + base_height * scale_fn(eval(scale_mode)) + + class TestLRScheduler(unittest.TestCase): def _test_static(self, python_func, paddle_api, kwarg, place): @@ -533,35 +580,89 @@ def test_scheduler(self): paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[1, 2, 3], gamma=2) + # check type of max_learning_rate with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR(max_learning_rate='test', total_steps=20) + # check value of max_learning_rate with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR(max_learning_rate=-1.5, total_steps=20) + # check type of end_learning_rate with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, total_steps=20, end_learning_rate='test') + # check value of end_learning_rate with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, total_steps=20, end_learning_rate=-1) + # check type of total_steps with self.assertRaises(TypeError): paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, total_steps='test') + # check value of total_steps with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, total_steps=-10) + # check value of anneal_strategy with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, total_steps=20, anneal_strategy='test') + # check value of phase_pct when three_phase is True with self.assertRaises(ValueError): paddle.optimizer.lr.OneCycleLR(max_learning_rate=0.1, total_steps=20, phase_pct=0.6, three_phase=True) + # check type of max_learning_rate + with self.assertRaises(TypeError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate='test', + step_size_up=10) + # check value of max_learning_rate + with self.assertRaises(ValueError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=-1, + step_size_up=10) + # check type of step_size_up + with self.assertRaises(TypeError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=1.0, + step_size_up='test') + # check value of step_size_up + with self.assertRaises(ValueError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=1.0, + step_size_up=-1) + # check type of step_size_down + with self.assertRaises(TypeError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=1.0, + step_size_up=500, + step_size_down='test') + # check type of step_size_down + with self.assertRaises(ValueError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=1.0, + step_size_up=500, + step_size_down=-1) + # check value of mode + with self.assertRaises(ValueError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=1.0, + step_size_up=500, + step_size_down=500, + mode='test') + # check type value of scale_mode + with self.assertRaises(ValueError): + paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=1.0, + step_size_up=500, + step_size_down=-1, + scale_mode='test') func_api_kwargs = [ (noam_lr, paddle.optimizer.lr.NoamDecay, { @@ -671,6 +772,61 @@ def test_scheduler(self): "anneal_strategy": 'linear', "phase_pct": 0.2, "three_phase": True, + }), + (cyclic_lr, paddle.optimizer.lr.CyclicLR, { + "base_learning_rate": 0.5, + "max_learning_rate": 1.0, + "step_size_up": 15, + "step_size_down": 5, + "mode": 'triangular', + "exp_gamma": 1., + "scale_fn": None, + "scale_mode": 'cycle', + "verbose": False + }), + (cyclic_lr, paddle.optimizer.lr.CyclicLR, { + "base_learning_rate": 0.5, + "max_learning_rate": 1.0, + "step_size_up": 15, + "step_size_down": 5, + "mode": 'triangular2', + "exp_gamma": 1., + "scale_fn": None, + "scale_mode": 'cycle', + "verbose": False + }), + (cyclic_lr, paddle.optimizer.lr.CyclicLR, { + "base_learning_rate": 0.5, + "max_learning_rate": 1.0, + "step_size_up": 15, + "step_size_down": 5, + "mode": 'exp_range', + "exp_gamma": 0.8, + "scale_fn": None, + "scale_mode": 'cycle', + "verbose": False + }), + (cyclic_lr, paddle.optimizer.lr.CyclicLR, { + "base_learning_rate": 0.5, + "max_learning_rate": 1.0, + "step_size_up": 15, + "step_size_down": 5, + "mode": 'exp_range', + "exp_gamma": 1., + "scale_fn": lambda x: 0.95**x, + "scale_mode": 'cycle', + "verbose": False + }), + (cyclic_lr, paddle.optimizer.lr.CyclicLR, { + "base_learning_rate": 0.5, + "max_learning_rate": 1.0, + "step_size_up": 15, + "step_size_down": 5, + "mode": 'exp_range', + "exp_gamma": 1., + "scale_fn": lambda x: 0.95, + "scale_mode": 'iterations', + "verbose": False }) ] diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 883b2c1481703..4d7d128e05e49 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -20,10 +20,22 @@ from ..fluid.framework import _in_legacy_dygraph __all__ = [ # noqa - 'LRScheduler', 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', - 'InverseTimeDecay', 'PolynomialDecay', 'LinearWarmup', 'ExponentialDecay', - 'MultiStepDecay', 'StepDecay', 'LambdaDecay', 'ReduceOnPlateau', - 'CosineAnnealingDecay', 'MultiplicativeDecay', 'OneCycleLR' + 'LRScheduler', + 'NoamDecay', + 'PiecewiseDecay', + 'NaturalExpDecay', + 'InverseTimeDecay', + 'PolynomialDecay', + 'LinearWarmup', + 'ExponentialDecay', + 'MultiStepDecay', + 'StepDecay', + 'LambdaDecay', + 'ReduceOnPlateau', + 'CosineAnnealingDecay', + 'MultiplicativeDecay', + 'OneCycleLR', + 'CyclicLR', ] @@ -1681,7 +1693,7 @@ def __init__(self, if not isinstance(max_learning_rate, (float, int)): raise TypeError( "'max_learning_rate' must be 'float' or 'int', but received {}". - format(type(total_steps))) + format(type(max_learning_rate))) if max_learning_rate < 0: raise ValueError("'max_learning_rate' must be a positive integer.") @@ -1689,7 +1701,7 @@ def __init__(self, if not isinstance(end_learning_rate, (float, int)): raise TypeError( "'end_learning_rate' must be 'float' or 'int', but received {}". - format(type(total_steps))) + format(type(end_learning_rate))) if end_learning_rate < 0: raise ValueError("'end_learning_rate' must be a positive integer.") @@ -1792,3 +1804,205 @@ def get_lr(self): percentage = (current_step - self._step_config[i]) / step_size return self.anneal_func(self._lr_config[i], self._lr_config[i + 1], percentage) + + +class CyclicLR(LRScheduler): + r""" + Set the learning rate according to the cyclic learning rate (CLR) scheduler. + The scheduler regards the process of learning rate adjustment as one cycle after another. + It cycles the learning rate between two boundaries with a constant frequency. + The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis. + + It has been proposed in `Cyclic Learning Rates for Training Neural Networks `_. + + According to the paper, the cyclic learning rate schedule has three build-in scale methods: + + * "triangular": A basic triangular cycle without any amplitude scaling. + * "triangular2": A basic triangular cycle that reduce initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by scale function which is defined as :math:`gamma^{iterations}` . + + The initial amplitude is defined as max_learning_rate - base_learning_rate. + Also note that you should update learning rate each step. + + Args: + base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends + that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate. + max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above. + Since there is some scaling operation during process of learning rate adjustment, + max_learning_rate may not actually be reached. + step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle. + The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step + size should be set as at least 3 or 4 times steps in one epoch. + step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle. + If not specified, it's value will initialize to `` step_size_up `` . Default: None + mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'. + If scale_fn is specified, this argument will be ignored. Default: 'triangular' + exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0 + scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods. + It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1. + If specified, then 'mode' will be ignored. Default: None + scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle + number or cycle iterations (total iterations since start of training). Default: 'cycle' + last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate. + verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``CyclicLR`` instance to schedule learning rate. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + # train on default dynamic graph mode + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) + for epoch in range(5): + for batch_id in range(20): + x = paddle.uniform([10, 10]) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_gradients() + scheduler.step() # You should update learning rate each step + + # train on static graph mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[None, 4, 5]) + y = paddle.static.data(name='y', shape=[None, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, + max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(5): + for batch_id in range(20): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=loss.name) + scheduler.step() # You should update learning rate each step + """ + + def __init__(self, + base_learning_rate, + max_learning_rate, + step_size_up, + step_size_down=None, + mode='triangular', + exp_gamma=1., + scale_fn=None, + scale_mode='cycle', + last_epoch=-1, + verbose=False): + # check type and value of max_learning_rate + if not isinstance(max_learning_rate, (float, int)): + raise TypeError( + "'max_learning_rate' must be 'float' or 'int', but received {}". + format(type(max_learning_rate))) + if max_learning_rate < 0: + raise ValueError( + "'max_learning_rate' must be a positive integer, but received {}" + .format(max_learning_rate)) + + # check type and value of step_size_up + if not isinstance(step_size_up, int): + raise TypeError( + "The type of 'step_size_up' must be int, but received {}". + format(type(step_size_up))) + if step_size_up <= 0: + raise ValueError( + "'step_size_up' must be a positive integer, but received {}". + format(step_size_up)) + + # check type and value of step_size_down + if step_size_down is not None: + if not isinstance(step_size_down, int): + raise TypeError( + "The type of 'step_size_down' must be int, but received {}". + format(type(step_size_down))) + if step_size_down <= 0: + raise ValueError( + "'step_size_down' must be a positive integer, but received {}" + .format(step_size_down)) + + # check type of exp_gamma + if not isinstance(exp_gamma, float): + raise TypeError( + "The type of 'exp_gamma' must be float, but received {}".format( + type(exp_gamma))) + + step_size_up = float(step_size_up) + step_size_down = float( + step_size_down) if step_size_down is not None else step_size_up + + self.cycle_size = step_size_up + step_size_down + self.step_up_pct = step_size_up / self.cycle_size + self.max_lr = float(max_learning_rate) + self.amplitude = self.max_lr - base_learning_rate + + if mode not in ['triangular', 'triangular2', 'exp_range' + ] and scale_fn is None: + raise ValueError( + "'mode' is invalid and 'scale_fn' is not specified, make sure one of 'mode' or 'scale_fn' is valid" + ) + if scale_mode not in ['cycle', 'iterations']: + raise ValueError( + "'scale_mode' must be one of 'cycle' or 'iterations") + + self.mode = mode + self.gamma = exp_gamma # only for exp_range mode + + if scale_fn is None: + if self.mode == 'triangular': + self.scale_fn = self._triangular_scale_fn + self.scale_mode = 'cycle' + elif self.mode == 'triangular2': + self.scale_fn = self._triangular2_scale_fn + self.scale_mode = 'cycle' + elif self.mode == 'exp_range': + self.scale_fn = self._exp_range_scale_fn + self.scale_mode = 'iterations' + else: + self.scale_fn = scale_fn + self.scale_mode = scale_mode + super().__init__(base_learning_rate, last_epoch, verbose) + + def _triangular_scale_fn(self, x): + return 1. + + def _triangular2_scale_fn(self, x): + return 1 / (2.**(x - 1)) + + def _exp_range_scale_fn(self, x): + return self.gamma**x + + def get_lr(self): + iterations = self.last_epoch + + cycle = 1 + iterations // self.cycle_size + pct_per_cycle = 1. + iterations / self.cycle_size - cycle + + if pct_per_cycle <= self.step_up_pct: + scale_factor = pct_per_cycle / self.step_up_pct + else: + scale_factor = (1 - pct_per_cycle) / (1 - self.step_up_pct) + + base_height = self.amplitude * scale_factor + + lr = self.base_lr + base_height * self.scale_fn(eval(self.scale_mode)) + + return lr