Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed May 11, 2022
1 parent 98f9a9e commit 70f97a8
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 62 deletions.
57 changes: 31 additions & 26 deletions python/paddle/fluid/tests/unittests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,43 +322,43 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):


def one_cycle_lr(epoch_num,
learning_rate,
max_learning_rate,
total_steps,
scale_factor=25,
divide_factor=25,
end_lr=0.0001,
phase_pct=0.3,
anneal_strategy='cos',
three_phase=False,
verbose=False):
max_lr = learning_rate * scale_factor
initial_lr = max_learning_rate / divide_factor
if three_phase:
_end_steps = [
float(phase_pct * total_steps) - 1,
float(2 * phase_pct * total_steps) - 2, total_steps - 1
]
_schedule_phases = [
{
'start_lr': learning_rate,
'end_lr': max_lr,
'start_lr': initial_lr,
'end_lr': max_learning_rate,
},
{
'start_lr': max_lr,
'end_lr': learning_rate,
'start_lr': max_learning_rate,
'end_lr': initial_lr,
},
{
'start_lr': learning_rate,
'start_lr': initial_lr,
'end_lr': end_lr,
},
]
else:
_end_steps = [float(phase_pct * total_steps) - 1, total_steps - 1]
_schedule_phases = [
{
'start_lr': learning_rate,
'end_lr': max_lr,
'start_lr': initial_lr,
'end_lr': max_learning_rate,
},
{
'start_lr': max_lr,
'start_lr': max_learning_rate,
'end_lr': end_lr,
},
]
Expand Down Expand Up @@ -532,24 +532,29 @@ def test_scheduler(self):
paddle.optimizer.lr.MultiStepDecay(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(learning_rate='test', total_steps=20)
paddle.optimizer.lr.OneCycleLR(
max_learning_rate='test', total_steps=20)
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=-1.5, total_steps=20)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps=20, end_lr='test')
max_learning_rate=0.1, total_steps=20, end_lr='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps=20, end_lr=-1)
max_learning_rate=0.1, total_steps=20, end_lr=-1)
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps='test')
max_learning_rate=0.1, total_steps='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(learning_rate=0.1, total_steps=-10)
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=-10)
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1, total_steps=20, anneal_strategy='test')
max_learning_rate=0.1, total_steps=20, anneal_strategy='test')
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
learning_rate=0.1,
max_learning_rate=0.1,
total_steps=20,
phase_pct=0.6,
three_phase=True)
Expand Down Expand Up @@ -614,33 +619,33 @@ def test_scheduler(self):
"T_max": 10,
"verbose": False
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.1,
"max_learning_rate": 0.1,
"total_steps": 20,
"scale_factor": 5,
"divide_factor": 5,
"end_lr": 0.0001,
"anneal_strategy": 'cos',
"phase_pct": 0.3,
"three_phase": False,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.5,
"max_learning_rate": 0.5,
"total_steps": 20,
"scale_factor": 10,
"divide_factor": 10,
"end_lr": 0.001,
"anneal_strategy": 'linear',
"phase_pct": 0.4,
"three_phase": False,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.1,
"max_learning_rate": 1.0,
"total_steps": 20,
"scale_factor": 9,
"divide_factor": 9,
"end_lr": 0.0001,
"anneal_strategy": 'cos',
"phase_pct": 0.3,
"three_phase": True,
}), (one_cycle_lr, paddle.optimizer.lr.OneCycleLR, {
"learning_rate": 0.3,
"max_learning_rate": 0.3,
"total_steps": 20,
"scale_factor": 25,
"divide_factor": 25,
"end_lr": 0.0005,
"anneal_strategy": 'linear',
"phase_pct": 0.2,
Expand Down
94 changes: 58 additions & 36 deletions python/paddle/optimizer/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,11 +1609,11 @@ class OneCycleLR(LRScheduler):
Also note that you should update learning rate each step.
Args:
learning_rate (float): The initial learning rate. It is a python float number.
Functionally, it defines the maximum learning rate by ``scale_factor`` .
max_learning_rate (float): The maximum learning rate. It is a python float number.
Functionally, it defines the initial learning rate by ``divide_factor`` .
total_steps (int): Number of total training steps.
scale_factor (float): Maximum learning rate will be determined by maximum_lr = learning_rate * scale_factor. Default: 25.
end_lr (float, optional): The minimum learning rate of schedule, it should be much less than initial learning rate.
divide_factor (float): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
end_lr (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing,
'linear' for linear annealing. Default: 'cos'.
Expand All @@ -1639,7 +1639,7 @@ class OneCycleLR(LRScheduler):
# train on default dynamic graph mode
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True)
scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
for epoch in range(5):
for batch_id in range(20):
Expand All @@ -1660,7 +1660,7 @@ class OneCycleLR(LRScheduler):
y = paddle.static.data(name='y', shape=[None, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle.optimizer.lr.OneCycleLR(learning_rate=1.0, total_steps=100, scale_factor=5, verbose=True)
scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
sgd.minimize(loss)
Expand All @@ -1679,19 +1679,27 @@ class OneCycleLR(LRScheduler):
"""

def __init__(self,
learning_rate,
max_learning_rate,
total_steps,
scale_factor=25,
divide_factor=25.,
end_lr=0.0001,
phase_pct=0.3,
anneal_strategy='cos',
three_phase=False,
last_epoch=-1,
verbose=False):
# Check type and value of max_learning_rate
if not isinstance(max_learning_rate, (float, int)):
raise TypeError(
"'max_learning_rate' must be 'float' or 'int', but received {}".
format(type(total_steps)))
if max_learning_rate < 0:
raise ValueError("'max_learning_rate' must be a positive integer.")

# Check type and value of end_lr
if not isinstance(end_lr, (float, int)):
raise TypeError(
"'total_step' must be 'float' or 'int', but received {}".format(
"'end_lr' must be 'float' or 'int', but received {}".format(
type(total_steps)))
if end_lr < 0:
raise ValueError("'end_lr' must be a positive integer.")
Expand All @@ -1713,49 +1721,62 @@ def __init__(self,
"'phase_pct' must be between 0 and 1, but received {}".format(
phase_pct))

max_lr = learning_rate * scale_factor
# Check type and value of divide_factor
if not isinstance(divide_factor, (float, int)):
raise TypeError(
"'divide_factor' must be 'float' or 'int', but received {}".
format(type(divide_factor)))

initial_lr = max_learning_rate / float(divide_factor)
min_lr = float(end_lr)

if three_phase:
if phase_pct >= 0.5:
raise ValueError(
"When three_phase is True, 'phase_pct' must be smaller than 0.5"
"When three_phase is True, 'phase_pct' must be less than 0.5"
)
self._start_steps = [
# start step and end step of each phase.
self._step_config = [
0,
phase_pct * self.total_steps - 1,
2 * phase_pct * self.total_steps - 2,
self.total_steps - 1,
self.total_steps - 1, # for the last step.
]
# step size of each phase.
self._steps_size = [
self._start_steps[1] - self._start_steps[0],
self._start_steps[2] - self._start_steps[1],
self._start_steps[3] - self._start_steps[2],
self._start_steps[3] - self._start_steps[2],
self._step_config[1] - self._step_config[0],
self._step_config[2] - self._step_config[1],
self._step_config[3] - self._step_config[2],
self._step_config[3] -
self._step_config[2], # for the last step.
]
# start lr and end lr of each phase.
self._lr_config = [
initial_lr, max_learning_rate, initial_lr, min_lr
]
self._lr_config = [learning_rate, max_lr, learning_rate, min_lr]
else:
self._start_steps = [
0, phase_pct * self.total_steps - 1, self.total_steps - 1
self._step_config = [
0, phase_pct * self.total_steps - 1, self.total_steps - 1,
self.total_steps - 1
]
self._steps_size = [
self._start_steps[1] - self._start_steps[0],
self._start_steps[2] - self._start_steps[1],
self._start_steps[2] - self._start_steps[1],
self._step_config[1] - self._step_config[0],
self._step_config[2] - self._step_config[1],
self._step_config[2] - self._step_config[1],
]
self._lr_config = [learning_rate, max_lr, min_lr]
self._lr_config = [initial_lr, max_learning_rate, min_lr]

# Check anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError(
"'anneal_strategy' must by one of 'cos' or 'linear', but received {}".
format(anneal_strategy))
elif anneal_strategy == 'cos':
if anneal_strategy == 'cos':
self.anneal_func = self._cos_annealing
elif anneal_strategy == 'linear':
self.anneal_func = self._linear_annealing

super(OneCycleLR, self).__init__(learning_rate, last_epoch, verbose)
else:
raise ValueError(
"'anneal_strategy' must by one of 'cos' or 'linear', but received {}".
format(anneal_strategy))
super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose)

def _cos_annealing(self, start_lr, end_lr, pct):
cos_out = math.cos(math.pi * pct) + 1
Expand All @@ -1769,13 +1790,14 @@ def get_lr(self):

if current_step > self.total_steps:
raise ValueError(
"Tried to step {} times. The specified number of total steps is {}"
"Tried to step {} times. However the number of total steps is {}"
.format(current_step, self.total_steps))

for (i, (start_step, step_size)
) in enumerate(zip(self._start_steps, self._steps_size)):
if current_step <= (start_step + step_size
) or i == len(self._lr_config) - 2:
percentage = (current_step - start_step) / step_size
for (i, (end_step, step_size)
) in enumerate(zip(self._step_config[1:], self._steps_size)):
# i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
if current_step <= end_step or i == len(self._lr_config) - 2:
# self._step_config[i] means start step of a phase.
percentage = (current_step - self._step_config[i]) / step_size
return self.anneal_func(self._lr_config[i],
self._lr_config[i + 1], percentage)

0 comments on commit 70f97a8

Please sign in to comment.