From 3e10985cc56c2b9eb67bebe62b27400fa54f4052 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 29 Oct 2021 14:51:52 +0530 Subject: [PATCH 01/35] add LightningModule.scheduler_step --- pytorch_lightning/core/lightning.py | 12 ++++++++++++ .../loops/epoch/training_epoch_loop.py | 14 ++++++++------ tests/models/test_hooks.py | 11 +++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3d11a06deb841..e3acbeff9acfe 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1493,6 +1493,18 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm ) + def scheduler_step( + self, + scheduler: Any, + step: int, + optimizer_idx: Optional[int] = None, + monitor_val: Optional[Union[float, torch.Tensor]] = None, + ): + if monitor_val is None: + scheduler.step() + else: + scheduler.step(monitor_val) + def optimizer_step( self, epoch: int, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ce9e82bc93efd..c41ed128f90c5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -503,14 +503,16 @@ def _update_learning_rates( self.scheduler_progress.increment_ready() # update LR - if lr_scheduler["reduce_on_plateau"]: - lr_scheduler["scheduler"].step(monitor_val) - else: - lr_scheduler["scheduler"].step() - + step = self.trainer.global_step if interval == "step" else self.trainer.current_epoch + self.trainer.lightning_module.scheduler_step( + lr_scheduler["scheduler"], + step, + optimizer_idx=lr_scheduler["opt_idx"], + monitor_val=monitor_val, + ) self.scheduler_progress.increment_completed() - def _get_monitor_value(self, key: str) -> Any: + def _get_monitor_value(self, key: str) -> Optional[Union[float, torch.Tensor]]: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index acc12151739db..102b89f85b14e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -326,6 +326,17 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre args=(current_epoch, i, ANY, 0, ANY), kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp), ), + *( + [ + dict( + name="scheduler_step", + args=(ANY, current_epoch), + kwargs=dict(optimizer_idx=None, monitor_val=None), + ) + ] + if i == (batches - 1) + else [] + ), dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)), dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)), dict(name="Callback.on_batch_end", args=(trainer, model)), From 9b07fa2c4236e92e43edcf2ba0186318c3c5e92d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 30 Oct 2021 20:18:58 +0530 Subject: [PATCH 02/35] add tests --- pytorch_lightning/core/lightning.py | 2 +- .../loops/epoch/training_epoch_loop.py | 2 +- pytorch_lightning/strategies/horovod.py | 4 +- tests/models/test_hooks.py | 4 +- tests/trainer/optimization/test_optimizers.py | 65 +++++++++++++++++++ 5 files changed, 70 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e3acbeff9acfe..14509aaa4728c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1493,7 +1493,7 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm ) - def scheduler_step( + def lr_scheduler_step( self, scheduler: Any, step: int, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c41ed128f90c5..d21cb94a65bc0 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -504,7 +504,7 @@ def _update_learning_rates( # update LR step = self.trainer.global_step if interval == "step" else self.trainer.current_epoch - self.trainer.lightning_module.scheduler_step( + self.trainer.lightning_module.lr_scheduler_step( lr_scheduler["scheduler"], step, optimizer_idx=lr_scheduler["opt_idx"], diff --git a/pytorch_lightning/strategies/horovod.py b/pytorch_lightning/strategies/horovod.py index 7d7ecbe6a7b2c..a1c34fa87b8d5 100644 --- a/pytorch_lightning/strategies/horovod.py +++ b/pytorch_lightning/strategies/horovod.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer @@ -105,8 +104,7 @@ def _unpack_lightning_optimizer(opt): lr_schedulers = self.lightning_module.trainer.lr_schedulers for scheduler in lr_schedulers: scheduler = scheduler["scheduler"] - if isinstance(scheduler, _LRScheduler): - scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] + scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 102b89f85b14e..7ea291aa7a772 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -329,12 +329,12 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre *( [ dict( - name="scheduler_step", + name="lr_scheduler_step", args=(ANY, current_epoch), kwargs=dict(optimizer_idx=None, monitor_val=None), ) ] - if i == (batches - 1) + if i == (trainer.num_training_batches - 1) else [] ), dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)), diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index d2e193b59e1ea..2a0efe5c766f9 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from unittest import mock import pytest @@ -661,3 +662,67 @@ def on_save_checkpoint(self, checkpoint): model.training_epoch_end = None trainer.fit(model) assert model.on_save_checkpoint_called + + +def test_lr_scheduler_step_hook(tmpdir): + class CustomEpochScheduler: + def __init__(self, optimizer): + self.optimizer = optimizer + + def step(self, epoch): + for param_group in self.optimizer.param_groups: + param_group["lr"] = param_group["lr"] / (epoch + 1) + + class CustomBoringModel(BoringModel): + def __init__(self, learning_rate): + super().__init__() + self.learning_rate = learning_rate + self.layer1 = torch.nn.Linear(32, 2) + self.layer2 = torch.nn.Linear(32, 2) + + def training_step(self, batch, batch_idx, optimizer_idx): + if optimizer_idx == 0: + output = self.layer1(batch) + else: + output = self.layer2(batch) + + return self.loss(batch, output) + + def training_epoch_end(self, *args, **kwargs): + pass + + def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val=None): + if optimizer_idx == 0: + assert step == self.trainer.global_step + super().lr_scheduler_step(scheduler, step, optimizer_idx, monitor_val) + if optimizer_idx == 1: + assert step == self.trainer.current_epoch + scheduler.step(epoch=step) + + def configure_optimizers(self): + opt1 = torch.optim.SGD(self.layer1.parameters(), lr=self.learning_rate) + lr_scheduler1 = {"scheduler": torch.optim.lr_scheduler.StepLR(opt1, step_size=1), "interval": "step"} + opt2 = torch.optim.SGD(self.layer2.parameters(), lr=self.learning_rate) + lr_scheduler2 = CustomEpochScheduler(opt2) + return {"optimizer": opt1, "lr_scheduler": lr_scheduler1}, { + "optimizer": opt2, + "lr_scheduler": lr_scheduler2, + } + + lr = 1e-2 + max_epochs = 3 + model = CustomBoringModel(learning_rate=lr) + trainer = Trainer( + default_root_dir=tmpdir, + enable_checkpointing=False, + logger=False, + max_epochs=max_epochs, + limit_train_batches=2, + limit_val_batches=0, + ) + trainer.fit(model) + + for param_group in trainer.optimizers[1].param_groups: + assert param_group["lr"] == lr / math.factorial(max_epochs) + + breakpoint() From a718c849038b9eb68afaca745059b95d18012baa Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 30 Oct 2021 23:51:05 +0530 Subject: [PATCH 03/35] update types --- _notebooks | 2 +- pytorch_lightning/strategies/deepspeed.py | 5 +---- pytorch_lightning/trainer/trainer.py | 3 +++ pytorch_lightning/utilities/cli.py | 11 +++-------- pytorch_lightning/utilities/types.py | 3 +++ 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/_notebooks b/_notebooks index 0c325829101d5..a2fb6468112b7 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit 0c325829101d5a6ebf32ed99bbf5b09badf04a59 +Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 452f3c8e1a8b4..d6a996f58cbe2 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -24,7 +24,6 @@ import torch from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers @@ -398,9 +397,7 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer] self._set_deepspeed_activation_checkpointing() return self.model, [optimizer] - def _setup_model_and_optimizer( - self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None - ): + def _setup_model_and_optimizer(self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[Any] = None): """Initialize one model and one optimizer with an optional learning rate scheduler. This calls :func:`deepspeed.initialize` internally. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b9ae7f2dc2036..351d21eb06e88 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -41,6 +41,8 @@ from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.loops.utilities import _parse_loop_limits, _reset_progress +from pytorch_lightning.utilities.cli import LR_SCHEDULER_REGISTRY +from pytorch_lightning.loops.utilities import _parse_loop_limits from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, DDPSpawnStrategy, @@ -110,6 +112,7 @@ LRSchedulerConfig, STEP_OUTPUT, TRAIN_DATALOADERS, + LRSchedulerTypeUnion, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 51fb22a301035..3f4822b9df50c 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -84,7 +84,6 @@ def __str__(self) -> str: OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) LR_SCHEDULER_REGISTRY = _Registry() -LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): @@ -202,21 +201,17 @@ def add_optimizer_args( def add_lr_scheduler_args( self, - lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]], + lr_scheduler_class: Any, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from a learning rate scheduler class to a nested key of the parser. Args: - lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. + lr_scheduler_class: Learning rate scheduler. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ - if isinstance(lr_scheduler_class, tuple): - assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) - else: - assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) @@ -749,7 +744,7 @@ def get_automatic( return automatic optimizers = get_automatic(Optimizer, parser._optimizers) - lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) + lr_schedulers = get_automatic(Any, parser._lr_schedulers) if len(optimizers) == 0: return diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 44a3b88d530d6..9ec979bfd2999 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -21,6 +21,9 @@ import torch from torch.optim import Optimizer +from typing import Any, Dict, Iterator, List, Mapping, Sequence, Union + +import torch from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import TypedDict From 39059e54765a4ea9c6606cad8e167688e3598ad9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 31 Oct 2021 00:21:08 +0530 Subject: [PATCH 04/35] docs --- docs/source/common/optimizers.rst | 25 +++++++++++++++++++++++++ pytorch_lightning/core/lightning.py | 28 +++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 8080d12e2b6fe..1fd7771958fa1 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -252,6 +252,31 @@ If you want to call schedulers that require a metric value after each epoch, con ----- +Bring your own custom learning rate schedulers +---------------------------------------------- +Lightning allows custom learning rate schedulers which are not present in +`PyTorch natively `_. +One good example is `Timm Schedulers `_. +You can configure how your learning rate will be updated based on your custom implementation +and lightning will handle when they should be updated based on the scheduler config provided inside +:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers`. For you custom +implementation you must override :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` +if necessary. If you are using native PyTorch schedulers, there is no need to override this hook since +Lightning will handle it optimally by default. + +.. testcode:: python + + def configure_optimizers(self): + optimizer = ... + scheduler = ... + return [optimizer], [scheduler] + + + def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val=None): + scheduler.step(epoch=step) # timm's scheduler need the epoch value + +----- + Use closure for LBFGS-like optimizers ------------------------------------- It is a good practice to provide the optimizer with a closure function that performs a ``forward``, ``zero_grad`` and diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 14509aaa4728c..e617ccb9cf244 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1177,7 +1177,7 @@ def configure_optimizers(self): # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, - # Metric to to monitor for schedulers like `ReduceLROnPlateau` + # Metric to to monitor for schedulers like ``ReduceLROnPlateau`` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping @@ -1500,6 +1500,32 @@ def lr_scheduler_step( optimizer_idx: Optional[int] = None, monitor_val: Optional[Union[float, torch.Tensor]] = None, ): + r""" + Override this method to adjust the default way the + :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler. + By default, Lightning calls ``step()`` and as shown in the example + for each scheduler based on its ``interval``. + + Args: + scheduler: Learning rate scheduler. + step: Epoch or global step based on the interval of individual scheduler. + optimizer_idx: Index of the optimizer associated with scheduler. + monitor_val: Value of the metric used for schedulers like ``ReduceLROnPlateau``. + + Examples:: + + # DEFAULT + def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val): + if monitor_val is None: + scheduler.step() + else: + scheduler.step(monitor_val) + + # Alternative way to do step if scheduler requires an epoch value + def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val): + scheduler.step(epoch=step) + + """ if monitor_val is None: scheduler.step() else: From 3c667689555f9052ef7c1449c2a56bbb2e6ead59 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 31 Oct 2021 00:27:20 +0530 Subject: [PATCH 05/35] update .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e1d165dd5dbb1..97ba65852b821 100644 --- a/.gitignore +++ b/.gitignore @@ -142,7 +142,7 @@ mnist/ legacy/checkpoints/ *.gz *ubyte - +MNIST/ # pl tests ml-runs/ From e437242a1f691cb33bda6ff73ae8e39034c4d2e5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 31 Oct 2021 00:29:06 +0530 Subject: [PATCH 06/35] chlog --- CHANGELOG.md | 4 ++++ _notebooks | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09623b6b52e25..49443e5f09983 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990)) +- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249)) + ### Changed @@ -229,6 +231,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Remove deprecated method `ClusterEnvironment.creates_children` ([#10339](https://github.com/PyTorchLightning/pytorch-lightning/issues/10339)) +### Changed + - Removed deprecated `TrainerModelHooksMixin.is_function_implemented` and `TrainerModelHooksMixin.has_arg` ([#10322](https://github.com/PyTorchLightning/pytorch-lightning/pull/10322)) diff --git a/_notebooks b/_notebooks index a2fb6468112b7..0c325829101d5 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc +Subproject commit 0c325829101d5a6ebf32ed99bbf5b09badf04a59 From fc8bc16fa8dd2cfa156b49742dede2b97256ee8b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 31 Oct 2021 00:34:02 +0530 Subject: [PATCH 07/35] mypy --- pytorch_lightning/utilities/cli.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 3f4822b9df50c..2e702bc73f658 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -732,9 +732,7 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - 'AUTOMATIC'.""" parser = self._parser(subcommand) - def get_automatic( - class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] - ) -> List[str]: + def get_automatic(class_type: Any, register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]]) -> List[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): From b4dd1d82736f9f284f6a4b9f143dd3057b0a4a46 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 18 Dec 2021 20:13:47 +0530 Subject: [PATCH 08/35] remove step --- CHANGELOG.md | 2 -- docs/source/common/optimizers.rst | 6 +++--- pytorch_lightning/core/lightning.py | 10 ++++------ .../loops/epoch/training_epoch_loop.py | 2 -- pytorch_lightning/strategies/deepspeed.py | 5 ++++- pytorch_lightning/trainer/trainer.py | 1 + pytorch_lightning/utilities/cli.py | 15 +++++++++++---- pytorch_lightning/utilities/types.py | 2 ++ tests/models/test_hooks.py | 2 +- tests/trainer/optimization/test_optimizers.py | 14 ++++++-------- 10 files changed, 32 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49443e5f09983..6ffb0ec52d0f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -231,8 +231,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Remove deprecated method `ClusterEnvironment.creates_children` ([#10339](https://github.com/PyTorchLightning/pytorch-lightning/issues/10339)) -### Changed - - Removed deprecated `TrainerModelHooksMixin.is_function_implemented` and `TrainerModelHooksMixin.has_arg` ([#10322](https://github.com/PyTorchLightning/pytorch-lightning/pull/10322)) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 1fd7771958fa1..25245433b1eee 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -252,7 +252,7 @@ If you want to call schedulers that require a metric value after each epoch, con ----- -Bring your own custom learning rate schedulers +Bring your own Custom Learning Rate Schedulers ---------------------------------------------- Lightning allows custom learning rate schedulers which are not present in `PyTorch natively `_. @@ -272,8 +272,8 @@ Lightning will handle it optimally by default. return [optimizer], [scheduler] - def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val=None): - scheduler.step(epoch=step) # timm's scheduler need the epoch value + def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None): + scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value ----- diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e617ccb9cf244..e35eeeb2458ab 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1496,7 +1496,6 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va def lr_scheduler_step( self, scheduler: Any, - step: int, optimizer_idx: Optional[int] = None, monitor_val: Optional[Union[float, torch.Tensor]] = None, ): @@ -1508,28 +1507,27 @@ def lr_scheduler_step( Args: scheduler: Learning rate scheduler. - step: Epoch or global step based on the interval of individual scheduler. optimizer_idx: Index of the optimizer associated with scheduler. monitor_val: Value of the metric used for schedulers like ``ReduceLROnPlateau``. Examples:: # DEFAULT - def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val): + def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): if monitor_val is None: scheduler.step() else: scheduler.step(monitor_val) # Alternative way to do step if scheduler requires an epoch value - def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val): - scheduler.step(epoch=step) + def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): + scheduler.step(epoch=self.current_epoch) """ if monitor_val is None: scheduler.step() else: - scheduler.step(monitor_val) + scheduler.step(metrics=monitor_val) def optimizer_step( self, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index d21cb94a65bc0..3c9ae335f9e1a 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -503,10 +503,8 @@ def _update_learning_rates( self.scheduler_progress.increment_ready() # update LR - step = self.trainer.global_step if interval == "step" else self.trainer.current_epoch self.trainer.lightning_module.lr_scheduler_step( lr_scheduler["scheduler"], - step, optimizer_idx=lr_scheduler["opt_idx"], monitor_val=monitor_val, ) diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index d6a996f58cbe2..452f3c8e1a8b4 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -24,6 +24,7 @@ import torch from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers @@ -397,7 +398,9 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer] self._set_deepspeed_activation_checkpointing() return self.model, [optimizer] - def _setup_model_and_optimizer(self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[Any] = None): + def _setup_model_and_optimizer( + self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None + ): """Initialize one model and one optimizer with an optional learning rate scheduler. This calls :func:`deepspeed.initialize` internally. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 351d21eb06e88..3c4292c4bd4c2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -113,6 +113,7 @@ STEP_OUTPUT, TRAIN_DATALOADERS, LRSchedulerTypeUnion, + TRAIN_DATALOADERS, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 2e702bc73f658..51fb22a301035 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -84,6 +84,7 @@ def __str__(self) -> str: OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) LR_SCHEDULER_REGISTRY = _Registry() +LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): @@ -201,17 +202,21 @@ def add_optimizer_args( def add_lr_scheduler_args( self, - lr_scheduler_class: Any, + lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]], nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from a learning rate scheduler class to a nested key of the parser. Args: - lr_scheduler_class: Learning rate scheduler. + lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ + if isinstance(lr_scheduler_class, tuple): + assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) + else: + assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) @@ -732,7 +737,9 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - 'AUTOMATIC'.""" parser = self._parser(subcommand) - def get_automatic(class_type: Any, register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]]) -> List[str]: + def get_automatic( + class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] + ) -> List[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): @@ -742,7 +749,7 @@ def get_automatic(class_type: Any, register: Dict[str, Tuple[Union[Type, Tuple[T return automatic optimizers = get_automatic(Optimizer, parser._optimizers) - lr_schedulers = get_automatic(Any, parser._lr_schedulers) + lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) if len(optimizers) == 0: return diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 9ec979bfd2999..4d262f301062d 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -22,8 +22,10 @@ import torch from torch.optim import Optimizer from typing import Any, Dict, Iterator, List, Mapping, Sequence, Union +from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union import torch +from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import TypedDict diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7ea291aa7a772..ab48aa315d75c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -330,7 +330,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre [ dict( name="lr_scheduler_step", - args=(ANY, current_epoch), + args=(ANY,), kwargs=dict(optimizer_idx=None, monitor_val=None), ) ] diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 2a0efe5c766f9..fcc4859a8efe3 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -691,13 +691,13 @@ def training_step(self, batch, batch_idx, optimizer_idx): def training_epoch_end(self, *args, **kwargs): pass - def lr_scheduler_step(self, scheduler, step, optimizer_idx, monitor_val=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None): + # step-level if optimizer_idx == 0: - assert step == self.trainer.global_step - super().lr_scheduler_step(scheduler, step, optimizer_idx, monitor_val) - if optimizer_idx == 1: - assert step == self.trainer.current_epoch - scheduler.step(epoch=step) + super().lr_scheduler_step(scheduler, optimizer_idx, monitor_val) + # epoch-level + elif optimizer_idx == 1: + scheduler.step(epoch=self.current_epoch) def configure_optimizers(self): opt1 = torch.optim.SGD(self.layer1.parameters(), lr=self.learning_rate) @@ -724,5 +724,3 @@ def configure_optimizers(self): for param_group in trainer.optimizers[1].param_groups: assert param_group["lr"] == lr / math.factorial(max_epochs) - - breakpoint() From 18e6bb477e27a4037d963fd876a33dc6ee5c2791 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 18 Dec 2021 23:08:47 +0530 Subject: [PATCH 09/35] add protocol api --- pytorch_lightning/strategies/deepspeed.py | 5 ++--- pytorch_lightning/strategies/strategy.py | 1 + pytorch_lightning/trainer/optimizers.py | 19 +++++++++++++++++++ tests/trainer/optimization/test_optimizers.py | 8 ++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 452f3c8e1a8b4..dc63021ee1eaf 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -24,7 +24,6 @@ import torch from torch.nn import Module from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers @@ -41,7 +40,7 @@ from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache warning_cache = WarningCache() @@ -399,7 +398,7 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer] return self.model, [optimizer] def _setup_model_and_optimizer( - self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None + self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None ): """Initialize one model and one optimizer with an optional learning rate scheduler. diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index fe9093838c157..d2488b7d81ba1 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeUnion, STEP_OUTPUT TBroadcast = TypeVar("TBroadcast") diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 0b44786873867..09bb55533a4de 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -14,6 +14,11 @@ from abc import ABC from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable, Tuple, Union + +import torch +from torch import optim +from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.optimizer import ( @@ -24,6 +29,20 @@ from pytorch_lightning.utilities import rank_zero_deprecation +@runtime_checkable +class _SupportedLRScheduler(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportedLRScheduler)`""" + + def step(self, *args: Any, **kwargs: Any) -> None: + ... + + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... + + class TrainerOptimizersMixin(ABC): r""" .. deprecated:: v1.6 diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index fcc4859a8efe3..99d5985dbb0fc 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -665,6 +665,8 @@ def on_save_checkpoint(self, checkpoint): def test_lr_scheduler_step_hook(tmpdir): + """Test that custom lr_schedulers works and `lr_scheduler_hook` is called at appropriate time.""" + class CustomEpochScheduler: def __init__(self, optimizer): self.optimizer = optimizer @@ -673,6 +675,12 @@ def step(self, epoch): for param_group in self.optimizer.param_groups: param_group["lr"] = param_group["lr"] / (epoch + 1) + def state_dict(self): + return {key: value for key, value in self.__dict__.items() if key != "optimizer"} + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + class CustomBoringModel(BoringModel): def __init__(self, learning_rate): super().__init__() From d7bdd0e4d1979914a5bb8fc2df1f88dfc0a08d29 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 18 Dec 2021 23:11:56 +0530 Subject: [PATCH 10/35] update --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e35eeeb2458ab..76c44dc5205a3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1519,7 +1519,7 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): else: scheduler.step(monitor_val) - # Alternative way to do step if scheduler requires an epoch value + # Alternative way to update schedulers if it requires an epoch value def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): scheduler.step(epoch=self.current_epoch) From ec2aa5d995d3680ec015a1fe46a0bd7f1f895522 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 18 Dec 2021 23:25:32 +0530 Subject: [PATCH 11/35] add more test --- tests/trainer/optimization/test_optimizers.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 99d5985dbb0fc..68b149e85b94b 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -732,3 +732,26 @@ def configure_optimizers(self): for param_group in trainer.optimizers[1].param_groups: assert param_group["lr"] == lr / math.factorial(max_epochs) + + +def test_invalid_lr_scheduler(tmpdir): + """Test that custom lr_schedulers works and `lr_scheduler_hook` is called at appropriate time.""" + + class CustomScheduler: + def __init__(self, optimizer): + self.optimizer = optimizer + + def step(self, epoch): + for param_group in self.optimizer.param_groups: + param_group["lr"] = param_group["lr"] / (epoch + 1) + + class CustomBoringModel(BoringModel): + def configure_optimizers(self): + opt = torch.optim.SGD(self.parameters(), lr=1e-2) + lr_scheduler = CustomScheduler(opt) + return {"optimizer": opt, "lr_scheduler": lr_scheduler} + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir) + with pytest.raises(ValueError, match="provided lr scheduler .* is invalid"): + trainer.init_optimizers(model) From 555c49f5e8b32cdde68927a8ee3aa8d36aedf2c8 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 18 Dec 2021 23:29:02 +0530 Subject: [PATCH 12/35] use extensions --- pytorch_lightning/trainer/optimizers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 09bb55533a4de..27b222502c373 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -15,10 +15,12 @@ from abc import ABC from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Protocol, runtime_checkable, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import optim from torch.optim.optimizer import Optimizer +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.core.optimizer import ( From 5e8d37117137c123b344a016f7fee49b4d22e0f9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 19 Dec 2021 00:07:02 +0530 Subject: [PATCH 13/35] register_hook --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 3 ++- .../trainer/connectors/logger_connector/fx_validator.py | 1 + tests/trainer/logging_/test_logger_connector.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3c9ae335f9e1a..9aba7e72a1bf1 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -503,7 +503,8 @@ def _update_learning_rates( self.scheduler_progress.increment_ready() # update LR - self.trainer.lightning_module.lr_scheduler_step( + self.trainer._call_lightning_module_hook( + "lr_scheduler_step", lr_scheduler["scheduler"], optimizer_idx=lr_scheduler["opt_idx"], monitor_val=monitor_val, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index e73bf54825269..c33320185d76a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -43,6 +43,7 @@ class _LogOptions(TypedDict): "optimizer_step": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), + "lr_scheduler_step": None, "on_before_zero_grad": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b965478684b87..8d6b3551e8579 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -233,6 +233,7 @@ def test_fx_validator_integration(tmpdir): "configure_callbacks": "You can't", "on_validation_model_eval": "You can't", "on_validation_model_train": "You can't", + "lr_scheduler_step": "You can't", "summarize": "not managed by the `Trainer", } model = HookedModel(not_supported) From f6b3e10f0b116686c2dd3840110b59cbdae69171 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 21 Dec 2021 00:48:08 +0530 Subject: [PATCH 14/35] address reviews --- docs/source/common/optimizers.rst | 24 +++++++++---------- pytorch_lightning/core/lightning.py | 14 +++++------ .../loops/epoch/training_epoch_loop.py | 4 ++-- tests/models/test_hooks.py | 2 +- tests/trainer/optimization/test_optimizers.py | 4 ++-- 5 files changed, 23 insertions(+), 25 deletions(-) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 25245433b1eee..8fd26a22dadc1 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -254,25 +254,23 @@ If you want to call schedulers that require a metric value after each epoch, con Bring your own Custom Learning Rate Schedulers ---------------------------------------------- -Lightning allows custom learning rate schedulers which are not present in -`PyTorch natively `_. -One good example is `Timm Schedulers `_. -You can configure how your learning rate will be updated based on your custom implementation -and lightning will handle when they should be updated based on the scheduler config provided inside -:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers`. For you custom -implementation you must override :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` -if necessary. If you are using native PyTorch schedulers, there is no need to override this hook since -Lightning will handle it optimally by default. +Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively `_. +One good example is `Timm Schedulers `_. When using custom learning rate schedulers +relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic. +If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it optimally by default. + +.. code-block:: python + + from timm.scheduler import TanhLRScheduler -.. testcode:: python def configure_optimizers(self): optimizer = ... - scheduler = ... - return [optimizer], [scheduler] + scheduler = TanhLRScheduler(optimizer, ...) + return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None): scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value ----- diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 76c44dc5205a3..37382c485b381 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1177,7 +1177,7 @@ def configure_optimizers(self): # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, - # Metric to to monitor for schedulers like ``ReduceLROnPlateau`` + # Metric to to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping @@ -1497,8 +1497,8 @@ def lr_scheduler_step( self, scheduler: Any, optimizer_idx: Optional[int] = None, - monitor_val: Optional[Union[float, torch.Tensor]] = None, - ): + metrics: Optional[Union[float, torch.Tensor]] = None, + ) -> None: r""" Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each scheduler. @@ -1513,21 +1513,21 @@ def lr_scheduler_step( Examples:: # DEFAULT - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics): if monitor_val is None: scheduler.step() else: scheduler.step(monitor_val) # Alternative way to update schedulers if it requires an epoch value - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics): scheduler.step(epoch=self.current_epoch) """ - if monitor_val is None: + if metrics is None: scheduler.step() else: - scheduler.step(metrics=monitor_val) + scheduler.step(metrics=metrics) def optimizer_step( self, diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 9aba7e72a1bf1..00d6142bcd38f 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -507,11 +507,11 @@ def _update_learning_rates( "lr_scheduler_step", lr_scheduler["scheduler"], optimizer_idx=lr_scheduler["opt_idx"], - monitor_val=monitor_val, + metrics=monitor_val, ) self.scheduler_progress.increment_completed() - def _get_monitor_value(self, key: str) -> Optional[Union[float, torch.Tensor]]: + def _get_monitor_value(self, key: str) -> Any: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index ab48aa315d75c..7c13577dee4d7 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -331,7 +331,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict( name="lr_scheduler_step", args=(ANY,), - kwargs=dict(optimizer_idx=None, monitor_val=None), + kwargs=dict(optimizer_idx=None, metrics=None), ) ] if i == (trainer.num_training_batches - 1) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 68b149e85b94b..c0d604756f7d5 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -699,10 +699,10 @@ def training_step(self, batch, batch_idx, optimizer_idx): def training_epoch_end(self, *args, **kwargs): pass - def lr_scheduler_step(self, scheduler, optimizer_idx, monitor_val=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None): # step-level if optimizer_idx == 0: - super().lr_scheduler_step(scheduler, optimizer_idx, monitor_val) + super().lr_scheduler_step(scheduler, optimizer_idx, metrics) # epoch-level elif optimizer_idx == 1: scheduler.step(epoch=self.current_epoch) From 3c095ce1b40460e414e713cc212e35648556df4e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Dec 2021 19:28:28 +0000 Subject: [PATCH 15/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3c4292c4bd4c2..ce5680059a945 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -112,8 +112,6 @@ LRSchedulerConfig, STEP_OUTPUT, TRAIN_DATALOADERS, - LRSchedulerTypeUnion, - TRAIN_DATALOADERS, ) from pytorch_lightning.utilities.warnings import PossibleUserWarning From b01d2cf0b513be5613e6a65cfec7f76c4712e9ed Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 28 Dec 2021 18:17:22 +0530 Subject: [PATCH 16/35] fix and rebase --- docs/source/common/optimizers.rst | 2 +- pytorch_lightning/core/lightning.py | 22 +++++----- pytorch_lightning/core/optimizer.py | 40 +++++++++++++++---- .../loops/epoch/training_epoch_loop.py | 2 +- pytorch_lightning/strategies/deepspeed.py | 4 +- pytorch_lightning/strategies/strategy.py | 1 - pytorch_lightning/trainer/optimizers.py | 21 ---------- pytorch_lightning/trainer/trainer.py | 2 - pytorch_lightning/utilities/types.py | 16 +------- tests/models/test_hooks.py | 2 +- tests/trainer/optimization/test_optimizers.py | 7 ++-- 11 files changed, 55 insertions(+), 64 deletions(-) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 8fd26a22dadc1..bb9799f2d234f 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -270,7 +270,7 @@ If you are using native PyTorch schedulers, there is no need to override this ho return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] - def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None): scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value ----- diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 37382c485b381..292ed728d5868 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -53,7 +53,7 @@ from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.parsing import collect_init_args from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -1495,9 +1495,9 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va def lr_scheduler_step( self, - scheduler: Any, + scheduler: LRSchedulerTypeUnion, optimizer_idx: Optional[int] = None, - metrics: Optional[Union[float, torch.Tensor]] = None, + metric: Optional[Union[float, torch.Tensor]] = None, ) -> None: r""" Override this method to adjust the default way the @@ -1507,27 +1507,27 @@ def lr_scheduler_step( Args: scheduler: Learning rate scheduler. - optimizer_idx: Index of the optimizer associated with scheduler. - monitor_val: Value of the metric used for schedulers like ``ReduceLROnPlateau``. + optimizer_idx: Index of the optimizer associated with this scheduler. + metric: Value of the metric used for schedulers like ``ReduceLROnPlateau``. Examples:: # DEFAULT - def lr_scheduler_step(self, scheduler, optimizer_idx, metrics): - if monitor_val is None: + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + if metric is None: scheduler.step() else: - scheduler.step(monitor_val) + scheduler.step(metric) # Alternative way to update schedulers if it requires an epoch value - def lr_scheduler_step(self, scheduler, optimizer_idx, metrics): + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): scheduler.step(epoch=self.current_epoch) """ - if metrics is None: + if metric is None: scheduler.step() else: - scheduler.step(metrics=metrics) + scheduler.step(metric) def optimizer_step( self, diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 5859687eed345..401b95078095e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -19,10 +19,12 @@ import torch from torch import optim from torch.optim import Optimizer +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import LRSchedulerConfig def do_nothing_closure() -> None: @@ -168,7 +170,9 @@ def closure_dis(): trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) -def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]: +def _init_optimizers_and_lr_schedulers( + model: "pl.LightningModule", +) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" model.trainer._lightning_optimizers = None optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) @@ -252,7 +256,7 @@ def _configure_optimizers( return optimizers, lr_schedulers, optimizer_frequencies, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using automatic optimization.""" lr_schedulers = [] default_config = _get_default_scheduler_config() @@ -298,14 +302,17 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] lr_schedulers.append( {**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor} ) - elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): - lr_schedulers.append({**default_config, "scheduler": scheduler}) else: - raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') + lr_schedulers.append({**default_config, "scheduler": scheduler}) + + current_scheduler = lr_schedulers[-1]["scheduler"] + if not isinstance(current_scheduler, _SupportedLRScheduler): + raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.") + return lr_schedulers -def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: +def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using manual optimization.""" lr_schedulers = [] default_config = _get_default_scheduler_config() @@ -325,6 +332,11 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - lr_schedulers.append({**default_config, **scheduler}) else: lr_schedulers.append({**default_config, "scheduler": scheduler}) + + current_scheduler = lr_schedulers[-1]["scheduler"] + if not isinstance(current_scheduler, _SupportedLRScheduler): + raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.") + return lr_schedulers @@ -341,7 +353,7 @@ def _get_default_scheduler_config() -> Dict[str, Any]: } -def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None: +def _validate_scheduler_optimizer(optimizers: List[Optimizer], lr_schedulers: List[LRSchedulerConfig]) -> None: if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers): raise MisconfigurationException( "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." @@ -394,3 +406,17 @@ def zero_grad(self, set_to_none: Optional[bool] = False) -> None: def __repr__(self) -> str: return "No Optimizer" + + +@runtime_checkable +class _SupportedLRScheduler(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportedLRScheduler)`""" + + def step(self, *args: Any, **kwargs: Any) -> None: + ... + + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 00d6142bcd38f..7f42689b8b4a5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -507,7 +507,7 @@ def _update_learning_rates( "lr_scheduler_step", lr_scheduler["scheduler"], optimizer_idx=lr_scheduler["opt_idx"], - metrics=monitor_val, + metric=monitor_val, ) self.scheduler_progress.increment_completed() diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index dc63021ee1eaf..84c6f0b0dcfd9 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -40,7 +40,7 @@ from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, LRSchedulerTypeUnion, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache warning_cache = WarningCache() @@ -444,7 +444,7 @@ def init_deepspeed(self): else: self._initialize_deepspeed_inference(model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]: + def _init_optimizers(self) -> Tuple[Optimizer, Optional[List[LRSchedulerConfig]], Optional[int]]: optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(schedulers) > 1: raise MisconfigurationException( diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index d2488b7d81ba1..fe9093838c157 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -33,7 +33,6 @@ from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT -from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeUnion, STEP_OUTPUT TBroadcast = TypeVar("TBroadcast") diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 27b222502c373..0b44786873867 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -14,13 +14,6 @@ from abc import ABC from typing import List, Optional, Tuple -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable, Tuple, Union -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import optim -from torch.optim.optimizer import Optimizer -from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.core.optimizer import ( @@ -31,20 +24,6 @@ from pytorch_lightning.utilities import rank_zero_deprecation -@runtime_checkable -class _SupportedLRScheduler(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportedLRScheduler)`""" - - def step(self, *args: Any, **kwargs: Any) -> None: - ... - - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... - - class TrainerOptimizersMixin(ABC): r""" .. deprecated:: v1.6 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ce5680059a945..b9ae7f2dc2036 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -41,8 +41,6 @@ from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.loops.utilities import _parse_loop_limits, _reset_progress -from pytorch_lightning.utilities.cli import LR_SCHEDULER_REGISTRY -from pytorch_lightning.loops.utilities import _parse_loop_limits from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, DDPSpawnStrategy, diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 4d262f301062d..8d2b8f37763da 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -21,11 +21,6 @@ import torch from torch.optim import Optimizer -from typing import Any, Dict, Iterator, List, Mapping, Sequence, Union -from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau from torch.utils.data import DataLoader from torchmetrics import Metric from typing_extensions import TypedDict @@ -51,8 +46,7 @@ EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -# Copied from `torch.optim.lr_scheduler.pyi` -# Missing attributes were added to improve typing +# Inferred from `torch.optim.lr_scheduler.pyi` class _LRScheduler: optimizer: Optimizer @@ -65,13 +59,7 @@ def state_dict(self) -> dict: def load_state_dict(self, state_dict: dict) -> None: ... - def get_last_lr(self) -> List[float]: - ... - - def get_lr(self) -> float: - ... - - def step(self, epoch: Optional[int] = ...) -> None: + def step(self, *args: Any, **kwargs: Any) -> None: ... diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7c13577dee4d7..950bc9fe5c78b 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -331,7 +331,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict( name="lr_scheduler_step", args=(ANY,), - kwargs=dict(optimizer_idx=None, metrics=None), + kwargs=dict(optimizer_idx=None, metric=None), ) ] if i == (trainer.num_training_batches - 1) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index c0d604756f7d5..a3101428d3020 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -699,10 +699,10 @@ def training_step(self, batch, batch_idx, optimizer_idx): def training_epoch_end(self, *args, **kwargs): pass - def lr_scheduler_step(self, scheduler, optimizer_idx, metrics=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None): # step-level if optimizer_idx == 0: - super().lr_scheduler_step(scheduler, optimizer_idx, metrics) + super().lr_scheduler_step(scheduler, optimizer_idx, metric) # epoch-level elif optimizer_idx == 1: scheduler.step(epoch=self.current_epoch) @@ -753,5 +753,6 @@ def configure_optimizers(self): model = CustomBoringModel() trainer = Trainer(default_root_dir=tmpdir) + model.trainer = trainer with pytest.raises(ValueError, match="provided lr scheduler .* is invalid"): - trainer.init_optimizers(model) + _init_optimizers_and_lr_schedulers(model) From 78ebd316e8a40ea6d21db6ee92930c0e05f5537d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 28 Dec 2021 18:35:56 +0530 Subject: [PATCH 17/35] mypy --- pytorch_lightning/core/optimizer.py | 2 +- pytorch_lightning/utilities/types.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 401b95078095e..a2d8f53005051 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -340,7 +340,7 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - return lr_schedulers -def _get_default_scheduler_config() -> Dict[str, Any]: +def _get_default_scheduler_config() -> LRSchedulerConfig: return { "scheduler": None, "name": None, # no custom name diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 8d2b8f37763da..fc00aff1be26a 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -101,7 +101,7 @@ def load_state_dict(self, state_dict: dict) -> None: class LRSchedulerConfig(TypedDict): - scheduler: Union[_LRScheduler, ReduceLROnPlateau] + scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] name: Optional[str] interval: str frequency: int From ff11e76d59bca963adbd70cd3529fe0c2b247d1d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 3 Jan 2022 16:40:14 +0530 Subject: [PATCH 18/35] try fix mypy --- pytorch_lightning/core/optimizer.py | 2 +- pytorch_lightning/utilities/types.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index a2d8f53005051..401b95078095e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -340,7 +340,7 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - return lr_schedulers -def _get_default_scheduler_config() -> LRSchedulerConfig: +def _get_default_scheduler_config() -> Dict[str, Any]: return { "scheduler": None, "name": None, # no custom name diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index fc00aff1be26a..76b7d6545f0d7 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -50,7 +50,7 @@ class _LRScheduler: optimizer: Optimizer - def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None: + def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: ... def state_dict(self) -> dict: @@ -101,7 +101,7 @@ def load_state_dict(self, state_dict: dict) -> None: class LRSchedulerConfig(TypedDict): - scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] + scheduler: Union[_LRScheduler, ReduceLROnPlateau] name: Optional[str] interval: str frequency: int From f8de4d06ad4ccc4e3a7b83cf68e2527d6ca9133a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 3 Jan 2022 16:52:07 +0530 Subject: [PATCH 19/35] try fix mypy --- pytorch_lightning/core/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 401b95078095e..c6fca03d3ff2d 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -258,7 +258,7 @@ def _configure_optimizers( def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using automatic optimization.""" - lr_schedulers = [] + lr_schedulers: List[LRSchedulerConfig] = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): @@ -314,7 +314,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: """Convert each scheduler into dict structure with relevant information, when using manual optimization.""" - lr_schedulers = [] + lr_schedulers: List[LRSchedulerConfig] = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): From 404ba6b129f7920b98fe34937dcf7b98060b6cac Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 3 Jan 2022 17:03:32 +0530 Subject: [PATCH 20/35] try fix mypy --- pytorch_lightning/core/optimizer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index c6fca03d3ff2d..7e31ad13f5be5 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -24,7 +24,6 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import LRSchedulerConfig def do_nothing_closure() -> None: @@ -172,7 +171,7 @@ def closure_dis(): def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", -) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]: +) -> Tuple[List[Optimizer], List[Dict[str, Any]], List[int]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" model.trainer._lightning_optimizers = None optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model) @@ -256,9 +255,9 @@ def _configure_optimizers( return optimizers, lr_schedulers, optimizer_frequencies, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: """Convert each scheduler into dict structure with relevant information, when using automatic optimization.""" - lr_schedulers: List[LRSchedulerConfig] = [] + lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): @@ -312,9 +311,9 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] return lr_schedulers -def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: +def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: """Convert each scheduler into dict structure with relevant information, when using manual optimization.""" - lr_schedulers: List[LRSchedulerConfig] = [] + lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: if isinstance(scheduler, dict): @@ -353,7 +352,7 @@ def _get_default_scheduler_config() -> Dict[str, Any]: } -def _validate_scheduler_optimizer(optimizers: List[Optimizer], lr_schedulers: List[LRSchedulerConfig]) -> None: +def _validate_scheduler_optimizer(optimizers: List[Optimizer], lr_schedulers: List[Dict[str, Any]]) -> None: if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers): raise MisconfigurationException( "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." From 013f9ce145fd3490cb73ac2b81a053ccf521ce65 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 4 Jan 2022 03:46:26 +0530 Subject: [PATCH 21/35] use existing state_dict protocol --- .gitignore | 2 +- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/core/optimizer.py | 24 ++++++------------- .../connectors/checkpoint_connector.py | 12 ++++++++++ pytorch_lightning/utilities/auto_restart.py | 16 ++++--------- pytorch_lightning/utilities/types.py | 3 --- tests/trainer/optimization/test_optimizers.py | 4 ++-- 7 files changed, 28 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 97ba65852b821..923c2a1829c22 100644 --- a/.gitignore +++ b/.gitignore @@ -139,10 +139,10 @@ ENV/ .data/ Datasets/ mnist/ +MNIST/ legacy/checkpoints/ *.gz *ubyte -MNIST/ # pl tests ml-runs/ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 292ed728d5868..5402dda631a60 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1497,7 +1497,7 @@ def lr_scheduler_step( self, scheduler: LRSchedulerTypeUnion, optimizer_idx: Optional[int] = None, - metric: Optional[Union[float, torch.Tensor]] = None, + metric: Any = None, ) -> None: r""" Override this method to adjust the default way the @@ -1508,7 +1508,7 @@ def lr_scheduler_step( Args: scheduler: Learning rate scheduler. optimizer_idx: Index of the optimizer associated with this scheduler. - metric: Value of the metric used for schedulers like ``ReduceLROnPlateau``. + metric: Value of the monitor used for schedulers like ``ReduceLROnPlateau``. Examples:: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 7e31ad13f5be5..9d17e74de7f57 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -19,7 +19,6 @@ import torch from torch import optim from torch.optim import Optimizer -from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, rank_zero_warn @@ -257,6 +256,8 @@ def _configure_optimizers( def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: """Convert each scheduler into dict structure with relevant information, when using automatic optimization.""" + from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict + lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: @@ -305,7 +306,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] lr_schedulers.append({**default_config, "scheduler": scheduler}) current_scheduler = lr_schedulers[-1]["scheduler"] - if not isinstance(current_scheduler, _SupportedLRScheduler): + if not isinstance(current_scheduler, _SupportsStateDict): raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.") return lr_schedulers @@ -313,6 +314,9 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: """Convert each scheduler into dict structure with relevant information, when using manual optimization.""" + + from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict + lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: @@ -333,7 +337,7 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - lr_schedulers.append({**default_config, "scheduler": scheduler}) current_scheduler = lr_schedulers[-1]["scheduler"] - if not isinstance(current_scheduler, _SupportedLRScheduler): + if not isinstance(current_scheduler, _SupportsStateDict): raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.") return lr_schedulers @@ -405,17 +409,3 @@ def zero_grad(self, set_to_none: Optional[bool] = False) -> None: def __repr__(self) -> str: return "No Optimizer" - - -@runtime_checkable -class _SupportedLRScheduler(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportedLRScheduler)`""" - - def step(self, *args: Any, **kwargs: Any) -> None: - ... - - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index aa333a29942a9..af8b06f2db27b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -18,6 +18,7 @@ import torch from torchmetrics import Metric +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase @@ -489,3 +490,14 @@ def _get_loops_state_dict(self) -> Dict[str, Any]: "test_loop": self.trainer.test_loop.state_dict(), "predict_loop": self.trainer.predict_loop.state_dict(), } + + +@runtime_checkable +class _SupportsStateDict(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" + + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 9d26f4a6e0736..c8cb854565f6b 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -36,7 +36,6 @@ DataLoader, IterableDataset, ) -from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -577,6 +576,8 @@ def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dic # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` # therefore, we need to reload the states manually. + from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict + latest_worker_id = state_dict["latest_worker_id"] num_workers = state_dict["state"][latest_worker_id]["num_workers"] sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) @@ -635,17 +636,6 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} -@runtime_checkable -class _SupportsStateDict(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" - - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... - - class _StatefulDataLoaderIter: """This mixin is used to make PyTorch DataLoaderIter stateful.""" @@ -656,6 +646,8 @@ def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" + from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict + sampler_state = { k: v.state_dict() for k, v in self._loader.__dict__.items() diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 76b7d6545f0d7..4ee037259af9a 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -59,9 +59,6 @@ def state_dict(self) -> dict: def load_state_dict(self, state_dict: dict) -> None: ... - def step(self, *args: Any, **kwargs: Any) -> None: - ... - # Copied from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index a3101428d3020..cec23b320f80e 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -665,7 +665,7 @@ def on_save_checkpoint(self, checkpoint): def test_lr_scheduler_step_hook(tmpdir): - """Test that custom lr_schedulers works and `lr_scheduler_hook` is called at appropriate time.""" + """Test that custom lr scheduler works and `lr_scheduler_step` is called at appropriate time.""" class CustomEpochScheduler: def __init__(self, optimizer): @@ -735,7 +735,7 @@ def configure_optimizers(self): def test_invalid_lr_scheduler(tmpdir): - """Test that custom lr_schedulers works and `lr_scheduler_hook` is called at appropriate time.""" + """Test that custom lr scheduler raises an error if it doesn't follow basic protocol API.""" class CustomScheduler: def __init__(self, optimizer): From 65bda5f4f1f0e898cb710ea762a7c89eb81dcb8f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 4 Jan 2022 03:55:41 +0530 Subject: [PATCH 22/35] update import --- tests/utilities/test_auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e33bc91621a2b..19ca791bcf7e7 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -38,6 +38,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer +from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.auto_restart import ( @@ -48,7 +49,6 @@ _reload_dataloader_state_dict, _rotate_worker_indices, _SingleProcessDataLoaderIterStateful, - _SupportsStateDict, _teardown_dataloader_get_iterators, _validate_fault_tolerant_automatic, CaptureIterableDataset, From 26182db1b5fbf3abad481e900caa3c103de7703b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 4 Jan 2022 20:29:01 +0530 Subject: [PATCH 23/35] small updates --- docs/source/common/optimizers.rst | 2 +- pytorch_lightning/core/optimizer.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index bb9799f2d234f..2f4bccffd05a0 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -257,7 +257,7 @@ Bring your own Custom Learning Rate Schedulers Lightning allows using custom learning rate schedulers that aren't available in `PyTorch natively `_. One good example is `Timm Schedulers `_. When using custom learning rate schedulers relying on a different API from Native PyTorch ones, you should override the :meth:`~pytorch_lightning.core.lightning.LightningModule.lr_scheduler_step` with your desired logic. -If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it optimally by default. +If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it automatically by default. .. code-block:: python diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 9d17e74de7f57..361895c30884c 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -307,7 +307,10 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] current_scheduler = lr_schedulers[-1]["scheduler"] if not isinstance(current_scheduler, _SupportsStateDict): - raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.") + raise ValueError( + f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid." + " It should have `state_dict` and `load_state_dict` methods defined." + ) return lr_schedulers @@ -338,7 +341,10 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - current_scheduler = lr_schedulers[-1]["scheduler"] if not isinstance(current_scheduler, _SupportsStateDict): - raise ValueError(f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid.") + raise ValueError( + f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid." + " It should have `state_dict` and `load_state_dict` methods defined." + ) return lr_schedulers From b4fb944c9ff101d2937bd9e17b841e2988171c69 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 5 Jan 2022 00:10:27 +0530 Subject: [PATCH 24/35] add edge case check --- pytorch_lightning/core/optimizer.py | 50 +++++++++++++------ tests/trainer/optimization/test_optimizers.py | 33 +++++++++++- 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 0369bd190be41..4e2443de99652 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -19,10 +19,12 @@ import torch from torch import optim from torch.optim import Optimizer +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden def do_nothing_closure() -> None: @@ -187,6 +189,7 @@ def _init_optimizers_and_lr_schedulers( ) lr_schedulers = _configure_schedulers(lr_schedulers, monitor) _set_scheduler_opt_idx(optimizers, lr_schedulers) + _validate_scheduler_api(lr_schedulers, model) return optimizers, lr_schedulers, optimizer_frequencies @@ -256,8 +259,6 @@ def _configure_optimizers( def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: """Convert each scheduler into dict structure with relevant information, when using automatic optimization.""" - from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict - lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: @@ -305,21 +306,11 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] else: lr_schedulers.append({**default_config, "scheduler": scheduler}) - current_scheduler = lr_schedulers[-1]["scheduler"] - if not isinstance(current_scheduler, _SupportsStateDict): - raise ValueError( - f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid." - " It should have `state_dict` and `load_state_dict` methods defined." - ) - return lr_schedulers def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) -> List[Dict[str, Any]]: """Convert each scheduler into dict structure with relevant information, when using manual optimization.""" - - from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict - lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: @@ -339,14 +330,26 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - else: lr_schedulers.append({**default_config, "scheduler": scheduler}) - current_scheduler = lr_schedulers[-1]["scheduler"] - if not isinstance(current_scheduler, _SupportsStateDict): + return lr_schedulers + + +def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None: + from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict + + for scheduler_config in lr_schedulers: + scheduler = scheduler_config["scheduler"] + if not isinstance(scheduler, _SupportsStateDict): raise ValueError( - f"The provided lr scheduler `{current_scheduler.__class__.__name__}` is invalid." + f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." " It should have `state_dict` and `load_state_dict` methods defined." ) - return lr_schedulers + if not isinstance(scheduler, _SupportsLRScheduler) and not is_overridden("lr_scheduler_step", model): + raise MisconfigurationException( + f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow the PyTorch LR Scheduler" + " Protocol. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if" + " you are using a custom LR scheduler." + ) def _get_default_scheduler_config() -> Dict[str, Any]: @@ -427,3 +430,18 @@ def zero_grad(self, set_to_none: Optional[bool] = False) -> None: def __repr__(self) -> str: return "No Optimizer" + + +@runtime_checkable +class _SupportsLRScheduler(Protocol): + """This class is used to detect if a learning rate scheduler is supported for default configuration using + `isinstance(obj, _SupportsLRScheduler)`.""" + + def step(self, *args: Any, **kwargs: Any) -> None: + ... + + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 97f74cd6dd047..cd97c33c41974 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -752,7 +752,7 @@ def configure_optimizers(self): assert param_group["lr"] == lr / math.factorial(max_epochs) -def test_invalid_lr_scheduler(tmpdir): +def test_invalid_basic_lr_scheduler(tmpdir): """Test that custom lr scheduler raises an error if it doesn't follow basic protocol API.""" class CustomScheduler: @@ -774,3 +774,34 @@ def configure_optimizers(self): model.trainer = trainer with pytest.raises(ValueError, match="provided lr scheduler .* is invalid"): _init_optimizers_and_lr_schedulers(model) + + +def test_invalid_pt_lr_scheduler(tmpdir): + """Test that custom lr scheduler raises an error if it doesn't follow PyTorch LR Scheduler protocol API and + `lr_scheduler_step` is also not overridden.""" + + class CustomScheduler: + def __init__(self, optimizer): + self.optimizer = optimizer + + def update(self, epoch): + for param_group in self.optimizer.param_groups: + param_group["lr"] = param_group["lr"] / (epoch + 1) + + def state_dict(self): + return {key: value for key, value in self.__dict__.items() if key != "optimizer"} + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + class CustomBoringModel(BoringModel): + def configure_optimizers(self): + opt = torch.optim.SGD(self.parameters(), lr=1e-2) + lr_scheduler = CustomScheduler(opt) + return {"optimizer": opt, "lr_scheduler": lr_scheduler} + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir) + model.trainer = trainer + with pytest.raises(MisconfigurationException, match="doesn't follow the PyTorch LR Scheduler Protocol"): + _init_optimizers_and_lr_schedulers(model) From af8b1c371106ef4eec0f3ec1966e241ea522fac8 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 5 Jan 2022 00:16:57 +0530 Subject: [PATCH 25/35] rebase --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 950bc9fe5c78b..3dc5ada9bf72d 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -331,7 +331,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre dict( name="lr_scheduler_step", args=(ANY,), - kwargs=dict(optimizer_idx=None, metric=None), + kwargs=dict(optimizer_idx=0, metric=None), ) ] if i == (trainer.num_training_batches - 1) From 4c8ada65604eb23fbec8a82ac9239aaec67db02b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 5 Jan 2022 17:05:03 +0530 Subject: [PATCH 26/35] avoid protocol --- pytorch_lightning/core/optimizer.py | 19 ++----------------- tests/trainer/optimization/test_optimizers.py | 2 +- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 4e2443de99652..50e91d94f071a 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -19,12 +19,12 @@ import torch from torch import optim from torch.optim import Optimizer -from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import LRSchedulerTypeTuple def do_nothing_closure() -> None: @@ -344,7 +344,7 @@ def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.Ligh " It should have `state_dict` and `load_state_dict` methods defined." ) - if not isinstance(scheduler, _SupportsLRScheduler) and not is_overridden("lr_scheduler_step", model): + if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model): raise MisconfigurationException( f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow the PyTorch LR Scheduler" " Protocol. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if" @@ -430,18 +430,3 @@ def zero_grad(self, set_to_none: Optional[bool] = False) -> None: def __repr__(self) -> str: return "No Optimizer" - - -@runtime_checkable -class _SupportsLRScheduler(Protocol): - """This class is used to detect if a learning rate scheduler is supported for default configuration using - `isinstance(obj, _SupportsLRScheduler)`.""" - - def step(self, *args: Any, **kwargs: Any) -> None: - ... - - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index cd97c33c41974..2cfbb06367e3d 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -784,7 +784,7 @@ class CustomScheduler: def __init__(self, optimizer): self.optimizer = optimizer - def update(self, epoch): + def step(self, epoch): for param_group in self.optimizer.param_groups: param_group["lr"] = param_group["lr"] / (epoch + 1) From c497bdf3d55c473e2f0a5eaaad31712b9272f775 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 8 Jan 2022 00:54:57 +0530 Subject: [PATCH 27/35] move to types --- pytorch_lightning/core/optimizer.py | 4 +--- .../trainer/connectors/checkpoint_connector.py | 12 ------------ pytorch_lightning/utilities/auto_restart.py | 6 +----- pytorch_lightning/utilities/types.py | 13 ++++++++++++- tests/utilities/test_auto_restart.py | 2 +- 5 files changed, 15 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 50e91d94f071a..a9db26acaffd7 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -24,7 +24,7 @@ from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import LRSchedulerTypeTuple +from pytorch_lightning.utilities.types import _SupportsStateDict, LRSchedulerTypeTuple def do_nothing_closure() -> None: @@ -334,8 +334,6 @@ def _configure_schedulers_manual_opt(schedulers: list, monitor: Optional[str]) - def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.LightningModule") -> None: - from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict - for scheduler_config in lr_schedulers: scheduler = scheduler_config["scheduler"] if not isinstance(scheduler, _SupportsStateDict): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index d3c5801371797..918615019ea9e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -18,7 +18,6 @@ import torch from torchmetrics import Metric -from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.loops.utilities import _is_max_limit_reached @@ -476,14 +475,3 @@ def hpc_save_path(folderpath: _PATH) -> str: ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt") return filepath - - -@runtime_checkable -class _SupportsStateDict(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" - - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index c8cb854565f6b..ec630f795d8cc 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -42,6 +42,7 @@ from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _SupportsStateDict class FastForwardSampler(Sampler): @@ -575,9 +576,6 @@ def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` # therefore, we need to reload the states manually. - - from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict - latest_worker_id = state_dict["latest_worker_id"] num_workers = state_dict["state"][latest_worker_id]["num_workers"] sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) @@ -646,8 +644,6 @@ def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" - from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict - sampler_state = { k: v.state_dict() for k, v in self._loader.__dict__.items() diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 4ee037259af9a..c70a82fd377de 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -23,7 +23,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric -from typing_extensions import TypedDict +from typing_extensions import Protocol, runtime_checkable, TypedDict _NUMBER = Union[int, float] _METRIC = Union[Metric, torch.Tensor, _NUMBER] @@ -106,3 +106,14 @@ class LRSchedulerConfig(TypedDict): monitor: Optional[str] strict: bool opt_idx: Optional[int] + + +@runtime_checkable +class _SupportsStateDict(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" + + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c02f36d851ffd..e467436238f31 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -38,7 +38,6 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer -from pytorch_lightning.trainer.connectors.checkpoint_connector import _SupportsStateDict from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.auto_restart import ( @@ -60,6 +59,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.types import _SupportsStateDict from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf From f1553ee4142b2291a7761aab006c94eab68780e0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 8 Jan 2022 19:41:34 +0100 Subject: [PATCH 28/35] Inherit from the state dict protocol --- pytorch_lightning/utilities/types.py | 42 ++++++++++------------------ 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index c70a82fd377de..1d5cd272267d5 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -46,23 +46,29 @@ EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -# Inferred from `torch.optim.lr_scheduler.pyi` -class _LRScheduler: - optimizer: Optimizer +@runtime_checkable +class _SupportsStateDict(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" - def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: + def state_dict(self) -> Dict[str, Any]: ... - def state_dict(self) -> dict: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... - def load_state_dict(self, state_dict: dict) -> None: + +# Inferred from `torch.optim.lr_scheduler.pyi` +# Missing attributes were added to improve typing +class _LRScheduler(_SupportsStateDict): + optimizer: Optimizer + + def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: ... -# Copied from `torch.optim.lr_scheduler.pyi` +# Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing -class ReduceLROnPlateau: +class ReduceLROnPlateau(_SupportsStateDict): in_cooldown: bool optimizer: Optimizer @@ -81,15 +87,6 @@ def __init__( ) -> None: ... - def step(self, metrics: Any, epoch: Optional[int] = ...) -> None: - ... - - def state_dict(self) -> dict: - ... - - def load_state_dict(self, state_dict: dict) -> None: - ... - # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) @@ -106,14 +103,3 @@ class LRSchedulerConfig(TypedDict): monitor: Optional[str] strict: bool opt_idx: Optional[int] - - -@runtime_checkable -class _SupportsStateDict(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" - - def state_dict(self) -> Dict[str, Any]: - ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - ... From 99c92a5500d00b11c7571565f87e660084f3e016 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 8 Jan 2022 19:49:43 +0100 Subject: [PATCH 29/35] All positional, optimizer index always int --- docs/source/common/optimizers.rst | 2 +- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- tests/models/test_hooks.py | 8 +------- tests/trainer/optimization/test_optimizers.py | 2 +- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index a7df3dfc6c7f8..4ed54aed66410 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -537,7 +537,7 @@ If you are using native PyTorch schedulers, there is no need to override this ho return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}] - def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e3cf7c913f184..25472d5295cbe 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1496,8 +1496,8 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va def lr_scheduler_step( self, scheduler: LRSchedulerTypeUnion, - optimizer_idx: Optional[int] = None, - metric: Any = None, + optimizer_idx: int, + metric: Optional[Any], ) -> None: r""" Override this method to adjust the default way the diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 6fd71da7357e8..caa64558f110d 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -457,8 +457,8 @@ def _update_learning_rates( self.trainer._call_lightning_module_hook( "lr_scheduler_step", lr_scheduler["scheduler"], - optimizer_idx=lr_scheduler["opt_idx"], - metric=monitor_val, + lr_scheduler["opt_idx"] or 0, + monitor_val, ) self.scheduler_progress.increment_completed() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index f17069c88aae5..5f20d7bb4115a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -327,13 +327,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp), ), *( - [ - dict( - name="lr_scheduler_step", - args=(ANY,), - kwargs=dict(optimizer_idx=0, metric=None), - ) - ] + [dict(name="lr_scheduler_step", args=(ANY, 0, None))] if i == (trainer.num_training_batches - 1) else [] ), diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 2cfbb06367e3d..00caca48e65d9 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -717,7 +717,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): def training_epoch_end(self, *args, **kwargs): pass - def lr_scheduler_step(self, scheduler, optimizer_idx, metric=None): + def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # step-level if optimizer_idx == 0: super().lr_scheduler_step(scheduler, optimizer_idx, metric) From ae8ae092ad6668e66da71226379d9c788477c45b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 8 Jan 2022 20:20:55 +0100 Subject: [PATCH 30/35] Simplify tests --- pytorch_lightning/core/optimizer.py | 4 +- tests/trainer/optimization/test_optimizers.py | 40 +++++++++++-------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index a9db26acaffd7..c2716fdbdea72 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -344,8 +344,8 @@ def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.Ligh if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model): raise MisconfigurationException( - f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow the PyTorch LR Scheduler" - " Protocol. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if" + f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler" + " API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if" " you are using a custom LR scheduler." ) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 00caca48e65d9..6e5c2c0b50cca 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -752,16 +752,15 @@ def configure_optimizers(self): assert param_group["lr"] == lr / math.factorial(max_epochs) -def test_invalid_basic_lr_scheduler(tmpdir): - """Test that custom lr scheduler raises an error if it doesn't follow basic protocol API.""" +def test_invalid_scheduler_missing_state_dict(): + """Test that custom lr scheduler raises an error if it's missing the state dict.""" class CustomScheduler: def __init__(self, optimizer): self.optimizer = optimizer - def step(self, epoch): - for param_group in self.optimizer.param_groups: - param_group["lr"] = param_group["lr"] / (epoch + 1) + def step(self): + ... class CustomBoringModel(BoringModel): def configure_optimizers(self): @@ -770,13 +769,13 @@ def configure_optimizers(self): return {"optimizer": opt, "lr_scheduler": lr_scheduler} model = CustomBoringModel() - trainer = Trainer(default_root_dir=tmpdir) - model.trainer = trainer - with pytest.raises(ValueError, match="provided lr scheduler .* is invalid"): + model.trainer = Trainer() + with pytest.raises(ValueError, match="provided lr scheduler `CustomScheduler` is invalid"): _init_optimizers_and_lr_schedulers(model) -def test_invalid_pt_lr_scheduler(tmpdir): +@pytest.mark.parametrize("override", (False, True)) +def test_invalid_lr_scheduler_with_custom_step_method(override): """Test that custom lr scheduler raises an error if it doesn't follow PyTorch LR Scheduler protocol API and `lr_scheduler_step` is also not overridden.""" @@ -784,15 +783,14 @@ class CustomScheduler: def __init__(self, optimizer): self.optimizer = optimizer - def step(self, epoch): - for param_group in self.optimizer.param_groups: - param_group["lr"] = param_group["lr"] / (epoch + 1) + def step(self, foobar): # breaks the API, forces user to override `lr_scheduler_step` + ... def state_dict(self): - return {key: value for key, value in self.__dict__.items() if key != "optimizer"} + ... def load_state_dict(self, state_dict): - self.__dict__.update(state_dict) + ... class CustomBoringModel(BoringModel): def configure_optimizers(self): @@ -801,7 +799,15 @@ def configure_optimizers(self): return {"optimizer": opt, "lr_scheduler": lr_scheduler} model = CustomBoringModel() - trainer = Trainer(default_root_dir=tmpdir) - model.trainer = trainer - with pytest.raises(MisconfigurationException, match="doesn't follow the PyTorch LR Scheduler Protocol"): + model.trainer = Trainer() + if override: + + def lr_scheduler_step(*args): + ... + + # the user did override the hook, no error + model.lr_scheduler_step = lr_scheduler_step _init_optimizers_and_lr_schedulers(model) + else: + with pytest.raises(MisconfigurationException, match="CustomScheduler` doesn't follow"): + _init_optimizers_and_lr_schedulers(model) From 236b55d71dcfb10c05e1233dfbd3037a283dd5ef Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 8 Jan 2022 20:25:00 +0100 Subject: [PATCH 31/35] Minor test changes --- tests/trainer/optimization/test_optimizers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 6e5c2c0b50cca..a9c54d691478f 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -776,8 +776,7 @@ def configure_optimizers(self): @pytest.mark.parametrize("override", (False, True)) def test_invalid_lr_scheduler_with_custom_step_method(override): - """Test that custom lr scheduler raises an error if it doesn't follow PyTorch LR Scheduler protocol API and - `lr_scheduler_step` is also not overridden.""" + """Test that custom lr scheduler raises an error if it doesn't follow PyTorch LR Scheduler API.""" class CustomScheduler: def __init__(self, optimizer): @@ -802,7 +801,7 @@ def configure_optimizers(self): model.trainer = Trainer() if override: - def lr_scheduler_step(*args): + def lr_scheduler_step(*_): ... # the user did override the hook, no error From 7e82d1d7e3c0339af27cfb0a61b8b47414311cd7 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 9 Jan 2022 01:37:36 +0530 Subject: [PATCH 32/35] simplify test --- .../loops/epoch/training_epoch_loop.py | 2 +- tests/trainer/optimization/test_optimizers.py | 37 ++++++++++--------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index caa64558f110d..69432ee07dd0b 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -457,7 +457,7 @@ def _update_learning_rates( self.trainer._call_lightning_module_hook( "lr_scheduler_step", lr_scheduler["scheduler"], - lr_scheduler["opt_idx"] or 0, + lr_scheduler["opt_idx"], monitor_val, ) self.scheduler_progress.increment_completed() diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index a9c54d691478f..a2a981204c0cc 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from unittest import mock +from unittest.mock import patch import pytest import torch @@ -690,19 +690,17 @@ def __init__(self, optimizer): self.optimizer = optimizer def step(self, epoch): - for param_group in self.optimizer.param_groups: - param_group["lr"] = param_group["lr"] / (epoch + 1) + ... def state_dict(self): - return {key: value for key, value in self.__dict__.items() if key != "optimizer"} + ... def load_state_dict(self, state_dict): - self.__dict__.update(state_dict) + ... class CustomBoringModel(BoringModel): - def __init__(self, learning_rate): + def __init__(self): super().__init__() - self.learning_rate = learning_rate self.layer1 = torch.nn.Linear(32, 2) self.layer2 = torch.nn.Linear(32, 2) @@ -714,9 +712,6 @@ def training_step(self, batch, batch_idx, optimizer_idx): return self.loss(batch, output) - def training_epoch_end(self, *args, **kwargs): - pass - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # step-level if optimizer_idx == 0: @@ -726,30 +721,36 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, metric): scheduler.step(epoch=self.current_epoch) def configure_optimizers(self): - opt1 = torch.optim.SGD(self.layer1.parameters(), lr=self.learning_rate) + opt1 = torch.optim.SGD(self.layer1.parameters(), lr=1e-2) lr_scheduler1 = {"scheduler": torch.optim.lr_scheduler.StepLR(opt1, step_size=1), "interval": "step"} - opt2 = torch.optim.SGD(self.layer2.parameters(), lr=self.learning_rate) + opt2 = torch.optim.SGD(self.layer2.parameters(), lr=1e-2) lr_scheduler2 = CustomEpochScheduler(opt2) return {"optimizer": opt1, "lr_scheduler": lr_scheduler1}, { "optimizer": opt2, "lr_scheduler": lr_scheduler2, } - lr = 1e-2 max_epochs = 3 - model = CustomBoringModel(learning_rate=lr) + model = CustomBoringModel() + model.training_epoch_end = None + limit_train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, enable_checkpointing=False, logger=False, max_epochs=max_epochs, - limit_train_batches=2, + limit_train_batches=limit_train_batches, limit_val_batches=0, ) - trainer.fit(model) - for param_group in trainer.optimizers[1].param_groups: - assert param_group["lr"] == lr / math.factorial(max_epochs) + with patch.object(CustomEpochScheduler, "step") as mock_method_epoch, patch.object( + torch.optim.lr_scheduler.StepLR, "step" + ) as mock_method_step: + trainer.fit(model) + assert mock_method_epoch.call_count == max_epochs + assert ( + mock_method_step.call_count == max_epochs * limit_train_batches + 1 + ) # first step is called by PyTorch _LRScheduler def test_invalid_scheduler_missing_state_dict(): From 43532fdb23fff575f2216fedf1134dd3f8f87b44 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 9 Jan 2022 01:40:11 +0530 Subject: [PATCH 33/35] one line --- tests/trainer/optimization/test_optimizers.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index a2a981204c0cc..dda167ffdf342 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -705,12 +705,7 @@ def __init__(self): self.layer2 = torch.nn.Linear(32, 2) def training_step(self, batch, batch_idx, optimizer_idx): - if optimizer_idx == 0: - output = self.layer1(batch) - else: - output = self.layer2(batch) - - return self.loss(batch, output) + return (self.layer1 if optimizer_idx == 0 else self.layer2)(batch).sum() def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # step-level From 4ec0e5c32d1a86fcc3f5be2f166db7600764ba60 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 8 Jan 2022 21:32:09 +0100 Subject: [PATCH 34/35] Reduce further, test calls --- tests/trainer/optimization/test_optimizers.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index dda167ffdf342..42327edfb8a70 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock -from unittest.mock import patch +from unittest.mock import call, patch import pytest import torch @@ -699,13 +699,8 @@ def load_state_dict(self, state_dict): ... class CustomBoringModel(BoringModel): - def __init__(self): - super().__init__() - self.layer1 = torch.nn.Linear(32, 2) - self.layer2 = torch.nn.Linear(32, 2) - - def training_step(self, batch, batch_idx, optimizer_idx): - return (self.layer1 if optimizer_idx == 0 else self.layer2)(batch).sum() + def training_step(self, batch, batch_idx, optimizer_idx=0): + return super().training_step(batch, batch_idx) def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # step-level @@ -716,18 +711,18 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, metric): scheduler.step(epoch=self.current_epoch) def configure_optimizers(self): - opt1 = torch.optim.SGD(self.layer1.parameters(), lr=1e-2) + opt1 = torch.optim.SGD(self.layer.parameters(), lr=1e-2) lr_scheduler1 = {"scheduler": torch.optim.lr_scheduler.StepLR(opt1, step_size=1), "interval": "step"} - opt2 = torch.optim.SGD(self.layer2.parameters(), lr=1e-2) + opt2 = torch.optim.SGD(self.layer.parameters(), lr=1e-2) lr_scheduler2 = CustomEpochScheduler(opt2) return {"optimizer": opt1, "lr_scheduler": lr_scheduler1}, { "optimizer": opt2, "lr_scheduler": lr_scheduler2, } - max_epochs = 3 model = CustomBoringModel() model.training_epoch_end = None + max_epochs = 3 limit_train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, @@ -742,10 +737,10 @@ def configure_optimizers(self): torch.optim.lr_scheduler.StepLR, "step" ) as mock_method_step: trainer.fit(model) - assert mock_method_epoch.call_count == max_epochs - assert ( - mock_method_step.call_count == max_epochs * limit_train_batches + 1 - ) # first step is called by PyTorch _LRScheduler + + assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] + # first step is called by PyTorch _LRScheduler + assert mock_method_step.call_count == max_epochs * limit_train_batches + 1 def test_invalid_scheduler_missing_state_dict(): From b55504fa6a81d84755948277638a8b885949b7b7 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 10 Jan 2022 20:01:31 +0530 Subject: [PATCH 35/35] use typeerror --- pytorch_lightning/core/optimizer.py | 2 +- tests/trainer/optimization/test_optimizers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index c2716fdbdea72..3b0cdffff497e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -337,7 +337,7 @@ def _validate_scheduler_api(lr_schedulers: List[Dict[str, Any]], model: "pl.Ligh for scheduler_config in lr_schedulers: scheduler = scheduler_config["scheduler"] if not isinstance(scheduler, _SupportsStateDict): - raise ValueError( + raise TypeError( f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid." " It should have `state_dict` and `load_state_dict` methods defined." ) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 42327edfb8a70..e960eabcb9b62 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -761,7 +761,7 @@ def configure_optimizers(self): model = CustomBoringModel() model.trainer = Trainer() - with pytest.raises(ValueError, match="provided lr scheduler `CustomScheduler` is invalid"): + with pytest.raises(TypeError, match="provided lr scheduler `CustomScheduler` is invalid"): _init_optimizers_and_lr_schedulers(model)