Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Loop.replace #10324

Merged
merged 12 commits into from Dec 16, 2021
3 changes: 2 additions & 1 deletion CHANGELOG.md
Expand Up @@ -31,7 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))


-
- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324))


### Changed

Expand Down
20 changes: 14 additions & 6 deletions docs/source/extensions/loops.rst
Expand Up @@ -267,17 +267,25 @@ run (optional)
Subloops
--------

When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:
When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.replace` method:

.. code-block:: python

# Step 1: create your loop
my_epoch_loop = MyEpochLoop()
# This takes care of properly instantiating the new Loop and setting all references
trainer.fit_loop.replace(epoch_loop=MyEpochLoop)
# Trainer runs the fit loop with your new epoch loop!
trainer.fit(model)

# Step 2: use connect()
trainer.fit_loop.connect(epoch_loop=my_epoch_loop)
Alternatively, for more fine-grained control, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:

# Trainer runs the fit loop with your new epoch loop!
.. code-block:: python

# Optional: stitch back the trainer arguments
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
# Optional: connect children loops as they might have existing state
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
# Instantiate and connect the loop.
trainer.fit_loop.connect(epoch_loop=epoch_loop)
trainer.fit(model)

More about the built-in loops and how they are composed is explained in the next section.
Expand Down
49 changes: 47 additions & 2 deletions pytorch_lightning/loops/base.py
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, TypeVar
from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union

from deprecate import void
from torchmetrics import Metric
Expand Down Expand Up @@ -99,6 +99,51 @@ def connect(self, **kwargs: "Loop") -> None:
Linked loops should form a tree.
"""

def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None:
"""Optionally replace one or multiple of this loop's sub-loops.
carmocca marked this conversation as resolved.
Show resolved Hide resolved

This methods takes care of instantiating the class (if necessary) with all existing arguments, connecting all
sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and conneting the new loop to
carmocca marked this conversation as resolved.
Show resolved Hide resolved
the parent.

Args:
**loops: A ``Loop`` subclass or instance. The name used should match the loop attribute name you want to
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
replace.

Raises:
MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those
of the Loop class it replaces.
"""
new_loops = {}

for name, type_or_object in loops.items():
old_loop = getattr(self, name)

if isinstance(type_or_object, type):
# compare the signatures
old_parameters = inspect.signature(old_loop.__class__.__init__).parameters
current_parameters = inspect.signature(type_or_object.__init__).parameters
if old_parameters != current_parameters:
raise MisconfigurationException(
f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the"
f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not."
)
# instantiate the loop
kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"}
carmocca marked this conversation as resolved.
Show resolved Hide resolved
loop = type_or_object(**kwargs)
else:
loop = type_or_object

# connect sub-loops
kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)}
carmocca marked this conversation as resolved.
Show resolved Hide resolved
loop.connect(**kwargs)
# set the trainer reference
loop.trainer = self.trainer

new_loops[name] = loop
# connect to self
self.connect(**new_loops)

def on_skip(self) -> Optional[Any]:
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.

Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_evaluation_loop.py
Expand Up @@ -127,6 +127,6 @@ def on_advance_end(self):
assert not is_overridden("test_epoch_end", model)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3)
trainer.test_loop.connect(TestLoop())
trainer.test_loop.replace(epoch_loop=TestLoop)
trainer.test(model)
assert did_assert
46 changes: 45 additions & 1 deletion tests/loops/test_loops.py
Expand Up @@ -24,8 +24,9 @@

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -102,6 +103,49 @@ def test_connect_subloops(tmpdir):
assert new_batch_loop.trainer is trainer


def test_replace_loops():
class TestLoop(TrainingEpochLoop):
def __init__(self, foo):
super().__init__()

trainer = Trainer(min_steps=123, max_steps=321)

with pytest.raises(
MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`"
):
trainer.fit_loop.replace(epoch_loop=TestLoop)

class TestLoop(TrainingEpochLoop):
...

# test passing a loop where previous state should be connected
old_loop = trainer.fit_loop.epoch_loop
trainer.fit_loop.replace(epoch_loop=TestLoop)
new_loop = trainer.fit_loop.epoch_loop

assert isinstance(new_loop, TestLoop)
assert trainer.fit_loop.epoch_loop is new_loop
assert new_loop.min_steps == 123
assert new_loop.max_steps == 321
assert new_loop.batch_loop is old_loop.batch_loop
assert new_loop.val_loop is old_loop.val_loop
assert new_loop.trainer is trainer

class MyBatchLoop(TrainingBatchLoop):
...

class MyEvalLoop(EvaluationLoop):
...

# test passing more than one where one is an instance and the other a class
trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop())
new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop
new_val_loop = trainer.fit_loop.epoch_loop.val_loop

assert isinstance(new_batch_loop, MyBatchLoop)
assert isinstance(new_val_loop, MyEvalLoop)


class CustomException(Exception):
pass

Expand Down