-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
SequentialLR cannot be used with ReduceLROnPlateau due to .step() not allowing for optional arguments #68978
Comments
I would like to take a jab at this if no one hasn't already! Before I open a PR, I'd like to discuss a few proposals.
We can simply edge case # class SequentialLR
def step(self, metrics=None):
self.last_epoch += 1
idx = bisect_right(self._milestones, self.last_epoch)
scheduler = self._schedulers[idx]
is_reduce_lr_on_plateau = isinstance(scheduler, ReduceLROnPlateau)
is_new_scheduler = idx > 0 and self._milestones[idx - 1] == self.last_epoch
if is_reduce_lr_on_plateau and is_new_scheduler:
scheduler.step(metrics, 0)
elif is_reduce_lr_on_plateau:
scheduler.step(metrics)
elif is_new_scheduler:
scheduler.step(0)
else:
scheduler.step()
A problem with the previous approach is that users might create custom schedulers that inherit from # class SequentialLR
def step(self, *args, **kwargs):
self.last_epoch += 1
idx = bisect_right(self._milestones, self.last_epoch)
scheduler = self._schedulers[idx]
is_lr_scheduler = isinstance(scheduler, _LRScheduler)
is_new_scheduler = idx > 0 and self._milestones[idx - 1] == self.last_epoch
if is_lr_scheduler and is_new_scheduler:
scheduler.step(0)
elif is_lr_scheduler:
scheduler.step()
elif is_new_scheduler:
scheduler.step(*args, **kwargs, epoch=0)
else:
scheduler.step(*args, **kwargs)
Given issues like #67760, #68332, and #68979, perhaps we should rewrite What do you think? @jbschlosser @albanD |
Hi! |
An alternative is to make the For instance, we could add a method (say def get_get_lr(self, *args, **kwargs) -> List[float]:
return self.get_lr() A scheduler with arguments would simply need to overwrite def lrs(optimizer: Optimizer) -> List[float]:
return [group['lr'] for group in optimizer.param_groups]
class ReduceLROnPlateau(_LRScheduler):
r"""Reduce learning rate when a metric has stopped improving"""
def __init__(
self,
optimizer: Optimizer,
gamma: float = 0.5, # <= 1
patience: int = 7,
cooldown: int = 1,
threshold: float = 1e-2,
mode: str = 'minimize', # or 'maximize'
min_lr: Union[float, List[float]] = 1e-6,
last_epoch: int = -1,
verbose: bool = False,
):
self.gamma = gamma
self.patience = patience
self.cooldown = cooldown
self.threshold = threshold
self.mode = mode
if type(min_lr) is float:
min_lr = [min_lr] * len(optimizer.param_groups)
self.min_lrs = min_lr
self.best = float('-inf') if self.mode == 'maximize' else float('inf')
self.last_best = last_epoch
self.last_reduce = last_epoch
super().__init__(optimizer, last_epoch, verbose)
def get_get_lr(self, last: float, *args, **kwargs) -> List[float]:
return self.get_lr(last)
def get_lr(self, last: float) -> List[float]:
if self.mode == 'maximize':
accept = last >= self.best * (1 + self.threshold)
else: # mode == 'minimize'
accept = last <= self.best * (1 - self.threshold)
if accept:
self.best = last
self.last_best = self.last_epoch
return lrs(self.optimizer)
if self.last_epoch - max(self.last_best, self.last_reduce + self.cooldown) <= self.patience:
return lrs(self.optimizer)
self.last_reduce = self.last_epoch
return [
max(lr * self.gamma, min_lr)
for lr, min_lr in zip(lrs(self.optimizer), self.min_lrs)
] |
Hey @francois-rozet, thanks for the reply and suggestion! It definitely makes sense. I'm personally in support of any solution that prioritizes implementation simplicity and BC. Gently pinging @albanD to see if there are any updates from the internal PyTorch team on this! |
Any updates on this? |
No update I'm afraid. My general stance of LRScheduler is that they would require a much bigger (most likely BC-breaking) cleanup to get them in a position where they are reliable. |
What about the schedulers is currently not "reliable"? One problem I encountered which led me to finding this existing issue is that a |
Just to ping about this without opening a new issue - is there any possibility of a PR which allows for a scheduler such as |
is there anyone that found a workaround for this as the above solutions could not help me it says that the metrics is not passed and always none . but when using reducelronplateu alone it works |
馃悰 Bug
Currently, SequentialLR can only be used with schedulers that inherit from _LRScheduler or in other words adhere to the way .step() is called, i.e. without any arguments. This is not the case for the built in ReduceLROnPlateau as it takes in a metric value. Therefore, when calling SequentialLR.step() without arguments, ReduceLROnPlateau will raise an error once its milestone is reached. However, when calling SequentialLR.step(metric), SequentialLR will raise an error due to too many arguments.
Also see Lightning-AI/pytorch-lightning#10759
To Reproduce
Expected behavior
SequentialLR should at least be working with all internal pytorch LR schedulers, hence allow for a metric valued to be passed to ReduceLROnPlateau if it used (or ReduceLROnPlateau needs a rewrite). However, I think SequentialLR should allow for arbitrary arguments in step which allows for arbitrary schedulers (probably another issue in itself).
Either the user takes care of passing the right argument at the right time or there is a built in mechanic (if that is even possible).
cc @vincentqb @jbschlosser @albanD
The text was updated successfully, but these errors were encountered: