From 1973904d59f94e2ae9468a38f1b86651d40af0ff Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 3 Nov 2021 00:23:11 +0100 Subject: [PATCH 1/8] Add `Loop.replace` --- docs/source/extensions/loops.rst | 20 +++++++--- pytorch_lightning/loops/base.py | 33 +++++++++++++++- .../loops/batch/training_batch_loop.py | 2 +- .../loops/epoch/evaluation_epoch_loop.py | 7 +--- .../loops/epoch/training_epoch_loop.py | 8 ++-- pytorch_lightning/loops/fit_loop.py | 6 +-- tests/loops/test_evaluation_loop.py | 2 +- tests/loops/test_loops.py | 38 +++++++++++++++++-- 8 files changed, 88 insertions(+), 28 deletions(-) diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst index 9291fca4819d2..267f637be63a1 100644 --- a/docs/source/extensions/loops.rst +++ b/docs/source/extensions/loops.rst @@ -267,17 +267,25 @@ run (optional) Subloops -------- -When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: +When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.replace` method: .. code-block:: python - # Step 1: create your loop - my_epoch_loop = MyEpochLoop() + # This takes care of properly instantiating the new Loop and setting all references + trainer.fit_loop.replace(MyEpochLoop) + # Trainer runs the fit loop with your new epoch loop! + trainer.fit(model) - # Step 2: use connect() - trainer.fit_loop.connect(epoch_loop=my_epoch_loop) +Alternatively, for more fine-grained control, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: - # Trainer runs the fit loop with your new epoch loop! +.. code-block:: python + + # Optional: stitch back the trainer arguments + epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps) + # Optional: connect children loops as they might have existing state + epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop) + # Instantiate and connect the loop. + trainer.fit_loop.connect(epoch_loop=epoch_loop) trainer.fit(model) More about the built-in loops and how they are composed is explained in the next section. diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 38b0d652e5d2f..37d63fca885aa 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Dict, Generic, Optional, Type, TypeVar from deprecate import void from torchmetrics import Metric @@ -99,6 +99,35 @@ def connect(self, **kwargs: "Loop") -> None: Linked loops should form a tree. """ + def replace(self, loop_cls: Type["Loop"]) -> "Loop": + # find the target + for name, old_loop in self.__dict__.items(): + if issubclass(loop_cls, type(old_loop)): + break + else: + raise MisconfigurationException( + f"Did not find an attribute with the same parent class as `{loop_cls.__name__}`" + ) + # compare the signatures + old_parameters = inspect.signature(old_loop.__class__.__init__).parameters + current_parameters = inspect.signature(loop_cls.__init__).parameters + if old_parameters != current_parameters: + raise MisconfigurationException( + f"`{self.__class__.__name__}.replace({loop_cls.__name__})` can only be used if the `__init__`" + f" signatures match but `{old_loop.__class__.__name__}` does not." + ) + # instantiate the loop + kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} + loop = loop_cls(**kwargs) + # connect subloops + kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)} + loop.connect(**kwargs) + # set the trainer reference + loop.trainer = self.trainer + # connect to self + self.connect(**{name: loop}) + return loop + def on_skip(self) -> Optional[Any]: """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index c1d800c42d853..7ed199e56be13 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -48,7 +48,7 @@ def done(self) -> bool: return len(self._remaining_splits) == 0 def connect( - self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None + self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None ) -> None: if optimizer_loop is not None: self.optimizer_loop = optimizer_loop diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 2fc572ea252e6..303685c002410 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -44,16 +44,13 @@ def __init__(self) -> None: self._num_dataloaders: Optional[int] = None self._dataloader_iter: Optional[Iterator] = None self._data_fetcher: Optional[DataFetcher] = None - self._dataloader_state_dict: Dict[str, Any] = None + self._dataloader_state_dict: Dict[str, Any] = {} @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def reset(self) -> None: """Resets the loop's internal state.""" self._dl_max_batches = None @@ -183,7 +180,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: _reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) - self._dataloader_state_dict = None + self._dataloader_state_dict = {} def _num_completed_batches_reached(self) -> bool: epoch_finished_on_completed = self.batch_progress.current.completed == self._dl_max_batches diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 8ddca3ad505e8..8c3e3143fd01b 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -61,8 +61,8 @@ def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None: self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() - self.batch_loop: Optional[TrainingBatchLoop] = None - self.val_loop: Optional["loops.EvaluationLoop"] = None + self.batch_loop = TrainingBatchLoop() + self.val_loop = loops.EvaluationLoop() self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] @@ -106,7 +106,7 @@ def done(self) -> bool: def connect( self, - batch_loop: TrainingBatchLoop = None, + batch_loop: Optional[TrainingBatchLoop] = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" @@ -117,8 +117,6 @@ def connect( def reset(self) -> None: """Resets the internal state of the loop for a new run.""" - assert self.batch_loop is not None - assert self.batch_loop.optimizer_loop is not None if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index df6634c963851..4040d08d4f3dd 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -48,7 +48,7 @@ def __init__( self.max_epochs = max_epochs self.min_epochs = min_epochs - self.epoch_loop: Optional[TrainingEpochLoop] = None + self.epoch_loop = TrainingEpochLoop() self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True @@ -128,15 +128,11 @@ def running_loss(self) -> TensorRunningAccum: @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" - assert self.epoch_loop.batch_loop is not None - assert self.epoch_loop.batch_loop.optimizer_loop is not None return self.epoch_loop.batch_loop.optimizer_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" - assert self.epoch_loop.batch_loop is not None - assert self.epoch_loop.batch_loop.optimizer_loop is not None self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value @property diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index d6b2c15553fb9..1507817357299 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -127,6 +127,6 @@ def on_advance_end(self): assert not is_overridden("test_epoch_end", model) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3) - trainer.test_loop.connect(TestLoop()) + trainer.test_loop.replace(TestLoop) trainer.test(model) assert did_assert diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 3c8912e145305..de32fc7820121 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -23,9 +23,11 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.loops import Loop, TrainingBatchLoop +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loops import Loop, PredictionEpochLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -61,7 +63,7 @@ def test_connect_loops_direct(loop_name): trainer = Trainer() - # trainer.loop = loop + # trainer.loop_name = loop setattr(trainer, loop_name, loop) assert loop.trainer is trainer @@ -102,6 +104,36 @@ def test_connect_subloops(tmpdir): assert new_batch_loop.trainer is trainer +def test_replace_loops(): + class TestLoop(TrainingEpochLoop): + def __init__(self, foo): + super().__init__() + + trainer = Trainer(min_steps=123, max_steps=321) + + with pytest.raises( + MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`" + ): + trainer.fit_loop.replace(TestLoop) + + with pytest.raises(MisconfigurationException, match="Did not find.*same parent class as `PredictionEpochLoop`"): + trainer.fit_loop.replace(PredictionEpochLoop) + + class TestLoop(TrainingEpochLoop): + ... + + old_loop = trainer.fit_loop.epoch_loop + new_loop = trainer.fit_loop.replace(TestLoop) + + assert isinstance(new_loop, TestLoop) + assert trainer.fit_loop.epoch_loop is new_loop + assert new_loop.min_steps == 123 + assert new_loop.max_steps == 321 + assert new_loop.batch_loop is old_loop.batch_loop + assert new_loop.val_loop is old_loop.val_loop + assert new_loop.trainer is trainer + + class CustomException(Exception): pass From 631d8e0ebb997055bfff40970bd5069dd0b820d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Nov 2021 17:09:00 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/loops/test_loops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index de32fc7820121..81ac04e726670 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -23,8 +23,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.loops import Loop, PredictionEpochLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException From f43bfdb706cb61c46413eaf70ef70dce8ac45612 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Nov 2021 19:11:42 +0100 Subject: [PATCH 3/8] Add docstring --- docs/source/extensions/loops.rst | 2 +- pytorch_lightning/loops/base.py | 58 +++++++++++++++-------------- tests/loops/test_evaluation_loop.py | 2 +- tests/loops/test_loops.py | 25 ++++++++++--- 4 files changed, 51 insertions(+), 36 deletions(-) diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst index 267f637be63a1..db8893452b305 100644 --- a/docs/source/extensions/loops.rst +++ b/docs/source/extensions/loops.rst @@ -272,7 +272,7 @@ When you want to customize nested loops within loops, use the :meth:`~pytorch_li .. code-block:: python # This takes care of properly instantiating the new Loop and setting all references - trainer.fit_loop.replace(MyEpochLoop) + trainer.fit_loop.replace(epoch_loop=MyEpochLoop) # Trainer runs the fit loop with your new epoch loop! trainer.fit(model) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 37d63fca885aa..8d5cb08d5591d 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, Type, TypeVar +from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union from deprecate import void from torchmetrics import Metric @@ -99,34 +99,36 @@ def connect(self, **kwargs: "Loop") -> None: Linked loops should form a tree. """ - def replace(self, loop_cls: Type["Loop"]) -> "Loop": - # find the target - for name, old_loop in self.__dict__.items(): - if issubclass(loop_cls, type(old_loop)): - break - else: - raise MisconfigurationException( - f"Did not find an attribute with the same parent class as `{loop_cls.__name__}`" - ) - # compare the signatures - old_parameters = inspect.signature(old_loop.__class__.__init__).parameters - current_parameters = inspect.signature(loop_cls.__init__).parameters - if old_parameters != current_parameters: - raise MisconfigurationException( - f"`{self.__class__.__name__}.replace({loop_cls.__name__})` can only be used if the `__init__`" - f" signatures match but `{old_loop.__class__.__name__}` does not." - ) - # instantiate the loop - kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} - loop = loop_cls(**kwargs) - # connect subloops - kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)} - loop.connect(**kwargs) - # set the trainer reference - loop.trainer = self.trainer + def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: + new_loops = {} + + for name, type_or_object in loops.items(): + old_loop = getattr(self, name) + + if isinstance(type_or_object, type): + # compare the signatures + old_parameters = inspect.signature(old_loop.__class__.__init__).parameters + current_parameters = inspect.signature(type_or_object.__init__).parameters + if old_parameters != current_parameters: + raise MisconfigurationException( + f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the" + f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not." + ) + # instantiate the loop + kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} + loop = type_or_object(**kwargs) + else: + loop = type_or_object + + # connect sub-loops + kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)} + loop.connect(**kwargs) + # set the trainer reference + loop.trainer = self.trainer + + new_loops[name] = loop # connect to self - self.connect(**{name: loop}) - return loop + self.connect(**new_loops) def on_skip(self) -> Optional[Any]: """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 1507817357299..4b58f2a20b93b 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -127,6 +127,6 @@ def on_advance_end(self): assert not is_overridden("test_epoch_end", model) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3) - trainer.test_loop.replace(TestLoop) + trainer.test_loop.replace(epoch_loop=TestLoop) trainer.test(model) assert did_assert diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 81ac04e726670..52a62b1ae50a1 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -24,7 +24,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.loops import Loop, PredictionEpochLoop, TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -113,16 +113,15 @@ def __init__(self, foo): with pytest.raises( MisconfigurationException, match=r"FitLoop.replace\(TestLoop\)`.*`__init__`.*`TrainingEpochLoop`" ): - trainer.fit_loop.replace(TestLoop) - - with pytest.raises(MisconfigurationException, match="Did not find.*same parent class as `PredictionEpochLoop`"): - trainer.fit_loop.replace(PredictionEpochLoop) + trainer.fit_loop.replace(epoch_loop=TestLoop) class TestLoop(TrainingEpochLoop): ... + # test passing a loop where previous state should be connected old_loop = trainer.fit_loop.epoch_loop - new_loop = trainer.fit_loop.replace(TestLoop) + trainer.fit_loop.replace(epoch_loop=TestLoop) + new_loop = trainer.fit_loop.epoch_loop assert isinstance(new_loop, TestLoop) assert trainer.fit_loop.epoch_loop is new_loop @@ -132,6 +131,20 @@ class TestLoop(TrainingEpochLoop): assert new_loop.val_loop is old_loop.val_loop assert new_loop.trainer is trainer + class MyBatchLoop(TrainingBatchLoop): + ... + + class MyEvalLoop(EvaluationLoop): + ... + + # test passing more than one where one is an instance and the other a class + trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop()) + new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop + new_val_loop = trainer.fit_loop.epoch_loop.val_loop + + assert isinstance(new_batch_loop, MyBatchLoop) + assert isinstance(new_val_loop, MyEvalLoop) + class CustomException(Exception): pass From 1680b34093f1621062501d4ec309ef45b3280ef7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Nov 2021 19:11:03 +0100 Subject: [PATCH 4/8] Add docstring --- pytorch_lightning/loops/base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 8d5cb08d5591d..2191293a2e858 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -100,6 +100,20 @@ def connect(self, **kwargs: "Loop") -> None: """ def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: + """Optionally replace one or multiple of this loop's sub-loops. + + This methods takes care of instantiating the class (if necessary) with all existing arguments, connecting all + sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and conneting the new loop to + the parent. + + Args: + **loops: A ``Loop`` subclass or instance. The name used should match the loop attribute name you want to + replace. + + Raises: + MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those + of the Loop class it replaces. + """ new_loops = {} for name, type_or_object in loops.items(): From 3543a816369d8dabaecc72aa5f3d06ff558904e6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Nov 2021 19:25:40 +0100 Subject: [PATCH 5/8] Update CHANGELOG --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f52369b443164..ce6e06e7fbf32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) -- +- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324)) + ### Changed From bfc1f88424976e36c6971aeb2c6b4479b57099ce Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 7 Dec 2021 13:53:32 +0000 Subject: [PATCH 6/8] update --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce6e06e7fbf32..43eb32e3331f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,7 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) -- +- Update ImageClassification Task to the new DataModule API ([#1025](https://github.com/PyTorchLightning/pytorch-lightning/pull/1025)) - From 23f064fed3cf2f477e74d84df18b7a9970965b26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 15 Dec 2021 18:05:38 +0100 Subject: [PATCH 7/8] Update pytorch_lightning/loops/base.py Co-authored-by: Rohit Gupta --- pytorch_lightning/loops/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index b6b877e19f189..adebf0d8d43da 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -103,7 +103,7 @@ def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: """Optionally replace one or multiple of this loop's sub-loops. This methods takes care of instantiating the class (if necessary) with all existing arguments, connecting all - sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and conneting the new loop to + sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and connecting the new loop to the parent. Args: From a3018ac2c649d53a44360721d25ae400615cdce1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 15 Dec 2021 18:29:43 +0100 Subject: [PATCH 8/8] Plural and mypy --- pytorch_lightning/loops/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index adebf0d8d43da..b47d1935e89be 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -107,7 +107,7 @@ def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: the parent. Args: - **loops: A ``Loop`` subclass or instance. The name used should match the loop attribute name you want to + **loops: ``Loop`` subclasses or instances. The name used should match the loop attribute name you want to replace. Raises: @@ -130,7 +130,7 @@ def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: ) # instantiate the loop kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} - loop = type_or_object(**kwargs) + loop = type_or_object(**kwargs) # type: ignore[call-arg] else: loop = type_or_object