diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ad5733118758..ea62090e4447a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) +- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324)) + + - Added support for `--lr_scheduler=ReduceLROnPlateau` to the `LightningCLI` ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860)) diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst index a4cf8abc1963e..d2e5d1467b454 100644 --- a/docs/source/extensions/loops.rst +++ b/docs/source/extensions/loops.rst @@ -266,17 +266,25 @@ run (optional) Subloops -------- -When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: +When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.replace` method: .. code-block:: python - # Step 1: create your loop - my_epoch_loop = MyEpochLoop() + # This takes care of properly instantiating the new Loop and setting all references + trainer.fit_loop.replace(epoch_loop=MyEpochLoop) + # Trainer runs the fit loop with your new epoch loop! + trainer.fit(model) - # Step 2: use connect() - trainer.fit_loop.connect(epoch_loop=my_epoch_loop) +Alternatively, for more fine-grained control, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: - # Trainer runs the fit loop with your new epoch loop! +.. code-block:: python + + # Optional: stitch back the trainer arguments + epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) + # Optional: connect children loops as they might have existing state + epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) + # Instantiate and connect the loop. + trainer.fit_loop.connect(epoch_loop=epoch_loop) trainer.fit(model) More about the built-in loops and how they are composed is explained in the next section. diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index be691669a93f5..b47d1935e89be 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union from deprecate import void from torchmetrics import Metric @@ -99,6 +99,51 @@ def connect(self, **kwargs: "Loop") -> None: Linked loops should form a tree. """ + def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: + """Optionally replace one or multiple of this loop's sub-loops. + + This methods takes care of instantiating the class (if necessary) with all existing arguments, connecting all + sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and connecting the new loop to + the parent. + + Args: + **loops: ``Loop`` subclasses or instances. The name used should match the loop attribute name you want to + replace. + + Raises: + MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those + of the Loop class it replaces. + """ + new_loops = {} + + for name, type_or_object in loops.items(): + old_loop = getattr(self, name) + + if isinstance(type_or_object, type): + # compare the signatures + old_parameters = inspect.signature(old_loop.__class__.__init__).parameters + current_parameters = inspect.signature(type_or_object.__init__).parameters + if old_parameters != current_parameters: + raise MisconfigurationException( + f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the" + f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not." + ) + # instantiate the loop + kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} + loop = type_or_object(**kwargs) # type: ignore[call-arg] + else: + loop = type_or_object + + # connect sub-loops + kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)} + loop.connect(**kwargs) + # set the trainer reference + loop.trainer = self.trainer + + new_loops[name] = loop + # connect to self + self.connect(**new_loops) + def on_skip(self) -> T: """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index d553d386205f5..1d4b0ea4cd6b1 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -127,6 +127,6 @@ def on_advance_end(self): assert not is_overridden("test_epoch_end", model) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3) - trainer.test_loop.connect(TestLoop()) + trainer.test_loop.replace(epoch_loop=TestLoop) trainer.test(model) assert did_assert diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index b1f93d82ab616..08ef7153e6bc3 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -24,8 +24,9 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.loops import Loop, TrainingBatchLoop +from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -102,6 +103,49 @@ def test_connect_subloops(tmpdir): assert new_batch_loop.trainer is trainer +def test_replace_loops(): + class TestLoop(TrainingEpochLoop): + def __init__(self, foo): + super().__init__() + + trainer = Trainer(min_steps=123, max_steps=321) + + with pytest.raises( + MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`" + ): + trainer.fit_loop.replace(epoch_loop=TestLoop) + + class TestLoop(TrainingEpochLoop): + ... + + # test passing a loop where previous state should be connected + old_loop = trainer.fit_loop.epoch_loop + trainer.fit_loop.replace(epoch_loop=TestLoop) + new_loop = trainer.fit_loop.epoch_loop + + assert isinstance(new_loop, TestLoop) + assert trainer.fit_loop.epoch_loop is new_loop + assert new_loop.min_steps == 123 + assert new_loop.max_steps == 321 + assert new_loop.batch_loop is old_loop.batch_loop + assert new_loop.val_loop is old_loop.val_loop + assert new_loop.trainer is trainer + + class MyBatchLoop(TrainingBatchLoop): + ... + + class MyEvalLoop(EvaluationLoop): + ... + + # test passing more than one where one is an instance and the other a class + trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop()) + new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop + new_val_loop = trainer.fit_loop.epoch_loop.val_loop + + assert isinstance(new_batch_loop, MyBatchLoop) + assert isinstance(new_val_loop, MyEvalLoop) + + class CustomException(Exception): pass