Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LightningModule.lr_scheduler_step #10249

Merged
merged 40 commits into from Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3e10985
add LightningModule.scheduler_step
rohitgr7 Oct 29, 2021
9b07fa2
add tests
rohitgr7 Oct 30, 2021
a718c84
update types
rohitgr7 Oct 30, 2021
39059e5
docs
rohitgr7 Oct 30, 2021
3c66768
update .gitignore
rohitgr7 Oct 30, 2021
e437242
chlog
rohitgr7 Oct 30, 2021
fc8bc16
mypy
rohitgr7 Oct 30, 2021
b4dd1d8
remove step
rohitgr7 Dec 18, 2021
18e6bb4
add protocol api
rohitgr7 Dec 18, 2021
d7bdd0e
update
rohitgr7 Dec 18, 2021
ec2aa5d
add more test
rohitgr7 Dec 18, 2021
555c49f
use extensions
rohitgr7 Dec 18, 2021
5e8d371
register_hook
rohitgr7 Dec 18, 2021
f6b3e10
address reviews
rohitgr7 Dec 20, 2021
3c095ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
b01d2cf
fix and rebase
rohitgr7 Dec 28, 2021
78ebd31
mypy
rohitgr7 Dec 28, 2021
ff11e76
try fix mypy
rohitgr7 Jan 3, 2022
f8de4d0
try fix mypy
rohitgr7 Jan 3, 2022
404ba6b
try fix mypy
rohitgr7 Jan 3, 2022
013f9ce
use existing state_dict protocol
rohitgr7 Jan 3, 2022
78c8133
Merge branch 'master' into enhance/scheduler_step
rohitgr7 Jan 3, 2022
65bda5f
update import
rohitgr7 Jan 3, 2022
26182db
small updates
rohitgr7 Jan 4, 2022
54e2af9
Merge branch 'master' into enhance/scheduler_step
rohitgr7 Jan 4, 2022
b4fb944
add edge case check
rohitgr7 Jan 4, 2022
af8b1c3
rebase
rohitgr7 Jan 4, 2022
4c8ada6
avoid protocol
rohitgr7 Jan 5, 2022
08be795
Merge branch 'master' into enhance/scheduler_step
rohitgr7 Jan 7, 2022
c497bdf
move to types
rohitgr7 Jan 7, 2022
f1553ee
Inherit from the state dict protocol
carmocca Jan 8, 2022
99c92a5
All positional, optimizer index always int
carmocca Jan 8, 2022
ae8ae09
Simplify tests
carmocca Jan 8, 2022
236b55d
Minor test changes
carmocca Jan 8, 2022
7e82d1d
simplify test
rohitgr7 Jan 8, 2022
43532fd
one line
rohitgr7 Jan 8, 2022
4ec0e5c
Reduce further, test calls
carmocca Jan 8, 2022
b55504f
use typeerror
rohitgr7 Jan 10, 2022
281b0ef
Merge remote-tracking branch 'origin/master' into enhance/scheduler_step
rohitgr7 Jan 11, 2022
16797dd
Merge branch 'master' into enhance/scheduler_step
carmocca Jan 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -142,7 +142,7 @@ mnist/
legacy/checkpoints/
*.gz
*ubyte

MNIST/
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

# pl tests
ml-runs/
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))


- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249))


### Changed

Expand Down
23 changes: 23 additions & 0 deletions docs/source/common/optimizers.rst
Expand Up @@ -252,6 +252,29 @@ If you want to call schedulers that require a metric value after each epoch, con

-----

Bring your own Custom Learning Rate Schedulers
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
----------------------------------------------
Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
One good example is `Timm Schedulers <https://github.com/rwightman/pytorch-image-models/blob/master/timm/scheduler/scheduler.py>`_. When using custom learning rate schedulers
relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic.
If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it optimally by default.
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

from timm.scheduler import TanhLRScheduler


def configure_optimizers(self):
optimizer = ...
scheduler = TanhLRScheduler(optimizer, ...)
return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]


def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None):
scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value

-----

Use closure for LBFGS-like optimizers
-------------------------------------
It is a good practice to provide the optimizer with a closure function that performs a ``forward``, ``zero_grad`` and
Expand Down
38 changes: 37 additions & 1 deletion pytorch_lightning/core/lightning.py
Expand Up @@ -53,7 +53,7 @@
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -1493,6 +1493,42 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def lr_scheduler_step(
self,
scheduler: LRSchedulerTypeUnion,
optimizer_idx: Optional[int] = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
metric: Optional[Union[float, torch.Tensor]] = None,
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler.
By default, Lightning calls ``step()`` and as shown in the example
for each scheduler based on its ``interval``.

Args:
scheduler: Learning rate scheduler.
optimizer_idx: Index of the optimizer associated with this scheduler.
metric: Value of the metric used for schedulers like ``ReduceLROnPlateau``.
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

Examples::

# DEFAULT
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
if metric is None:
scheduler.step()
else:
scheduler.step(metric)

# Alternative way to update schedulers if it requires an epoch value
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
scheduler.step(epoch=self.current_epoch)

"""
if metric is None:
scheduler.step()
else:
scheduler.step(metric)

def optimizer_step(
self,
epoch: int,
Expand Down
35 changes: 30 additions & 5 deletions pytorch_lightning/core/optimizer.py
Expand Up @@ -19,6 +19,7 @@
import torch
from torch import optim
from torch.optim import Optimizer
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, rank_zero_warn
Expand Down Expand Up @@ -168,7 +169,9 @@ def closure_dis():
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
model.trainer._lightning_optimizers = None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
Expand Down Expand Up @@ -298,10 +301,13 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
lr_schedulers.append({**default_config, "scheduler": scheduler})

current_scheduler = lr_schedulers[-1]["scheduler"]
if not isinstance(current_scheduler, _SupportedLRScheduler):
raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

return lr_schedulers


Expand All @@ -325,6 +331,11 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})

current_scheduler = lr_schedulers[-1]["scheduler"]
if not isinstance(current_scheduler, _SupportedLRScheduler):
raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.")

return lr_schedulers


Expand All @@ -341,7 +352,7 @@ def _get_default_scheduler_config() -> Dict[str, Any]:
}


def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
def _validate_scheduler_optimizer(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None:
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
Expand Down Expand Up @@ -394,3 +405,17 @@ def zero_grad(self, set_to_none: Optional[bool] = False) -> None:

def __repr__(self) -> str:
return "No Optimizer"


@runtime_checkable
class _SupportedLRScheduler(Protocol):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportedLRScheduler)`"""
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

def step(self, *args: Any, **kwargs: Any) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
...

def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...
11 changes: 6 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -503,11 +503,12 @@ def _update_learning_rates(
self.scheduler_progress.increment_ready()

# update LR
if lr_scheduler["reduce_on_plateau"]:
lr_scheduler["scheduler"].step(monitor_val)
else:
lr_scheduler["scheduler"].step()

self.trainer._call_lightning_module_hook(
"lr_scheduler_step",
lr_scheduler["scheduler"],
optimizer_idx=lr_scheduler["opt_idx"],
metric=monitor_val,
)
self.scheduler_progress.increment_completed()

def _get_monitor_value(self, key: str) -> Any:
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/strategies/deepspeed.py
Expand Up @@ -24,7 +24,6 @@
import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
Expand All @@ -41,7 +40,7 @@
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -399,7 +398,7 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
return self.model, [optimizer]

def _setup_model_and_optimizer(
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
):
"""Initialize one model and one optimizer with an optional learning rate scheduler.

Expand Down Expand Up @@ -445,7 +444,7 @@ def init_deepspeed(self):
else:
self._initialize_deepspeed_inference(model)

def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
def _init_optimizers(self) -> Tuple[Optimizer, Optional[List[LRSchedulerConfig]], Optional[int]]:
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/strategies/horovod.py
Expand Up @@ -17,7 +17,6 @@
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -105,8 +104,7 @@ def _unpack_lightning_optimizer(opt):
lr_schedulers = self.lightning_module.trainer.lr_schedulers
for scheduler in lr_schedulers:
scheduler = scheduler["scheduler"]
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
Expand Down
Expand Up @@ -43,6 +43,7 @@ class _LogOptions(TypedDict):
"optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"lr_scheduler_step": None,
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
Expand Down
13 changes: 3 additions & 10 deletions pytorch_lightning/utilities/types.py
Expand Up @@ -46,12 +46,11 @@
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
# Inferred from `torch.optim.lr_scheduler.pyi`
class _LRScheduler:
optimizer: Optimizer

def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None:
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
...

def state_dict(self) -> dict:
Expand All @@ -60,13 +59,7 @@ def state_dict(self) -> dict:
def load_state_dict(self, state_dict: dict) -> None:
...

def get_last_lr(self) -> List[float]:
...

def get_lr(self) -> float:
...

def step(self, epoch: Optional[int] = ...) -> None:
def step(self, *args: Any, **kwargs: Any) -> None:
...


Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_hooks.py
Expand Up @@ -326,6 +326,17 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
args=(current_epoch, i, ANY, 0, ANY),
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp),
),
*(
[
dict(
name="lr_scheduler_step",
args=(ANY,),
kwargs=dict(optimizer_idx=None, metric=None),
)
]
if i == (trainer.num_training_batches - 1)
else []
),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)),
dict(name="Callback.on_batch_end", args=(trainer, model)),
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/logging_/test_logger_connector.py
Expand Up @@ -233,6 +233,7 @@ def test_fx_validator_integration(tmpdir):
"configure_callbacks": "You can't",
"on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"lr_scheduler_step": "You can't",
"summarize": "not managed by the `Trainer",
}
model = HookedModel(not_supported)
Expand Down