Skip to content

Commit

Permalink
Add LightningModule.lr_scheduler_step (#10249)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
rohitgr7 and carmocca committed Jan 12, 2022
1 parent ba71937 commit 82c8875
Show file tree
Hide file tree
Showing 15 changed files with 246 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -139,11 +139,11 @@ ENV/
.data/
Datasets/
mnist/
MNIST/
legacy/checkpoints/
*.gz
*ubyte


# pl tests
ml-runs/
mlruns/
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Expand Up @@ -63,13 +63,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))


- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249))


- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))


- Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))



### Changed

- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
Expand Down
26 changes: 24 additions & 2 deletions docs/source/common/optimizers.rst
Expand Up @@ -518,9 +518,31 @@ to perform a step, Lightning won't be able to support accelerators, precision an
optimizer.step(closure=optimizer_closure)


***************************
Bring your own Custom Learning Rate Schedulers
==============================================

Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
One good example is `Timm Schedulers <https://github.com/rwightman/pytorch-image-models/blob/master/timm/scheduler/scheduler.py>`_. When using custom learning rate schedulers
relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic.
If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it automatically by default.

.. code-block:: python
from timm.scheduler import TanhLRScheduler
def configure_optimizers(self):
optimizer = ...
scheduler = TanhLRScheduler(optimizer, ...)
return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value
Configure Gradient Clipping
***************************
===========================

To configure custom gradient clipping, consider overriding
the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method.
Expand Down
38 changes: 37 additions & 1 deletion pytorch_lightning/core/lightning.py
Expand Up @@ -53,7 +53,7 @@
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -1493,6 +1493,42 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def lr_scheduler_step(
self,
scheduler: LRSchedulerTypeUnion,
optimizer_idx: int,
metric: Optional[Any],
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler.
By default, Lightning calls ``step()`` and as shown in the example
for each scheduler based on its ``interval``.
Args:
scheduler: Learning rate scheduler.
optimizer_idx: Index of the optimizer associated with this scheduler.
metric: Value of the monitor used for schedulers like ``ReduceLROnPlateau``.
Examples::
# DEFAULT
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
if metric is None:
scheduler.step()
else:
scheduler.step(metric)
# Alternative way to update schedulers if it requires an epoch value
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
scheduler.step(epoch=self.current_epoch)
"""
if metric is None:
scheduler.step()
else:
scheduler.step(metric)

def optimizer_step(
self,
epoch: int,
Expand Down
32 changes: 27 additions & 5 deletions pytorch_lightning/core/optimizer.py
Expand Up @@ -23,6 +23,8 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerTypeTuple


def do_nothing_closure() -> None:
Expand Down Expand Up @@ -168,7 +170,9 @@ def closure_dis():
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
def _init_optimizers_and_lr_schedulers(
model: "pl.LightningModule",
) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
model.trainer._lightning_optimizers = None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
Expand All @@ -185,6 +189,7 @@ def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[Lis
)
lr_schedulers = _configure_schedulers(lr_schedulers, monitor)
_set_scheduler_opt_idx(optimizers, lr_schedulers)
_validate_scheduler_api(lr_schedulers, model)
return optimizers, lr_schedulers, optimizer_frequencies


Expand Down Expand Up @@ -298,10 +303,9 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
lr_schedulers.append({**default_config, "scheduler": scheduler})

return lr_schedulers


Expand All @@ -325,9 +329,27 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})

return lr_schedulers


def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None:
for scheduler_config in lr_schedulers:
scheduler = scheduler_config["scheduler"]
if not isinstance(scheduler, _SupportsStateDict):
raise TypeError(
f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
" It should have `state_dict` and `load_state_dict` methods defined."
)

if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model):
raise MisconfigurationException(
f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"
" API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"
" you are using a custom LR scheduler."
)


def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
Expand All @@ -341,7 +363,7 @@ def _get_default_scheduler_config() -> Dict[str, Any]:
}


def _set_scheduler_opt_idx(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None:
for sch in lr_schedulers:

for opt_idx, opt in enumerate(optimizers):
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -454,11 +454,12 @@ def _update_learning_rates(
self.scheduler_progress.increment_ready()

# update LR
if lr_scheduler["reduce_on_plateau"]:
lr_scheduler["scheduler"].step(monitor_val)
else:
lr_scheduler["scheduler"].step()

self.trainer._call_lightning_module_hook(
"lr_scheduler_step",
lr_scheduler["scheduler"],
lr_scheduler["opt_idx"],
monitor_val,
)
self.scheduler_progress.increment_completed()

def _get_monitor_value(self, key: str) -> Any:
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/strategies/deepspeed.py
Expand Up @@ -24,7 +24,6 @@
import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
Expand All @@ -41,7 +40,7 @@
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -399,7 +398,7 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
return self.model, [optimizer]

def _setup_model_and_optimizer(
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None
):
"""Initialize one model and one optimizer with an optional learning rate scheduler.
Expand Down Expand Up @@ -445,7 +444,7 @@ def init_deepspeed(self):
else:
self._initialize_deepspeed_inference(model)

def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
def _init_optimizers(self) -> Tuple[Optimizer, Optional[List[LRSchedulerConfig]], Optional[int]]:
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/strategies/horovod.py
Expand Up @@ -17,7 +17,6 @@
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -105,8 +104,7 @@ def _unpack_lightning_optimizer(opt):
lr_schedulers = self.lightning_module.trainer.lr_schedulers
for scheduler in lr_schedulers:
scheduler = scheduler["scheduler"]
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
Expand Down
Expand Up @@ -43,6 +43,7 @@ class _LogOptions(TypedDict):
"optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"lr_scheduler_step": None,
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
Expand Down
14 changes: 1 addition & 13 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -36,13 +36,13 @@
DataLoader,
IterableDataset,
)
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _SupportsStateDict


class FastForwardSampler(Sampler):
Expand Down Expand Up @@ -576,7 +576,6 @@ def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict:
def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
# In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset`
# therefore, we need to reload the states manually.

latest_worker_id = state_dict["latest_worker_id"]
num_workers = state_dict["state"][latest_worker_id]["num_workers"]
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
Expand Down Expand Up @@ -635,17 +634,6 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}


@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""

def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...


class _StatefulDataLoaderIter:
"""This mixin is used to make PyTorch DataLoaderIter stateful."""

Expand Down
39 changes: 13 additions & 26 deletions pytorch_lightning/utilities/types.py
Expand Up @@ -23,7 +23,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import Metric
from typing_extensions import TypedDict
from typing_extensions import Protocol, runtime_checkable, TypedDict

_NUMBER = Union[int, float]
_METRIC = Union[Metric, torch.Tensor, _NUMBER]
Expand All @@ -46,33 +46,29 @@
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class _LRScheduler:
optimizer: Optimizer
@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""

def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None:
def state_dict(self) -> Dict[str, Any]:
...

def state_dict(self) -> dict:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...

def load_state_dict(self, state_dict: dict) -> None:
...

def get_last_lr(self) -> List[float]:
...

def get_lr(self) -> float:
...
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class _LRScheduler(_SupportsStateDict):
optimizer: Optimizer

def step(self, epoch: Optional[int] = ...) -> None:
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
...


# Copied from `torch.optim.lr_scheduler.pyi`
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class ReduceLROnPlateau:
class ReduceLROnPlateau(_SupportsStateDict):
in_cooldown: bool
optimizer: Optimizer

Expand All @@ -91,15 +87,6 @@ def __init__(
) -> None:
...

def step(self, metrics: Any, epoch: Optional[int] = ...) -> None:
...

def state_dict(self) -> dict:
...

def load_state_dict(self, state_dict: dict) -> None:
...


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
Expand Down
5 changes: 5 additions & 0 deletions tests/models/test_hooks.py
Expand Up @@ -326,6 +326,11 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
args=(current_epoch, i, ANY, 0, ANY),
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp),
),
*(
[dict(name="lr_scheduler_step", args=(ANY, 0, None))]
if i == (trainer.num_training_batches - 1)
else []
),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)),
dict(name="Callback.on_batch_end", args=(trainer, model)),
Expand Down

0 comments on commit 82c8875

Please sign in to comment.