Skip to content

Commit

Permalink
Add Loop.replace
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Nov 2, 2021
1 parent 4cd7e77 commit 010e846
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 29 deletions.
20 changes: 14 additions & 6 deletions docs/source/extensions/loops.rst
Expand Up @@ -267,17 +267,25 @@ run (optional)
Subloops
--------

When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:
When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.replace` method:

.. code-block:: python
# Step 1: create your loop
my_epoch_loop = MyEpochLoop()
# This takes care of properly instantiating the new Loop and setting all references
trainer.fit_loop.replace(MyEpochLoop)
# Trainer runs the fit loop with your new epoch loop!
trainer.fit(model)
# Step 2: use connect()
trainer.fit_loop.connect(epoch_loop=my_epoch_loop)
Alternatively, for more fine-grained control, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method:

# Trainer runs the fit loop with your new epoch loop!
.. code-block:: python
# Optional: stitch back the trainer arguments
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
# Optional: connect children loops as they might have existing state
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
# Instantiate and connect the loop.
trainer.fit_loop.connect(epoch_loop=epoch_loop)
trainer.fit(model)
More about the built-in loops and how they are composed is explained in the next section.
Expand Down
33 changes: 31 additions & 2 deletions pytorch_lightning/loops/base.py
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, TypeVar
from typing import Any, Dict, Generic, Optional, Type, TypeVar

from deprecate import void
from torchmetrics import Metric
Expand Down Expand Up @@ -99,6 +99,35 @@ def connect(self, **kwargs: "Loop") -> None:
Linked loops should form a tree.
"""

def replace(self, loop_cls: Type["Loop"]) -> "Loop":
# find the target
for name, old_loop in self.__dict__.items():
if issubclass(loop_cls, type(old_loop)):
break
else:
raise MisconfigurationException(
f"Did not find an attribute with the same parent class as `{loop_cls.__name__}`"
)
# compare the signatures
old_parameters = inspect.signature(old_loop.__class__.__init__).parameters
current_parameters = inspect.signature(loop_cls.__init__).parameters
if old_parameters != current_parameters:
raise MisconfigurationException(
f"`{self.__class__.__name__}.replace({loop_cls.__name__})` can only be used if the `__init__`"
f" signatures match but `{old_loop.__class__.__name__}` does not."
)
# instantiate the loop
kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"}
loop = loop_cls(**kwargs)
# connect subloops
kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)}
loop.connect(**kwargs)
# set the trainer reference
loop.trainer = self.trainer
# connect to self
self.connect(**{name: loop})
return loop

def on_skip(self) -> Optional[Any]:
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Expand Up @@ -48,7 +48,7 @@ def done(self) -> bool:
return len(self._remaining_splits) == 0

def connect(
self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None
self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None
) -> None:
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Expand Up @@ -45,16 +45,13 @@ def __init__(self) -> None:
self._num_dataloaders: Optional[int] = None
self._dataloader_iter: Optional[Iterator] = None
self._data_fetcher: Optional[DataFetcher] = None
self._dataloader_state_dict: Dict[str, Any] = None
self._dataloader_state_dict: Dict[str, Any] = {}

@property
def done(self) -> bool:
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
return self.batch_progress.current.completed >= self._dl_max_batches

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
"""Resets the loop's internal state."""
self._dl_max_batches = None
Expand Down Expand Up @@ -181,7 +178,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
if not self.trainer.sanity_checking and self._dataloader_state_dict:
reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
self._dataloader_state_dict = None
self._dataloader_state_dict = {}

def _num_completed_batches_reached(self) -> bool:
epoch_finished_on_completed = self.batch_progress.current.completed == self._dl_max_batches
Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -61,8 +61,8 @@ def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None:
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None
self.batch_loop = TrainingBatchLoop()
self.val_loop = loops.EvaluationLoop()

self._results = ResultCollection(training=True)
self._outputs: _OUTPUTS_TYPE = []
Expand Down Expand Up @@ -106,7 +106,7 @@ def done(self) -> bool:

def connect(
self,
batch_loop: TrainingBatchLoop = None,
batch_loop: Optional[TrainingBatchLoop] = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
Expand All @@ -117,8 +117,6 @@ def connect(

def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
assert self.batch_loop is not None
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/fit_loop.py
Expand Up @@ -48,7 +48,7 @@ def __init__(

self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
self.epoch_loop = TrainingEpochLoop()
self.epoch_progress = Progress()
self._is_fresh_start_epoch: bool = True

Expand Down Expand Up @@ -128,15 +128,11 @@ def running_loss(self) -> TensorRunningAccum:
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
assert self.epoch_loop.batch_loop is not None
assert self.epoch_loop.batch_loop.optimizer_loop is not None
return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
assert self.epoch_loop.batch_loop is not None
assert self.epoch_loop.batch_loop.optimizer_loop is not None
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

@property
Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_evaluation_loop.py
Expand Up @@ -127,6 +127,6 @@ def on_advance_end(self):
assert not is_overridden("test_epoch_end", model)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3)
trainer.test_loop.connect(TestLoop())
trainer.test_loop.replace(TestLoop)
trainer.test(model)
assert did_assert
38 changes: 34 additions & 4 deletions tests/loops/test_loops.py
Expand Up @@ -22,12 +22,12 @@
import torch
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader

from pl_examples.bug_report_model import RandomDataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.loops import Loop, PredictionEpochLoop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.trainer.progress import BaseProgress
from tests.helpers import BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


Expand Down Expand Up @@ -62,7 +62,7 @@ def test_connect_loops_direct(loop_name):

trainer = Trainer()

# trainer.loop = loop
# trainer.loop_name = loop
setattr(trainer, loop_name, loop)
assert loop.trainer is trainer

Expand Down Expand Up @@ -103,6 +103,36 @@ def test_connect_subloops(tmpdir):
assert new_batch_loop.trainer is trainer


def test_replace_loops():
class TestLoop(TrainingEpochLoop):
def __init__(self, foo):
super().__init__()

trainer = Trainer(min_steps=123, max_steps=321)

with pytest.raises(
MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`"
):
trainer.fit_loop.replace(TestLoop)

with pytest.raises(MisconfigurationException, match="Did not find.*same parent class as `PredictionEpochLoop`"):
trainer.fit_loop.replace(PredictionEpochLoop)

class TestLoop(TrainingEpochLoop):
...

old_loop = trainer.fit_loop.epoch_loop
new_loop = trainer.fit_loop.replace(TestLoop)

assert isinstance(new_loop, TestLoop)
assert trainer.fit_loop.epoch_loop is new_loop
assert new_loop.min_steps == 123
assert new_loop.max_steps == 321
assert new_loop.batch_loop is old_loop.batch_loop
assert new_loop.val_loop is old_loop.val_loop
assert new_loop.trainer is trainer


class CustomException(Exception):
pass

Expand Down

0 comments on commit 010e846

Please sign in to comment.