Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.13】为 Paddle 新增 CyclicLR 优化调度器 #40698

Merged
merged 33 commits into from Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
620ac56
add paddle.optimizer.lr.CyclicLR
Asthestarsfalll Mar 18, 2022
4fe10c4
add unittest of CyclicLR
Asthestarsfalll Mar 18, 2022
41a2fe0
fix code format
Asthestarsfalll Mar 28, 2022
2ba48b6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll Mar 28, 2022
d47a8c0
fix bug
Asthestarsfalll Mar 31, 2022
9318274
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll Mar 31, 2022
efdb2fe
try
Asthestarsfalll Apr 2, 2022
ac13dfb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll Apr 2, 2022
fc0d68c
fix CI-Coverage
Asthestarsfalll Apr 3, 2022
3893c1c
fix ValueError
Asthestarsfalll Apr 3, 2022
4bc9e22
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll Apr 3, 2022
e0b0558
fix arguments assgin
Asthestarsfalll Apr 3, 2022
e9acc0e
fix code format and retry pulling develop to pass ci
Asthestarsfalll Apr 3, 2022
4189f26
fix typo
Asthestarsfalll Apr 9, 2022
86403df
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll Apr 9, 2022
1ff05ac
Merge branch 'PaddlePaddle:develop' into cycliclr
Asthestarsfalll Apr 11, 2022
8502161
Merge branch 'develop' into cycliclr
Asthestarsfalll May 14, 2022
ab8cc5e
Refactor
Asthestarsfalll May 14, 2022
301256c
fix function-redefined in test_lr_scheduler.py
Asthestarsfalll May 14, 2022
558ee62
update
Asthestarsfalll May 16, 2022
0d6fb3f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll May 16, 2022
12932f1
fix conflict
Asthestarsfalll May 16, 2022
b9d6355
update
Asthestarsfalll May 20, 2022
a896f5a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll May 20, 2022
ef0b383
gamma->exp_gamma
Asthestarsfalll May 20, 2022
03d7fac
polish docs
Asthestarsfalll May 22, 2022
85e5eb1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll May 22, 2022
5ebf5c1
fix code-style
Asthestarsfalll Jun 1, 2022
36ca157
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll Jun 1, 2022
cab314a
fix conflit and update code format
Asthestarsfalll Jun 6, 2022
f86f9e3
adjust code format again
Asthestarsfalll Jun 7, 2022
07db9a0
change format of __all__ in lr.py
Asthestarsfalll Jun 7, 2022
5227cd6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Asthestarsfalll Jun 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
155 changes: 155 additions & 0 deletions python/paddle/fluid/tests/unittests/test_lr_scheduler.py
Expand Up @@ -385,6 +385,53 @@ def anneal_func(start, end, pct):
return computed_lr


def cyclic_lr(epoch_num,
base_learning_rate,
max_learning_rate,
step_size_up,
step_size_down,
mode,
exp_gamma=0.1,
scale_fn=None,
scale_mode='cycle',
verbose=False):
total_steps = step_size_up + step_size_down
step_ratio = step_size_up / total_steps

def triangular(x):
return 1.

def triangular2(x):
return 1 / (2.**(x - 1))

def exp_range(x):
return exp_gamma**x

if scale_fn is None:
if mode == 'triangular':
scale_fn = triangular
scale_mode = 'cycle'
elif mode == 'triangular2':
scale_fn = triangular2
scale_mode = 'cycle'
elif mode == 'exp_range':
scale_fn = exp_range
scale_mode = 'iterations'

cycle = math.floor(1 + epoch_num / total_steps)
iterations = epoch_num
x = 1. + epoch_num / total_steps - cycle

if x <= step_ratio:
scale_factor = x / step_ratio
else:
scale_factor = (x - 1) / (step_ratio - 1)

base_height = (max_learning_rate - base_learning_rate) * scale_factor

return base_learning_rate + base_height * scale_fn(eval(scale_mode))


class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
scheduler = paddle_api(**kwarg)
Expand Down Expand Up @@ -531,33 +578,91 @@ def test_scheduler(self):
with self.assertRaises(ValueError):
paddle.optimizer.lr.MultiStepDecay(
learning_rate=0.5, milestones=[1, 2, 3], gamma=2)
# check type of max_learning_rate
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate='test', total_steps=20)
# check value of max_learning_rate
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=-1.5, total_steps=20)
# check type of end_learning_rate
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=20, end_learning_rate='test')
# check value of end_learning_rate
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=20, end_learning_rate=-1)
# check type of total_steps
with self.assertRaises(TypeError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps='test')
# check value of total_steps
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=-10)
# check value of anneal_strategy
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1, total_steps=20, anneal_strategy='test')
# check value of phase_pct when three_phase is True
with self.assertRaises(ValueError):
paddle.optimizer.lr.OneCycleLR(
max_learning_rate=0.1,
total_steps=20,
phase_pct=0.6,
three_phase=True)
# check type of max_learning_rate
with self.assertRaises(TypeError):
paddle.optimizer.lr.CyclicLR(
base_learning_rate=0.5,
max_learning_rate='test',
step_size_up=10)
# check value of max_learning_rate
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
base_learning_rate=0.5, max_learning_rate=-1, step_size_up=10)
# check type of step_size_up
with self.assertRaises(TypeError):
paddle.optimizer.lr.CyclicLR(
base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up='test')
# check value of step_size_up
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(
base_learning_rate=0.5, max_learning_rate=1.0, step_size_up=-1)
# check type of step_size_down
with self.assertRaises(TypeError):
paddle.optimizer.lr.CyclicLR(
base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down='test')
# check type of step_size_down
with self.assertRaises(ValueError):
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
paddle.optimizer.lr.CyclicLR(
base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down=-1)
# check value of mode
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(
base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down=500,
mode='test')
# check type value of scale_mode
with self.assertRaises(ValueError):
paddle.optimizer.lr.CyclicLR(
base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=500,
step_size_down=-1,
scale_mode='test')

func_api_kwargs = [(noam_lr, paddle.optimizer.lr.NoamDecay, {
"d_model": 0.01,
Expand Down Expand Up @@ -650,6 +755,56 @@ def test_scheduler(self):
"anneal_strategy": 'linear',
"phase_pct": 0.2,
"three_phase": True,
}),(cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'triangular',
"exp_gamma": 1.,
"scale_fn": None,
"scale_mode": 'cycle',
"verbose": False
}), (cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'triangular2',
"exp_gamma": 1.,
"scale_fn": None,
"scale_mode": 'cycle',
"verbose": False
}), (cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'exp_range',
"exp_gamma": 0.8,
"scale_fn": None,
"scale_mode": 'cycle',
"verbose": False
}), (cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'exp_range',
"exp_gamma": 1.,
"scale_fn": lambda x: 0.95**x,
"scale_mode": 'cycle',
"verbose": False
}), (cyclic_lr, paddle.optimizer.lr.CyclicLR, {
"base_learning_rate": 0.5,
"max_learning_rate": 1.0,
"step_size_up": 15,
"step_size_down": 5,
"mode": 'exp_range',
"exp_gamma": 1.,
"scale_fn": lambda x: 0.95,
"scale_mode": 'iterations',
"verbose": False
})]

for python_func, paddle_api, kwarg in func_api_kwargs:
Expand Down