From d05363fca612e8cf8751dc5abac21c025bb5c9ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 04:53:23 +0200 Subject: [PATCH 001/104] improve spawn queue --- pl_examples/bug_report_model.py | 3 + pytorch_lightning/core/lightning.py | 4 +- .../plugins/training_type/ddp_spawn.py | 135 ++++++++++-------- .../plugins/training_type/sharded_spawn.py | 4 +- .../training_type/training_type_plugin.py | 16 +-- pytorch_lightning/trainer/trainer.py | 14 +- tests/plugins/test_ddp_spawn_plugin.py | 12 +- 7 files changed, 101 insertions(+), 87 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 270b0cd2abe8d..6a804c981033f 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -56,6 +56,9 @@ def run(): num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, + accelerator="cpu", + strategy="ddp_spawn", + devices=2, ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ca4b2af7eee17..7989d2bf87c65 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1992,7 +1992,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, queue: List[Any]) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -2006,7 +2006,7 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, queue: List[Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index c72cc7f31d0cc..8f8a3b14cb01c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import numpy as np import torch @@ -38,7 +38,7 @@ rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import distributed_available @@ -91,7 +91,6 @@ def __init__( self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 - self.mp_queue = None self._ddp_comm_state = ddp_comm_state self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper @@ -120,11 +119,11 @@ def sync_batchnorm(self, sync_batchnorm: bool) -> None: def local_rank(self) -> int: return self._local_rank - def __getstate__(self): - """Makes this plugin pickleable without destroying the queue in the current process.""" - state = self.__dict__.copy() - state["mp_queue"] = None - return state + # def __getstate__(self): + # """Makes this plugin pickleable without destroying the queue in the current process.""" + # state = self.__dict__.copy() + # state["mp_queue"] = None # TODO: is this anymoe needed? + # return state def __setstate__(self, state): self.__dict__ = state @@ -144,9 +143,6 @@ def _is_single_process_single_device(self): def setup(self) -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - # pass in a state q - smp = mp.get_context("spawn") - self.mp_queue = smp.SimpleQueue() def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" @@ -163,18 +159,24 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def start_training(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + def start_training(self, trainer: "pl.Trainer") -> Any: + best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) + self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] + return results def start_evaluating(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) + self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + return results def start_predicting(self, trainer: "pl.Trainer") -> None: - self.spawn(self.new_process, trainer, self.mp_queue) + best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) + self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + return results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Spawn processes that run the given function. Args: @@ -185,11 +187,18 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: These arguments must be pickleable. """ os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - mp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs()) + smp = mp.get_context("spawn") + mp_queue = smp.SimpleQueue() + mp.spawn(self._wrapped_function, args=(function, args, kwargs, mp_queue), nprocs=self.num_processes) + return mp_queue.get() - def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None: + def _wrapped_function( + self, process_idx: int, function: Callable, args: Any, kwargs: Any, mp_queue: SimpleQueue + ) -> None: self._worker_setup(process_idx) - function(*args, **kwargs) + result = function(*args, **kwargs) + if self.is_global_zero: + mp_queue.put(move_data_to_device(result, "cpu")) def _worker_setup(self, process_idx: int): reset_seed() @@ -197,9 +206,7 @@ def _worker_setup(self, process_idx: int): rank_zero_only.rank = self.global_rank init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size) - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue - + def new_process(self, trainer: "pl.Trainer") -> Any: # move the model to the correct device self.model_to_device() @@ -214,28 +221,16 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: self.barrier() results = trainer.run_stage() - - # persist info in ddp_spawn - self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + outputs = self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() + return outputs def post_dispatch(self, trainer: "pl.Trainer"): # restore main state with best weights - best_path = self.mp_queue.get() - last_path = self.mp_queue.get() - self._results = self.mp_queue.get() - # get the `callback_metrics` and set it to the trainer - # only in case the user does not override it. - # TODO: Remove the if in v1.7 - if is_overridden("get_from_queue", self.lightning_module): - self.lightning_module.get_from_queue(self.mp_queue) - else: - self.get_from_queue(trainer, self.mp_queue) - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(best_path, last_path) + pass def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` @@ -276,34 +271,40 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __transfer_distrib_spawn_state_on_fit_end( + self, trainer: "pl.Trainer", results: Any + ) -> Optional[Tuple[Optional[str], Optional[str], Any, List[Any]]]: + checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.global_rank == 0 and self.mp_queue is not None: - rank_zero_warn("cleaning up ddp environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(state_dict, last_path) - - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - # adds the `callback_metrics` to the queue - # TODO: Remove the if in v1.7 - if is_overridden("add_to_queue", self.lightning_module): - self.lightning_module.add_to_queue(self.mp_queue) - else: - self.add_to_queue(trainer, self.mp_queue) - - def __recover_child_process_weights(self, best_path, last_path): + if not self.is_global_zero: + return + + rank_zero_warn("cleaning up ddp environment...") + + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + atomic_save(state_dict, last_path) + + extra = [] + # adds the `callback_metrics` to the queue + # TODO: Remove the if in v1.7 + if is_overridden("add_to_queue", self.lightning_module): + self.lightning_module.add_to_queue(extra) + else: + self.add_to_queue(trainer, extra) + + return best_model_path, last_path, results, extra + + def __recover_child_process_weights( + self, best_path: Optional[str], last_path: Optional[str], extra: List[Any], trainer + ) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path @@ -314,6 +315,14 @@ def __recover_child_process_weights(self, best_path, last_path): ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) + # get the `callback_metrics` and set it to the trainer + # only in case the user does not override it. + # TODO: Remove the if in v1.7 + if is_overridden("get_from_queue", self.lightning_module): + self.lightning_module.get_from_queue(extra) + else: + self.get_from_queue(trainer, extra) + def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return @@ -379,7 +388,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -389,9 +398,9 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.Simpl callback_metrics: dict = apply_to_collection( trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() ) # send as numpy to avoid issues with memory sharing - queue.put(callback_metrics) + queue.append(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.SimpleQueue) -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -399,7 +408,7 @@ def get_from_queue(self, trainer: "pl.Trainer", queue: torch.multiprocessing.Sim queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before - callback_metrics: dict = queue.get() + callback_metrics: dict = queue.pop(0) trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) @classmethod diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 78b54d029a5f6..9fd55a4e79bc0 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -101,13 +101,13 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: + def new_process(self, trainer: "pl.Trainer") -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process precision_plugin = trainer.accelerator.precision_plugin if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): precision_plugin.scaler = ShardedGradScaler() - return super().new_process(trainer, mp_queue) + return super().new_process(trainer) @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 481b9ee1c4087..39227b0008d84 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -35,7 +35,6 @@ class TrainingTypePlugin(ABC): def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: self._model: Optional[Module] = None - self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io @@ -188,7 +187,8 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: start-method send the result to the master process through a `multiprocessing queue (shared memory) `_. """ - return self._results + # TODO: deprecate this + return None def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() @@ -202,17 +202,17 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) - def start_training(self, trainer: "pl.Trainer") -> None: + def start_training(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the training loop - self._results = trainer.run_stage() + return trainer.run_stage() - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the test loop - self._results = trainer.run_stage() + return trainer.run_stage() - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: # double dispatch to initiate the predicting loop - self._results = trainer.run_stage() + return trainer.run_stage() def training_step(self, *args, **kwargs): return self.model.training_step(*args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f9c18d0a8462f..4548ccc8702bc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1097,9 +1097,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.checkpoint_connector.restore_training_state() # dispatch `start_training` or `start_evaluating` or `start_predicting` - self._dispatch() + results = self._dispatch() - # plugin will finalized fitting (e.g. ddp_spawn will load trained model) + # TODO: needed? self._post_dispatch() # ---------------------------- @@ -1118,7 +1118,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.state.status = TrainerStatus.FINISHED self.state.stage = None - return self.training_type_plugin.results + return results def _pre_dispatch(self): self.accelerator.pre_dispatch(self) @@ -1170,13 +1170,13 @@ def _post_dispatch(self): self._active_loop.teardown() self.logger_connector.teardown() - def _dispatch(self): + def _dispatch(self) -> Any: if self.evaluating: - self.training_type_plugin.start_evaluating(self) + return self.training_type_plugin.start_evaluating(self) elif self.predicting: - self.training_type_plugin.start_predicting(self) + return self.training_type_plugin.start_predicting(self) else: - self.training_type_plugin.start_training(self) + return self.training_type_plugin.start_training(self) def run_stage(self): self.accelerator.dispatch(self) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index c389cf9290c78..804cf3d9e2ee0 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Any + import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -38,11 +40,11 @@ def validation_step(self, batch, batch_idx): return super().validation_step(batch, batch_idx) def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.put("test_val") + queue.append("test_val") return super().add_to_queue(queue) def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - self.test_val = queue.get() + self.test_val = queue.pop(0) return super().get_from_queue(queue) @@ -83,11 +85,11 @@ def test_ddp_spawn_extra_parameters(tmpdir): class TestDDPSpawnPlugin(DDPSpawnPlugin): def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.put("new_test_val") + queue.append("new_test_val") return super().add_to_queue(trainer, queue) - def get_from_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: - self.new_test_val = queue.get() + def get_from_queue(self, trainer: Trainer, queue: List[Any]) -> None: + self.new_test_val = queue.pop(0) return super().get_from_queue(trainer, queue) From d650e2641e7649363b2ec7a8192b60f686fff2da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 05:04:10 +0200 Subject: [PATCH 002/104] clean up --- .../plugins/training_type/ddp_spawn.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8f8a3b14cb01c..517ea4042767d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -119,10 +119,11 @@ def sync_batchnorm(self, sync_batchnorm: bool) -> None: def local_rank(self) -> int: return self._local_rank + # TODO: this should no longer be needed # def __getstate__(self): # """Makes this plugin pickleable without destroying the queue in the current process.""" # state = self.__dict__.copy() - # state["mp_queue"] = None # TODO: is this anymoe needed? + # state["mp_queue"] = None # return state def __setstate__(self, state): @@ -142,6 +143,7 @@ def _is_single_process_single_device(self): return True def setup(self) -> None: + # TODO: is this needed here? already getting set in spawn() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) def _setup_model(self, model: Module) -> DistributedDataParallel: @@ -227,11 +229,6 @@ def new_process(self, trainer: "pl.Trainer") -> Any: trainer._call_teardown_hook() return outputs - def post_dispatch(self, trainer: "pl.Trainer"): - # restore main state with best weights - - pass - def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases @@ -292,10 +289,10 @@ def __transfer_distrib_spawn_state_on_fit_end( last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(state_dict, last_path) - extra = [] # adds the `callback_metrics` to the queue - # TODO: Remove the if in v1.7 + extra = [] if is_overridden("add_to_queue", self.lightning_module): + # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) else: self.add_to_queue(trainer, extra) @@ -316,9 +313,9 @@ def __recover_child_process_weights( self.lightning_module.load_state_dict(ckpt) # get the `callback_metrics` and set it to the trainer - # only in case the user does not override it. - # TODO: Remove the if in v1.7 if is_overridden("get_from_queue", self.lightning_module): + # only in case the user does not override it. + # TODO: Remove the if in v1.7 self.lightning_module.get_from_queue(extra) else: self.get_from_queue(trainer, extra) @@ -393,6 +390,7 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: sharing, we cast the data to numpy. Args: + trainer: reference to the Trainer. queue: the instance of the queue to append the data. """ callback_metrics: dict = apply_to_collection( @@ -405,6 +403,7 @@ def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: we cast back the data to ``torch.Tensor``. Args: + trainer: reference to the Trainer. queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before From 5fda23a99a7d3a5501e064735bdda18fe9fe8a6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Oct 2021 03:10:00 +0000 Subject: [PATCH 003/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- tests/plugins/test_ddp_spawn_plugin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 517ea4042767d..8802a5d016433 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 804cf3d9e2ee0..3b293ea315f59 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Any +from typing import Any, List import torch from torch.nn.parallel.distributed import DistributedDataParallel From bcfb853fa6ad7ed5318c6e37eda569532cd44d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:04:55 +0100 Subject: [PATCH 004/104] fix --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 9f5b193dcefd6..4862518a872e2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -125,7 +125,7 @@ def distributed_sampler_kwargs(self): @property def _is_single_process_single_device(self): return True - + def setup(self, trainer: "pl.Trainer") -> None: # TODO: is this needed here? already getting set in spawn() os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) @@ -173,8 +173,6 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: function: The function to spawn processes from. *args: Optional positional arguments that will be passed to the function in addition to the process index. These arguments must be pickleable. - return_result: If ``True``, copies the output of the function from process 0 to the main process and - returns it. **kwargs: Optional named arguments that will be passed to the function in addition to the process index. These arguments must be pickleable. @@ -183,9 +181,9 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """ os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) context = mp.get_context("spawn") - return_queue = context.SimpleQueue() if return_result else None + return_queue = context.SimpleQueue() mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes) - return mp_queue.get() + return return_queue.get() def _wrapped_function( self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue From 97b4bf6046f2bbfe55431ad0d4110645103c3093 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 14:09:24 +0000 Subject: [PATCH 005/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/plugins/test_ddp_spawn_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index f85545e83ea83..4f165fe963a46 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List + import pytest import torch from torch.nn.parallel.distributed import DistributedDataParallel From 38b3a548731bbd9066ae994b26e0ac0434b19e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:14:18 +0100 Subject: [PATCH 006/104] rename --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 4862518a872e2..5214f9a12ce45 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -216,7 +216,7 @@ def new_process(self, trainer: "pl.Trainer") -> Any: self.barrier() results = trainer.run_stage() - outputs = self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + outputs = self.__collect_rank_zero_results(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() @@ -259,7 +259,7 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __transfer_distrib_spawn_state_on_fit_end( + def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any ) -> Optional[Tuple[Optional[str], Optional[str], Any, List[Any]]]: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5ef8a46d7127f..7d36dbfef8d33 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -188,7 +188,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: results = trainer.run_stage() - self.__transfer_distrib_spawn_state_on_fit_end(trainer, results) + self.__collect_rank_zero_results(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") @@ -207,7 +207,7 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> None: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None From 955b6c8f6c66bc6a2578ef919187a98bad1fea61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:14:24 +0100 Subject: [PATCH 007/104] delete dead code --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5214f9a12ce45..efe8df41d570e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -103,16 +103,6 @@ def num_nodes(self, num_nodes: int) -> None: def local_rank(self) -> int: return self._local_rank - # TODO: this should no longer be needed - # def __getstate__(self): - # """Makes this plugin pickleable without destroying the queue in the current process.""" - # state = self.__dict__.copy() - # state["mp_queue"] = None - # return state - - def __setstate__(self, state): - self.__dict__ = state - @property def root_device(self): return self.parallel_devices[self.local_rank] From f3216b21b49d04944912f546df1230f64043f9ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:24:41 +0100 Subject: [PATCH 008/104] clean up --- .../plugins/training_type/training_type_plugin.py | 13 ------------- pytorch_lightning/trainer/trainer.py | 1 - 2 files changed, 14 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 9923500574bfb..0dcc31c4482d6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -287,19 +287,6 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return unwrap_lightning_module(self._model) if self._model is not None else None - @property - def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - """Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. - - The result is - cached instead of returned directly, because some plugins require transmitting the results from one - multiprocessing context to another in a separate step. For example, the plugins that use the "spawn" - start-method send the result to the main process through a - `multiprocessing queue (shared memory) `_. - """ - # TODO: deprecate this - return None - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 87d71db7576d4..f2ec7ef4b7332 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1162,7 +1162,6 @@ def _run( # dispatch `start_training` or `start_evaluating` or `start_predicting` results = self._dispatch() - # TODO: needed? self._post_dispatch() # ---------------------------- From 2d00231acb52eb8a6f55890ff90ce5303995d80f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:26:11 +0100 Subject: [PATCH 009/104] update lite --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index fede7f5df7291..3a6a814ce9200 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -403,7 +403,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: run_method = partial(self._run_with_sharded_context, run_method) if isinstance(self._strategy, DDPSpawnPlugin): - return self._strategy.spawn(run_method, *args, return_result=True, **kwargs) + return self._strategy.spawn(run_method, *args, **kwargs) else: return run_method(*args, **kwargs) From 7aa36461c05632f016ddcf07f9617ba49dad9f06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:40:28 +0100 Subject: [PATCH 010/104] retain the queue interface in hooks --- pytorch_lightning/core/lightning.py | 5 +++-- .../plugins/training_type/ddp_spawn.py | 21 +++++++++++++++---- tests/plugins/test_ddp_spawn_plugin.py | 6 +++--- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 814caa3e6cf6d..3f8a547149f1a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,6 +36,7 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO +from pytorch_lightning.plugins.training_type.ddp_spawn import _SimpleQueue from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import ( _IS_WINDOWS, @@ -1928,7 +1929,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: List[Any]) -> None: + def add_to_queue(self, queue: _SimpleQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1942,7 +1943,7 @@ def add_to_queue(self, queue: List[Any]) -> None: if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: List[Any]) -> None: + def get_from_queue(self, queue: _SimpleQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index efe8df41d570e..cba316e2c514e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -366,7 +366,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -377,9 +377,9 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: callback_metrics: dict = apply_to_collection( trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() ) # send as numpy to avoid issues with memory sharing - queue.append(callback_metrics) + queue.put(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -388,7 +388,7 @@ def get_from_queue(self, trainer: "pl.Trainer", queue: List[Any]) -> None: queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before - callback_metrics: dict = queue.pop(0) + callback_metrics: dict = queue.get() trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) @classmethod @@ -422,3 +422,16 @@ def _clean_logger(trainer: "pl.Trainer") -> None: # we want to make sure these are closed before we spawn our own threads. # assuming nothing else references the experiment object, python should instantly `__del__` it. logger._experiment = None + + +class _SimpleQueue(list): + """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` using the Python list interface.""" + + def get(self) -> Any: + return self.pop(0) + + def put(self, item: Any) -> None: + self.append(item) + + def empty(self) -> bool: + return len(self) == 0 diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 4f165fe963a46..21736837d4ed1 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -86,12 +86,12 @@ def test_ddp_spawn_extra_parameters(tmpdir): class TestDDPSpawnPlugin(DDPSpawnPlugin): - def add_to_queue(self, trainer: Trainer, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.append("new_test_val") + def add_to_queue(self, trainer, queue) -> None: + queue.put("new_test_val") return super().add_to_queue(trainer, queue) def get_from_queue(self, trainer: Trainer, queue: List[Any]) -> None: - self.new_test_val = queue.pop(0) + self.new_test_val = queue.get() return super().get_from_queue(trainer, queue) From fb0c0d8bb4f35e8a59f01527d4c18b052f8b40be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:42:13 +0100 Subject: [PATCH 011/104] update tests --- tests/plugins/test_ddp_spawn_plugin.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 21736837d4ed1..6ea265a4bb575 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List import pytest import torch @@ -40,12 +39,12 @@ def validation_step(self, batch, batch_idx): self.log(self.name, self.val) return super().validation_step(batch, batch_idx) - def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - queue.append("test_val") + def add_to_queue(self, queue) -> None: + queue.put("test_val") return super().add_to_queue(queue) - def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - self.test_val = queue.pop(0) + def get_from_queue(self, queue) -> None: + self.test_val = queue.get() return super().get_from_queue(queue) @@ -90,7 +89,7 @@ def add_to_queue(self, trainer, queue) -> None: queue.put("new_test_val") return super().add_to_queue(trainer, queue) - def get_from_queue(self, trainer: Trainer, queue: List[Any]) -> None: + def get_from_queue(self, trainer: Trainer, queue) -> None: self.new_test_val = queue.get() return super().get_from_queue(trainer, queue) From 7e6c75ea5ee5ac4bebd909e033ac719c63e50c6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:43:47 +0100 Subject: [PATCH 012/104] _notebooks --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index a2fb6468112b7..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc From b7efc5052b156aec31156d53d06413672b63367e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:44:06 +0100 Subject: [PATCH 013/104] reset notebooks --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..0c325829101d5 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 0c325829101d5a6ebf32ed99bbf5b09badf04a59 From 84ca8b4ad50230da1fdbd314853cc30ff20c274d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 15:58:14 +0100 Subject: [PATCH 014/104] avoid circular import --- pytorch_lightning/core/lightning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3f8a547149f1a..736e8f5a9b560 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -36,7 +36,6 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO -from pytorch_lightning.plugins.training_type.ddp_spawn import _SimpleQueue from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import ( _IS_WINDOWS, @@ -1929,7 +1928,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: _SimpleQueue) -> None: + def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1943,7 +1942,7 @@ def add_to_queue(self, queue: _SimpleQueue) -> None: if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: _SimpleQueue) -> None: + def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. From 965c7245aa7785012ec671b5e3963409ab846235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:04:29 +0100 Subject: [PATCH 015/104] fix unused imports --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 1 - pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index e79c0aa9a1c24..c9a968fa94fbd 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from multiprocessing.queues import SimpleQueue from typing import Dict, Generator, List, Optional, Tuple import torch diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index f97bec44bc0ea..ef5f78a1e09f3 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -30,7 +30,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import _PATH TBroadcast = TypeVar("TBroadcast") From 1aae8ddc78848fda16215b92cc1e83e00966d401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:08:11 +0100 Subject: [PATCH 016/104] reset debugging script --- pl_examples/bug_report/bug_report_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 5e43eeca17308..7739630237d32 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -57,9 +57,6 @@ def run(): num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, - accelerator="cpu", - strategy="ddp_spawn", - devices=2, ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) trainer.test(model, dataloaders=test_data) From 4b998db77a32c1fc5b100ed0a618fe305cb843ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:20:43 +0100 Subject: [PATCH 017/104] typing _ExtraQueue --- pytorch_lightning/core/lightning.py | 4 ++-- .../plugins/training_type/ddp_spawn.py | 17 ++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 736e8f5a9b560..68611b93079c0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1928,7 +1928,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: + def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1942,7 +1942,7 @@ def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) - if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._SimpleQueue) -> None: + def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index cba316e2c514e..59849bc08c389 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -117,7 +117,6 @@ def _is_single_process_single_device(self): return True def setup(self, trainer: "pl.Trainer") -> None: - # TODO: is this needed here? already getting set in spawn() os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) super().setup(trainer) @@ -191,7 +190,7 @@ def _worker_setup(self, process_idx: int): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Any: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: # move the model to the correct device self.model_to_device() @@ -251,7 +250,7 @@ def determine_ddp_device_ids(self): def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, List[Any]]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -271,7 +270,7 @@ def __collect_rank_zero_results( atomic_save(state_dict, last_path) # adds the `callback_metrics` to the queue - extra = [] + extra = _ExtraQueue() if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) @@ -281,7 +280,7 @@ def __collect_rank_zero_results( return best_model_path, last_path, results, extra def __recover_child_process_weights( - self, best_path: Optional[str], last_path: Optional[str], extra: List[Any], trainer + self, best_path: Optional[str], last_path: Optional[str], extra: "_ExtraQueue", trainer ) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: @@ -366,7 +365,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -379,7 +378,7 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: ) # send as numpy to avoid issues with memory sharing queue.put(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: "_SimpleQueue") -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -424,8 +423,8 @@ def _clean_logger(trainer: "pl.Trainer") -> None: logger._experiment = None -class _SimpleQueue(list): - """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` using the Python list interface.""" +class _ExtraQueue(list): + """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" def get(self) -> Any: return self.pop(0) From 5871a4bacc64a868744618764ef341bb2a7ac6f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 16:55:38 +0100 Subject: [PATCH 018/104] bring changes to tpu_spawn plugin --- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 51 +++++++++++-------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 59849bc08c389..4eac619a08888 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -258,7 +258,7 @@ def __collect_rank_zero_results( # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if not self.is_global_zero: + if self.local_rank != 0: return rank_zero_warn("cleaning up ddp environment...") diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7d36dbfef8d33..d3f03eb122ebc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -16,7 +16,7 @@ import re import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Tuple import torch import torch.multiprocessing as mp @@ -28,7 +28,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters @@ -207,28 +207,37 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> None: + # TODO: this implementation is identical to the one in the super class, up to the self.save() call + def __collect_rank_zero_results( + self, trainer: "pl.Trainer", results: Any + ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: + checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.mp_queue is not None: - rank_zero_warn("cleaning up tpu spawn environment...") + if self.local_rank != 0: + return + + rank_zero_warn("cleaning up ddp environment...") - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(state_dict, last_path) + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.save(state_dict, last_path) - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + # adds the `callback_metrics` to the queue + extra = _ExtraQueue() + if is_overridden("add_to_queue", self.lightning_module): + # TODO: Remove the if in v1.7 + self.lightning_module.add_to_queue(extra) + else: + self.add_to_queue(trainer, extra) + + return best_model_path, last_path, results, extra def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path) @@ -275,18 +284,18 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: context = mp.get_context(self.start_method or "fork") - return_queue = context.SimpleQueue() if return_result else None + return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) - return return_queue.get() if return_result else None + return return_queue.get() def _wrapped_function( - self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: Optional[SimpleQueue] + self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue ) -> None: self._worker_setup(process_idx) result = function(*args, **kwargs) - if return_queue is not None and self.local_rank == 0: + if self.local_rank == 0: return_queue.put(move_data_to_device(result, "cpu")) self.barrier("end-process") From aa76840fdffad9047ee5ff9a4474df87d9ffb26b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:18:05 +0100 Subject: [PATCH 019/104] unify --- .../plugins/training_type/ddp_spawn.py | 2 +- .../plugins/training_type/tpu_spawn.py | 32 ------------------- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 4eac619a08888..8ab51f3fb0ee2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -267,7 +267,7 @@ def __collect_rank_zero_results( last_path = None if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(state_dict, last_path) + self.save_checkpoint(state_dict, last_path) # adds the `callback_metrics` to the queue extra = _ExtraQueue() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d3f03eb122ebc..da1ed20dd405c 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -207,38 +207,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - # TODO: this implementation is identical to the one in the super class, up to the self.save() call - def __collect_rank_zero_results( - self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: - - checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None - - # requires to compute the state_dict on all processes in case Metrics are present - state_dict = self.lightning_module.state_dict() - - if self.local_rank != 0: - return - - rank_zero_warn("cleaning up ddp environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.save(state_dict, last_path) - - # adds the `callback_metrics` to the queue - extra = _ExtraQueue() - if is_overridden("add_to_queue", self.lightning_module): - # TODO: Remove the if in v1.7 - self.lightning_module.add_to_queue(extra) - else: - self.add_to_queue(trainer, extra) - - return best_model_path, last_path, results, extra - def save(self, state_dict: Dict, path: _PATH) -> None: xm.save(state_dict, path) From 37f9db9f9082313c388d0adb45804b112f697917 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:18:44 +0100 Subject: [PATCH 020/104] remove dead code --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index da1ed20dd405c..ec773f1fc70ac 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -207,9 +207,6 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def save(self, state_dict: Dict, path: _PATH) -> None: - xm.save(state_dict, path) - def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj From d68cb35abd187c5abf80b29b3b449b34608a03e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 17:20:17 +0000 Subject: [PATCH 021/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ec773f1fc70ac..1fec30cc5fb9a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -16,7 +16,7 @@ import re import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.multiprocessing as mp @@ -28,7 +28,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue +from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters From dd80be94e74fb190c1442b41633bd2a1aa98ea7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:25:05 +0100 Subject: [PATCH 022/104] remove queue from tpu spawn --- .../plugins/training_type/tpu_spawn.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ec773f1fc70ac..01f3c8dac9134 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -122,7 +122,7 @@ def pre_dispatch(self): os.environ["PT_XLA_DEBUG"] = str(1) def setup(self, trainer: "pl.Trainer") -> None: - self.create_mp_queue() + self.start_method = "fork" if not self.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) self.setup_precision_plugin() @@ -138,11 +138,6 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: def _setup_model(self, model: Module) -> Module: return model - def create_mp_queue(self): - self.start_method = "fork" - smp = mp.get_context(self.start_method) - self.mp_queue = smp.SimpleQueue() - @property def distributed_sampler_kwargs(self) -> Dict[str, int]: return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) @@ -168,9 +163,7 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue - + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -188,7 +181,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: results = trainer.run_stage() - self.__collect_rank_zero_results(trainer, results) + outputs = self.__collect_rank_zero_results(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") @@ -199,6 +192,7 @@ def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() + return outputs def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) From f97eee894fa63c217019eb0b7687a6dc99656fc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:25:22 +0100 Subject: [PATCH 023/104] type annotation for new_process --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index c9a968fa94fbd..b6dff138803ed 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Optional, Tuple, Any import torch from torch.nn import Module @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -114,7 +114,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> None: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): From 459121ebbc5a74d5cb39f7acefb2b66e94f6e97d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 17:27:05 +0000 Subject: [PATCH 024/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index b6dff138803ed..55b9253d3101c 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple, Any +from typing import Any, Dict, Generator, List, Optional, Tuple import torch from torch.nn import Module @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _ExtraQueue +from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType From 72535ff34d5de83c1bf87a02a61d383ea338ea28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:30:13 +0100 Subject: [PATCH 025/104] unused imports --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 - pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8ab51f3fb0ee2..b9aa02bed397b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -35,7 +35,6 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device -from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as _group diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 0ed9563177274..e81cb5bfbe16b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -13,7 +13,6 @@ # limitations under the License. import io import os -import re import time from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -30,8 +29,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters +from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp From 61192df022c6c95ec6b65c872068c41299a0f2a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 18:59:30 +0100 Subject: [PATCH 026/104] move check --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b9aa02bed397b..370e9ceda1525 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -257,9 +257,6 @@ def __collect_rank_zero_results( # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.local_rank != 0: - return - rank_zero_warn("cleaning up ddp environment...") # save the last weights @@ -268,6 +265,9 @@ def __collect_rank_zero_results( last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.save_checkpoint(state_dict, last_path) + if self.local_rank != 0: + return + # adds the `callback_metrics` to the queue extra = _ExtraQueue() if is_overridden("add_to_queue", self.lightning_module): From 801f529dbc8698d5d3b2a197c364f49fbeae4274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:05:30 +0100 Subject: [PATCH 027/104] revert --- .../plugins/training_type/tpu_spawn.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index e81cb5bfbe16b..58585b9b79a52 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -199,6 +199,29 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) + def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + checkpoint_callback = trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + + # requires to compute the state_dict on all processes in case Metrics are present + state_dict = self.lightning_module.state_dict() + + if self.mp_queue is not None: + rank_zero_warn("cleaning up tpu spawn environment...") + + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) + + if self.local_rank == 0: + # todo, pass complete checkpoint as state dictionary + self.mp_queue.put(best_model_path) + self.mp_queue.put(last_path) + self.mp_queue.put(results) + self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: return obj From 1cd258b70ad3c9c0869b74e71093f0ff4a24cbbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:13:37 +0100 Subject: [PATCH 028/104] collect results on tpu --- .../plugins/training_type/tpu_spawn.py | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 58585b9b79a52..64926f00af21a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -13,6 +13,7 @@ # limitations under the License. import io import os +import re import time from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -29,7 +30,8 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp @@ -199,28 +201,35 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None: + def __collect_rank_zero_results( + self, trainer: "pl.Trainer", results: Any + ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() - if self.mp_queue is not None: - rank_zero_warn("cleaning up tpu spawn environment...") - - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) - - if self.local_rank == 0: - # todo, pass complete checkpoint as state dictionary - self.mp_queue.put(best_model_path) - self.mp_queue.put(last_path) - self.mp_queue.put(results) - self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue + rank_zero_warn("cleaning up tpu spawn environment...") + + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) + + if self.local_rank != 0: + return + + # adds the `callback_metrics` to the queue + extra = _ExtraQueue() + if is_overridden("add_to_queue", self.lightning_module): + # TODO: Remove the if in v1.7 + self.lightning_module.add_to_queue(extra) + else: + self.add_to_queue(trainer, extra) + + return best_model_path, last_path, results, extra def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: From 10ecbfd0da11d08fc6e12932b9153c2ae1352396 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:37:11 +0100 Subject: [PATCH 029/104] rename --- .../plugins/training_type/ddp_spawn.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3781e40b8f343..e52d403de0ffb 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -134,19 +134,19 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st def start_training(self, trainer: "pl.Trainer") -> Any: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] return results def start_evaluating(self, trainer: "pl.Trainer") -> None: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results def start_predicting(self, trainer: "pl.Trainer") -> None: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_child_process_weights(best_model_path, last_path, extra, trainer) + self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: @@ -254,11 +254,11 @@ def __collect_rank_zero_results( rank_zero_warn("cleaning up ddp environment...") - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) if self.local_rank != 0: return @@ -273,7 +273,7 @@ def __collect_rank_zero_results( return best_model_path, last_path, results, extra - def __recover_child_process_weights( + def __recover_results_in_main_process( self, best_path: Optional[str], last_path: Optional[str], extra: "_ExtraQueue", trainer ) -> None: # transfer back the best path to the trainer From ebba63f4be7ad1a8f041c735938207adc6968ab9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Nov 2021 22:38:29 +0000 Subject: [PATCH 030/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index bbfd894db82ca..cf46f2224b437 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -32,7 +32,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters, rank_zero_warn +from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp From d7df4d93be44783a3695b6218d0b931bd39facf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:48:38 +0100 Subject: [PATCH 031/104] fix merge errors --- .../plugins/training_type/ddp_spawn.py | 6 +++--- .../plugins/training_type/tpu_spawn.py | 21 ++++++------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index e52d403de0ffb..393a93ea69a61 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -252,6 +252,9 @@ def __collect_rank_zero_results( # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() + if self.local_rank != 0: + return + rank_zero_warn("cleaning up ddp environment...") # save the last weights @@ -260,9 +263,6 @@ def __collect_rank_zero_results( last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.checkpoint_io.save_checkpoint(state_dict, last_path) - if self.local_rank != 0: - return - # adds the `callback_metrics` to the queue extra = _ExtraQueue() if is_overridden("add_to_queue", self.lightning_module): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index cf46f2224b437..abe297cdb3b58 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -204,7 +204,7 @@ def barrier(self, name: Optional[str] = None) -> None: def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -213,11 +213,11 @@ def __collect_rank_zero_results( rank_zero_warn("cleaning up tpu spawn environment...") - # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + # save the last weights + last_path = None + if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: + last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) + self.checkpoint_io.save_checkpoint(state_dict, last_path) if self.local_rank != 0: return @@ -303,17 +303,8 @@ def start_training(self, trainer: "pl.Trainer") -> None: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] - self._clean_logger(trainer) return super().start_training(trainer) - def start_evaluating(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) - return super().start_evaluating(trainer) - - def start_predicting(self, trainer: "pl.Trainer") -> None: - self._clean_logger(trainer) - return super().start_predicting(trainer) - def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) From 4c547aa95b9a1cd2669a7e66da9749aecb9ed255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 23:49:07 +0100 Subject: [PATCH 032/104] fix merge errors --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index abe297cdb3b58..2dfe9709cc1cf 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -380,13 +380,3 @@ def checkpoint_io(self) -> CheckpointIO: @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") - - @staticmethod - def _clean_logger(trainer: "pl.Trainer") -> None: - loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] - for logger in loggers: - if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: - # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. - # we want to make sure these are closed before we spawn our own threads. - # assuming nothing else references the experiment object, python should instantly `__del__` it. - logger._experiment = None From e4e2a771f195fb445e1916b2b9df02579fdd707b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 04:01:07 +0100 Subject: [PATCH 033/104] re-add clean_logger --- .../plugins/training_type/tpu_spawn.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2dfe9709cc1cf..ac69b501993ef 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -303,8 +303,17 @@ def start_training(self, trainer: "pl.Trainer") -> None: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] + self._clean_logger(trainer) return super().start_training(trainer) + def start_evaluating(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_evaluating(trainer) + + def start_predicting(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) + return super().start_predicting(trainer) + def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -380,3 +389,13 @@ def checkpoint_io(self) -> CheckpointIO: @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") + + @staticmethod + def _clean_logger(trainer: "pl.Trainer") -> None: + loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] + for logger in loggers: + if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: + # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. + # we want to make sure these are closed before we spawn our own threads. + # assuming nothing else references the experiment object, python should instantly `__del__` it. + logger._experiment = None From acac29db559dc8a29d5295c2c93be7d1f54e7d14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 15:28:38 +0100 Subject: [PATCH 034/104] fix typing --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 393a93ea69a61..9a7439c42990d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -149,7 +149,7 @@ def start_predicting(self, trainer: "pl.Trainer") -> None: self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: """Spawn processes that run the given function. Args: From 880c8fc8db0a013cb342004fea96ab7cda42e821 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 17:28:09 +0100 Subject: [PATCH 035/104] changelog entries --- CHANGELOG.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 629b28e392792..0e1a7993aa687 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,7 +80,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649)) -- +- The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + +- The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + ### Deprecated @@ -188,6 +192,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed argument `return_result` from the `DDPSpawnPlugin.spawn()` method ([#10867](https://github.com/PyTorchLightning/pytorch-lightning/pull/10867)) +- Removed the property `TrainingTypePlugin.results` and corresponding properties in subclasses ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + +- Removed the `mp_queue` attribute from `DDPSpawnPlugin` and `TPUSpawnPlugin` ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) + + ### Fixed - Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611)) From 7520adcce70d9abdbe23a7a28059a9f68a1e49e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Dec 2021 17:34:21 +0000 Subject: [PATCH 036/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9396c19060833..1f73595c37418 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -217,7 +217,7 @@ def __collect_rank_zero_results( if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) self.checkpoint_io.save_checkpoint(state_dict, last_path) - + # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training if self.local_rank != 0: return From 96f2749cea3d9dbf8dd0fe0b6509642fd01f0b24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 21:11:53 +0100 Subject: [PATCH 037/104] rename _ExtraQueue -> _FakeQueue --- pytorch_lightning/core/lightning.py | 4 ++-- .../plugins/training_type/ddp_spawn.py | 14 +++++++------- .../plugins/training_type/sharded_spawn.py | 4 ++-- .../plugins/training_type/tpu_spawn.py | 8 ++++---- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6ebc320a12e19..e02c9d32ecb80 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1917,7 +1917,7 @@ def model_size(self) -> float: ) return get_model_size_mb(self) - def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: + def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -1931,7 +1931,7 @@ def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) - def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._ExtraQueue) -> None: + def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 89250120a3711..7003c2617037f 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -184,7 +184,7 @@ def _worker_setup(self, process_idx: int): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: # move the model to the correct device self.model_to_device() @@ -244,7 +244,7 @@ def determine_ddp_device_ids(self): def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -262,7 +262,7 @@ def __collect_rank_zero_results( self.checkpoint_io.save_checkpoint(state_dict, last_path) # adds the `callback_metrics` to the queue - extra = _ExtraQueue() + extra = _FakeQueue() if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) @@ -272,7 +272,7 @@ def __collect_rank_zero_results( return best_model_path, last_path, results, extra def __recover_results_in_main_process( - self, best_path: Optional[str], last_path: Optional[str], extra: "_ExtraQueue", trainer + self, best_path: Optional[str], last_path: Optional[str], extra: "_FakeQueue", trainer ) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: @@ -357,7 +357,7 @@ def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True - def add_to_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: + def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. @@ -370,7 +370,7 @@ def add_to_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: ) # send as numpy to avoid issues with memory sharing queue.put(callback_metrics) - def get_from_queue(self, trainer: "pl.Trainer", queue: "_ExtraQueue") -> None: + def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -402,7 +402,7 @@ def teardown(self) -> None: torch.cuda.empty_cache() -class _ExtraQueue(list): +class _FakeQueue(list): """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" def get(self) -> Any: diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 55b9253d3101c..5e10155cc3ca5 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -114,7 +114,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1f73595c37418..4dcdb589150ca 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -29,7 +29,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _ExtraQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters @@ -164,7 +164,7 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_ExtraQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -204,7 +204,7 @@ def barrier(self, name: Optional[str] = None) -> None: def __collect_rank_zero_results( self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, _ExtraQueue]]: + ) -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -223,7 +223,7 @@ def __collect_rank_zero_results( return # adds the `callback_metrics` to the queue - extra = _ExtraQueue() + extra = _FakeQueue() if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) From 65d183c25113fcc43334865a57c092bbdd3cd841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 21:12:56 +0100 Subject: [PATCH 038/104] missing typing updates --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 7003c2617037f..3e9840e33c4ca 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -139,12 +139,12 @@ def start_training(self, trainer: "pl.Trainer") -> Any: trainer.optimizers = [] return results - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) return results diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4dcdb589150ca..7679fbffa8e50 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -299,18 +299,18 @@ def _worker_setup(self, process_idx: int): self.tpu_global_core_rank = xm.get_ordinal() rank_zero_only.rank = self.global_rank - def start_training(self, trainer: "pl.Trainer") -> None: + def start_training(self, trainer: "pl.Trainer") -> Any: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] self._clean_logger(trainer) return super().start_training(trainer) - def start_evaluating(self, trainer: "pl.Trainer") -> None: + def start_evaluating(self, trainer: "pl.Trainer") -> Any: self._clean_logger(trainer) return super().start_evaluating(trainer) - def start_predicting(self, trainer: "pl.Trainer") -> None: + def start_predicting(self, trainer: "pl.Trainer") -> Any: self._clean_logger(trainer) return super().start_predicting(trainer) From 8c4e2e49a229794846c015fe414c6e04a4fce8d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Dec 2021 22:25:35 +0100 Subject: [PATCH 039/104] Introducing NamedTuple for spawn output typing --- .../plugins/training_type/ddp_spawn.py | 57 ++++++++++--------- .../plugins/training_type/tpu_spawn.py | 10 ++-- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 3e9840e33c4ca..f42edd0ae7763 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union, NamedTuple import numpy as np import torch @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT, _PATH if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook @@ -133,23 +133,23 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st return {"nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> Any: - best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] - return results + return spawn_output.trainer_results def start_evaluating(self, trainer: "pl.Trainer") -> Any: - best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) - return results + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results def start_predicting(self, trainer: "pl.Trainer") -> Any: - best_model_path, last_path, results, extra = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(best_model_path, last_path, extra, trainer) - return results + spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) + self.__recover_results_in_main_process(spawn_output, trainer) + return spawn_output.trainer_results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: """Spawn processes that run the given function. Args: @@ -184,7 +184,7 @@ def _worker_setup(self, process_idx: int): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: # move the model to the correct device self.model_to_device() @@ -242,9 +242,7 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __collect_rank_zero_results( - self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -269,28 +267,28 @@ def __collect_rank_zero_results( else: self.add_to_queue(trainer, extra) - return best_model_path, last_path, results, extra + return _SpawnOutput(best_model_path, last_path, results, extra) - def __recover_results_in_main_process( - self, best_path: Optional[str], last_path: Optional[str], extra: "_FakeQueue", trainer - ) -> None: + def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: - self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score + self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + # TODO: pass also best score # load last weights - if last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: - ckpt = self.checkpoint_io.load_checkpoint(last_path, map_location=(lambda storage, loc: storage)) + if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + ckpt = self.checkpoint_io.load_checkpoint( + spawn_output.last_path, map_location=(lambda storage, loc: storage) + ) self.lightning_module.load_state_dict(ckpt) # get the `callback_metrics` and set it to the trainer if is_overridden("get_from_queue", self.lightning_module): # only in case the user does not override it. # TODO: Remove the if in v1.7 - self.lightning_module.get_from_queue(extra) + self.lightning_module.get_from_queue(spawn_output.extra) else: - self.get_from_queue(trainer, extra) + self.get_from_queue(trainer, spawn_output.extra) def barrier(self, *args, **kwargs) -> None: if not distributed_available(): @@ -413,3 +411,10 @@ def put(self, item: Any) -> None: def empty(self) -> bool: return len(self) == 0 + + +class _SpawnOutput(NamedTuple): + best_model_path: Optional[_PATH] + last_path: Optional[_PATH] + trainer_results: Any + extra: _FakeQueue diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7679fbffa8e50..73b6e9f8a39b9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -29,7 +29,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin, _SpawnOutput from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters @@ -202,9 +202,7 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __collect_rank_zero_results( - self, trainer: "pl.Trainer", results: Any - ) -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: + def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -230,7 +228,7 @@ def __collect_rank_zero_results( else: self.add_to_queue(trainer, extra) - return best_model_path, last_path, results, extra + return _SpawnOutput(best_model_path, last_path, results, extra) def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: @@ -274,7 +272,7 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st "start_method": self.start_method, } - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) From 213b447278153e2eb52badd535933bcb551d5e2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Dec 2021 21:27:59 +0000 Subject: [PATCH 040/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index f42edd0ae7763..7620329b60e7b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Union, NamedTuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union import numpy as np import torch @@ -45,7 +45,7 @@ from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT, _PATH +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 73b6e9f8a39b9..b0284c88d6566 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -29,7 +29,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin, _SpawnOutput +from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, _SpawnOutput, DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, rank_zero_warn, set_shared_parameters From de4617fbb887d4b5c8812c4622bc393106d14ad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 05:00:27 +0100 Subject: [PATCH 041/104] remove post_dispatch --- pytorch_lightning/plugins/training_type/ddp.py | 4 +--- pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 + pytorch_lightning/plugins/training_type/dp.py | 1 + pytorch_lightning/plugins/training_type/horovod.py | 1 + pytorch_lightning/plugins/training_type/ipu.py | 1 + pytorch_lightning/plugins/training_type/parallel.py | 4 ++++ pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- pytorch_lightning/trainer/trainer.py | 12 +++++++----- 8 files changed, 17 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 6d1b168d5ac7a..5f6517f89c37f 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -355,9 +355,6 @@ def pre_dispatch(self): if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - def post_dispatch(self, trainer: "pl.Trainer") -> None: - self.cluster_environment.teardown() - def barrier(self, *args, **kwargs) -> None: if not distributed_available(): return @@ -495,6 +492,7 @@ def reconciliate_processes(self, trace: str) -> None: raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") def teardown(self) -> None: + super().teardown() if isinstance(self.model, DistributedDataParallel): self.model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 7620329b60e7b..ceebec58f6ac9 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -390,6 +390,7 @@ def register_plugins(cls, plugin_registry: Dict) -> None: ) def teardown(self) -> None: + super().teardown() if isinstance(self.model, DistributedDataParallel): self.model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index c12068c025860..32f19f51fd1b5 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -146,6 +146,7 @@ def test_step_end(self, output): return output def teardown(self) -> None: + super().teardown() if self.on_gpu: # GPU teardown self.lightning_module.cpu() diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 4aef238abb5db..82dbc0641665f 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -217,6 +217,7 @@ def _filter_named_parameters(model: nn.Module, optimizer: Optimizer) -> List[Tup return [(name, p) for name, p in model.named_parameters() if p in opt_params] def teardown(self) -> None: + super().teardown() if self.on_gpu: # GPU teardown self.lightning_module.cpu() diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index b072ea8437ea8..5874bacebce8d 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -272,6 +272,7 @@ def predict_step(self, *args, **kwargs): return self._step(RunningStage.PREDICTING, *args, **kwargs) def teardown(self) -> None: + super().teardown() # undo dataloader patching pl.trainer.data_loading._update_dataloader = self._update_dataloader_original diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 3a05455b87990..b4bf96a3a8861 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -132,3 +132,7 @@ def block_backward_sync(self): yield None else: yield None + + def teardown(self) -> None: + super().teardown() + self.cluster_environment.teardown() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index b0284c88d6566..c3302d3183d0e 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -368,7 +368,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra return xm.all_gather(tensor) def teardown(self) -> None: - # TPU teardown + super().teardown() os.environ.pop("PT_XLA_DEBUG", None) self.barrier("teardown") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8ffc4d41c86e7..ce4301a12b25a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1161,7 +1161,7 @@ def _run( # dispatch `start_training` or `start_evaluating` or `start_predicting` results = self._dispatch() - self._post_dispatch() + self._teardown() # ---------------------------- # POST-Training CLEAN UP @@ -1222,10 +1222,12 @@ def _log_hyperparams(self) -> None: self.logger.log_graph(self.lightning_module) self.logger.save() - def _post_dispatch(self): - self.accelerator.post_dispatch(self) - # these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns - # which need to happen before. + def _teardown(self): + """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback. + Those are handled by :meth:`_call_teardown_hook`. + """ + self.training_type_plugin.teardown() + # TODO: once accelerator is part of TTP, call teardown in TTP's teardown() method self.accelerator.teardown() self._data_connector.teardown() self._active_loop.teardown() From 815172efdee408dd96c9bbcf71373309f719e78a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 05:52:35 +0100 Subject: [PATCH 042/104] step 1 --- .../plugins/training_type/ddp_spawn.py | 37 ++++--------------- pytorch_lightning/trainer/trainer.py | 33 ++++++++--------- 2 files changed, 23 insertions(+), 47 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ceebec58f6ac9..95b551eeee8cc 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -132,23 +132,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None: def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return {"nprocs": self.num_processes} - def start_training(self, trainer: "pl.Trainer") -> Any: - spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) - # reset optimizers, since main process is never used for training and thus does not have a valid optim state - trainer.optimizers = [] - return spawn_output.trainer_results - - def start_evaluating(self, trainer: "pl.Trainer") -> Any: - spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) - return spawn_output.trainer_results - - def start_predicting(self, trainer: "pl.Trainer") -> Any: - spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) - return spawn_output.trainer_results - def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: """Spawn processes that run the given function. @@ -184,7 +167,7 @@ def _worker_setup(self, process_idx: int): self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size ) - def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: + def pre_dispatch(self) -> None: # move the model to the correct device self.model_to_device() @@ -196,15 +179,9 @@ def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: if trainer_fn == TrainerFn.FITTING: self.configure_ddp() + # TODO: needed? self.barrier() - results = trainer.run_stage() - outputs = self.__collect_rank_zero_results(trainer, results) - - # ensure that spawned processes go through teardown before joining - trainer._call_teardown_hook() - return outputs - def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases @@ -242,7 +219,7 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -269,14 +246,14 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op return _SpawnOutput(best_model_path, last_path, results, extra) - def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: + def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer - if self.lightning_module.trainer.checkpoint_callback: - self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + if trainer.checkpoint_callback: + trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path # TODO: pass also best score # load last weights - if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if spawn_output.last_path is not None and trainer.state.fn == TrainerFn.FITTING: ckpt = self.checkpoint_io.load_checkpoint( spawn_output.last_path, map_location=(lambda storage, loc: storage) ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ce4301a12b25a..30e44d14932a9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -19,6 +19,7 @@ import warnings from argparse import ArgumentParser, Namespace from datetime import timedelta +from functools import partial from pathlib import Path from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union from weakref import proxy @@ -48,6 +49,7 @@ TrainingTypePlugin, ) from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment +from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -722,9 +724,15 @@ def fit( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ - self._call_and_handle_interrupt( - self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path - ) + function = partial(self._call_and_handle_interrupt, trainer_fn=self._fit_impl) + args = (model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) + self.training_type_plugin._recover_results_in_main_process(spawn_output, self) + output = spawn_output.trainer_results + else: + output = function(*args) + return output def _fit_impl( self, @@ -1158,8 +1166,10 @@ def _run( self.checkpoint_connector.resume_end() - # dispatch `start_training` or `start_evaluating` or `start_predicting` - results = self._dispatch() + results = self.run_stage() + + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + results = self.training_type_plugin._collect_rank_zero_results(self, results) self._teardown() @@ -1170,10 +1180,7 @@ def _run( if self.state.fn == TrainerFn.FITTING: self.call_hook("on_fit_end") - # teardown if necessary (similar calls for spawn plugins are excluded as they have - # been included at the end of `new_process` functions) - if not isinstance(self.training_type_plugin, DDPSpawnPlugin): - self._call_teardown_hook() + self._call_teardown_hook() if self.state.status != TrainerStatus.INTERRUPTED: self.state.status = TrainerStatus.FINISHED @@ -1234,14 +1241,6 @@ def _teardown(self): self.logger_connector.teardown() self.signal_connector.teardown() - def _dispatch(self) -> Any: - if self.evaluating: - return self.training_type_plugin.start_evaluating(self) - elif self.predicting: - return self.training_type_plugin.start_predicting(self) - else: - return self.training_type_plugin.start_training(self) - def run_stage(self): self.accelerator.dispatch(self) self.__setup_profiler() From be735bd774060ea8c6277a471047b2eb4624098f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 05:55:54 +0100 Subject: [PATCH 043/104] update flow --- pytorch_lightning/trainer/trainer.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 30e44d14932a9..cd8a1df99fc74 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1123,16 +1123,11 @@ def _run( {Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || create accelerator || - | || - {self._dispatch} || | || LIGHTNING - {self.training_type_plugin.start_training} || - or {self.training_type_plugin.start_evaluating} || - or {self.training_type_plugin.start_predicting} || FLOW - | || - {self.run_stage} || - | || DIRECTION - {self._run_train} || + | || + {self.run_stage} || FLOW + | || + {self._run_train} || DIRECTION or {self._run_evaluate} || or {self._run_predict} || | || From 2879ccb1aa15442e5baef53462c68f543bbbadea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 06:06:25 +0100 Subject: [PATCH 044/104] fix it --- pl_examples/bug_report/bug_report_model.py | 4 +++- pytorch_lightning/trainer/trainer.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 7739630237d32..c5494071130fc 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -57,9 +57,11 @@ def run(): num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, + accelerator="cpu", + strategy="ddp_spawn", + num_processes=2, ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - trainer.test(model, dataloaders=test_data) if __name__ == "__main__": diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cd8a1df99fc74..3a00573f18e9c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -724,7 +724,7 @@ def fit( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ - function = partial(self._call_and_handle_interrupt, trainer_fn=self._fit_impl) + function = partial(self._call_and_handle_interrupt, self._fit_impl) args = (model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) if isinstance(self.training_type_plugin, DDPSpawnPlugin): spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) @@ -765,10 +765,11 @@ def _fit_impl( # TODO: ckpt_path only in v1.7 ckpt_path = ckpt_path or self.resume_from_checkpoint - self._run(model, ckpt_path=ckpt_path) + output = self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.training = False + return output def validate( self, From ace196e1f915936fed00e0aebc36bb9edfedbaee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 06:26:49 +0100 Subject: [PATCH 045/104] jackpot! --- pl_examples/bug_report/bug_report_model.py | 3 ++ pytorch_lightning/trainer/trainer.py | 40 ++++++++++++++++------ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index c5494071130fc..7cbc797af49df 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -27,6 +27,7 @@ def forward(self, x): return self.layer(x) def training_step(self, batch, batch_idx): + print(self.global_rank, self.global_step) loss = self(batch).sum() self.log("train_loss", loss) return {"loss": loss} @@ -62,6 +63,8 @@ def run(): num_processes=2, ) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + trainer.validate(model, dataloaders=val_data) + trainer.test(model, dataloaders=test_data) if __name__ == "__main__": diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3a00573f18e9c..b1db8acbd9d90 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -729,10 +729,8 @@ def fit( if isinstance(self.training_type_plugin, DDPSpawnPlugin): spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) self.training_type_plugin._recover_results_in_main_process(spawn_output, self) - output = spawn_output.trainer_results else: - output = function(*args) - return output + function(*args) def _fit_impl( self, @@ -765,11 +763,11 @@ def _fit_impl( # TODO: ckpt_path only in v1.7 ckpt_path = ckpt_path or self.resume_from_checkpoint - output = self._run(model, ckpt_path=ckpt_path) + results = self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.training = False - return output + return results def validate( self, @@ -803,7 +801,15 @@ def validate( :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ - return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) + function = partial(self._call_and_handle_interrupt, self._validate_impl) + args = (model, dataloaders, ckpt_path, verbose, datamodule) + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) + self.training_type_plugin._recover_results_in_main_process(spawn_output, self) + output = spawn_output.trainer_results + else: + output = function(*args) + return output def _validate_impl( self, @@ -886,7 +892,15 @@ def test( :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ - return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) + function = partial(self._call_and_handle_interrupt, self._test_impl) + args = (model, dataloaders, ckpt_path, verbose, datamodule) + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) + self.training_type_plugin._recover_results_in_main_process(spawn_output, self) + output = spawn_output.trainer_results + else: + output = function(*args) + return output def _test_impl( self, @@ -968,9 +982,15 @@ def predict( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ - return self._call_and_handle_interrupt( - self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path - ) + function = partial(self._call_and_handle_interrupt, self._predict_impl) + args = (model, dataloaders, datamodule, return_predictions, ckpt_path) + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) + self.training_type_plugin._recover_results_in_main_process(spawn_output, self) + output = spawn_output.trainer_results + else: + output = function(*args) + return output def _predict_impl( self, From 34a889afe097617090d23ec3671b2ff04ccf1303 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Dec 2021 12:09:54 +0000 Subject: [PATCH 046/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e535ae2364806..f8c23a8d7d7db 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1145,9 +1145,9 @@ def _run( | || create accelerator || | || LIGHTNING - | || + | || {self.run_stage} || FLOW - | || + | || {self._run_train} || DIRECTION or {self._run_evaluate} || or {self._run_predict} || @@ -1246,7 +1246,9 @@ def _log_hyperparams(self) -> None: self.logger.save() def _teardown(self): - """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback. + """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and + Callback. + Those are handled by :meth:`_call_teardown_hook`. """ self.training_type_plugin.teardown() From ad3f39d3f581a6e08432e4f4f11522c2b347c283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 17:16:41 +0100 Subject: [PATCH 047/104] update sharded and tests --- .../plugins/training_type/sharded_spawn.py | 4 +- tests/plugins/test_sharded_plugin.py | 44 +++++-------------- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 5e10155cc3ca5..3a211f96abb26 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -114,12 +114,12 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: + def pre_dispatch(self) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): self._precision_plugin.scaler = ShardedGradScaler() - return super().new_process(trainer) + return super().pre_dispatch() @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index f6b58692aa221..c135d1715789f 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -32,43 +32,23 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v @RunIf(fairscale=True) -@pytest.mark.parametrize(["strategy"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) -def test_sharded_ddp_choice(tmpdir, strategy): +@pytest.mark.parametrize( + "strategy,expected", [("ddp_sharded", DDPShardedPlugin), ("ddp_sharded_spawn", DDPSpawnShardedPlugin)] +) +def test_sharded_ddp_choice(tmpdir, strategy, expected): """Test to ensure that plugin is correctly chosen.""" - - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - if strategy == "ddp_sharded": - assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) - elif strategy == "ddp_sharded_spawn": - assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin) - raise SystemExit() - - model = BoringModel() - trainer = Trainer(fast_dev_run=True, strategy=strategy, callbacks=[CB()]) - - with pytest.raises(SystemExit): - trainer.fit(model) + trainer = Trainer(fast_dev_run=True, strategy=strategy) + assert isinstance(trainer.accelerator.training_type_plugin, expected) @RunIf(min_gpus=1, fairscale=True) -@pytest.mark.parametrize(["strategy"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) -def test_ddp_choice_sharded_amp(tmpdir, strategy): +@pytest.mark.parametrize( + "strategy,expected", [("ddp_sharded", DDPShardedPlugin), ("ddp_sharded_spawn", DDPSpawnShardedPlugin)] +) +def test_ddp_choice_sharded_amp(tmpdir, strategy, expected): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" - - class CB(Callback): - def on_fit_start(self, trainer, pl_module): - if strategy == "ddp_sharded": - assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) - elif strategy == "ddp_sharded_spawn": - assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin) - raise SystemExit() - - model = BoringModel() - trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy, callbacks=[CB()]) - - with pytest.raises(SystemExit): - trainer.fit(model) + trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy) + assert isinstance(trainer.accelerator.training_type_plugin, expected) @RunIf(skip_windows=True, fairscale=True) From c897a2021bbee3eb1642325dfea6aa66a061f5d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 18:09:41 +0100 Subject: [PATCH 048/104] pull down spawn call --- pytorch_lightning/trainer/trainer.py | 49 ++++++++-------------------- 1 file changed, 14 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f8c23a8d7d7db..e1dc6b19c7909 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -676,7 +676,12 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - return trainer_fn(*args, **kwargs) + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + spawn_output: _SpawnOutput = self.training_type_plugin.spawn(trainer_fn, *args, **kwargs) + self.training_type_plugin._recover_results_in_main_process(spawn_output, self) + return spawn_output.trainer_results + else: + return trainer_fn(*args, **kwargs) # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") @@ -724,13 +729,9 @@ def fit( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ - function = partial(self._call_and_handle_interrupt, self._fit_impl) - args = (model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) - if isinstance(self.training_type_plugin, DDPSpawnPlugin): - spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) - self.training_type_plugin._recover_results_in_main_process(spawn_output, self) - else: - function(*args) + self._call_and_handle_interrupt( + self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + ) def _fit_impl( self, @@ -801,15 +802,7 @@ def validate( :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ - function = partial(self._call_and_handle_interrupt, self._validate_impl) - args = (model, dataloaders, ckpt_path, verbose, datamodule) - if isinstance(self.training_type_plugin, DDPSpawnPlugin): - spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) - self.training_type_plugin._recover_results_in_main_process(spawn_output, self) - output = spawn_output.trainer_results - else: - output = function(*args) - return output + return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _validate_impl( self, @@ -892,15 +885,7 @@ def test( :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ - function = partial(self._call_and_handle_interrupt, self._test_impl) - args = (model, dataloaders, ckpt_path, verbose, datamodule) - if isinstance(self.training_type_plugin, DDPSpawnPlugin): - spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) - self.training_type_plugin._recover_results_in_main_process(spawn_output, self) - output = spawn_output.trainer_results - else: - output = function(*args) - return output + return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _test_impl( self, @@ -982,15 +967,9 @@ def predict( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ - function = partial(self._call_and_handle_interrupt, self._predict_impl) - args = (model, dataloaders, datamodule, return_predictions, ckpt_path) - if isinstance(self.training_type_plugin, DDPSpawnPlugin): - spawn_output: _SpawnOutput = self.training_type_plugin.spawn(function, *args) - self.training_type_plugin._recover_results_in_main_process(spawn_output, self) - output = spawn_output.trainer_results - else: - output = function(*args) - return output + return self._call_and_handle_interrupt( + self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path + ) def _predict_impl( self, From 90054cf4daa809abfa08ad572956d3ae25650360 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 18:10:35 +0100 Subject: [PATCH 049/104] simplify test --- tests/trainer/test_trainer.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e440f5f703f75..dc7070c9c2ac0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1515,14 +1515,10 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *arg def test_spawn_predict_return_predictions(_, __, accelerator): """Test that `return_predictions=True` raise a MisconfigurationException with spawn training type plugins.""" model = BoringModel() - - def run(expected_plugin, **trainer_kwargs): - trainer = Trainer(**trainer_kwargs, fast_dev_run=True) - assert isinstance(trainer.training_type_plugin, expected_plugin) - with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"): - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True) - - run(DDPSpawnPlugin, accelerator=accelerator, strategy="ddp_spawn", devices=2) + trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True) + assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) + with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"): + trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True) @pytest.mark.parametrize("return_predictions", [None, False, True]) From 009abfadeba656724db6397367b32ccc7c002023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 20:55:51 +0100 Subject: [PATCH 050/104] attach model as early as possible --- pytorch_lightning/trainer/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e1dc6b19c7909..e2095135d15b7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -729,6 +729,7 @@ def fit( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ + self.training_type_plugin.model = model self._call_and_handle_interrupt( self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) @@ -802,6 +803,7 @@ def validate( :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ + self.training_type_plugin.model = model return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _validate_impl( @@ -885,6 +887,7 @@ def test( :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ + self.training_type_plugin.model = model return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _test_impl( @@ -967,6 +970,7 @@ def predict( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ + self.training_type_plugin.model = model return self._call_and_handle_interrupt( self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) From 376e4fe23d1d823fd3b936047021638bb04a2efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 2 Dec 2021 20:56:02 +0100 Subject: [PATCH 051/104] demonstrate which tests fails --- tests/checkpointing/test_model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index fa08057733f68..7360fb2ca78ac 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -402,7 +402,6 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): max_epochs=num_epochs, ) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" def test_model_checkpoint_format_checkpoint_name(tmpdir): From de1811ed5658011d001d25acfbe937333eb247ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 11:10:19 +0100 Subject: [PATCH 052/104] set module --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e2095135d15b7..5b0d08385a511 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -803,7 +803,7 @@ def validate( :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ - self.training_type_plugin.model = model + self.training_type_plugin.model = model or self.lightning_module return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _validate_impl( @@ -887,7 +887,7 @@ def test( :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ - self.training_type_plugin.model = model + self.training_type_plugin.model = model or self.lightning_module return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _test_impl( @@ -970,7 +970,7 @@ def predict( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ - self.training_type_plugin.model = model + self.training_type_plugin.model = model or self.lightning_module return self._call_and_handle_interrupt( self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) From ef61a0b7327e9c36ca710da6d53e6415b545e8b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 11:16:09 +0100 Subject: [PATCH 053/104] update exception --- tests/trainer/test_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index dc7070c9c2ac0..ced4eaacdf3b9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -26,6 +26,7 @@ import cloudpickle import pytest import torch +from torch.multiprocessing import ProcessRaisedException from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD from torch.utils.data import DataLoader, IterableDataset @@ -1517,7 +1518,7 @@ def test_spawn_predict_return_predictions(_, __, accelerator): model = BoringModel() trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True) assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) - with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"): + with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"): trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True) From 809014a0c2292ba4182ceb6cac2a8ebef59541b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 11:19:39 +0100 Subject: [PATCH 054/104] imports --- pytorch_lightning/plugins/training_type/ddp.py | 1 - pytorch_lightning/plugins/training_type/sharded_spawn.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 1 - tests/plugins/test_sharded_plugin.py | 1 - 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 5f6517f89c37f..569b45f851305 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -30,7 +30,6 @@ from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel -import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 3a211f96abb26..2531824bf2ce0 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Optional, Tuple import torch from torch.nn import Module @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5b0d08385a511..ca22394288b8f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -19,7 +19,6 @@ import warnings from argparse import ArgumentParser, Namespace from datetime import timedelta -from functools import partial from pathlib import Path from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union from weakref import proxy diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index c135d1715789f..e3b7e4986d9fb 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -6,7 +6,6 @@ import torch from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE From 440b639f29ae37229327c5b387231e76e94b14b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 11:28:52 +0100 Subject: [PATCH 055/104] transfer trainer state --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 8 ++++++-- pytorch_lightning/trainer/trainer.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index e2c8dccf35bbd..15671775150f7 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -32,7 +32,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import distributed_available @@ -245,7 +245,8 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt else: self.add_to_queue(trainer, extra) - return _SpawnOutput(best_model_path, last_path, results, extra) + state = trainer.state + return _SpawnOutput(best_model_path, last_path, state, results, extra) def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer @@ -260,6 +261,8 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer ) self.lightning_module.load_state_dict(ckpt) + trainer.state = spawn_output.trainer_state + # get the `callback_metrics` and set it to the trainer if is_overridden("get_from_queue", self.lightning_module): # only in case the user does not override it. @@ -395,5 +398,6 @@ def empty(self) -> bool: class _SpawnOutput(NamedTuple): best_model_path: Optional[_PATH] last_path: Optional[_PATH] + trainer_state: TrainerState trainer_results: Any extra: _FakeQueue diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ca22394288b8f..f2ca20db3cd22 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1166,9 +1166,6 @@ def _run( results = self.run_stage() - if isinstance(self.training_type_plugin, DDPSpawnPlugin): - results = self.training_type_plugin._collect_rank_zero_results(self, results) - self._teardown() # ---------------------------- @@ -1184,6 +1181,9 @@ def _run( self.state.status = TrainerStatus.FINISHED self.state.stage = None + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + results = self.training_type_plugin._collect_rank_zero_results(self, results) + return results def _pre_dispatch(self): From ab5559e73da53500c3f41f7413c7e42ec710ccd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 12:12:26 +0100 Subject: [PATCH 056/104] fix problem with getqueue --- pytorch_lightning/core/lightning.py | 4 ---- pytorch_lightning/plugins/training_type/ddp_spawn.py | 6 ++---- pytorch_lightning/plugins/training_type/tpu_spawn.py | 5 ++--- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e02c9d32ecb80..b0055e65bb2cd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1928,8 +1928,6 @@ def add_to_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.add_to_queue` and will be removed in v1.7. """ - if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): - self.trainer.training_type_plugin.add_to_queue(self.trainer, queue) def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, @@ -1942,8 +1940,6 @@ def get_from_queue(self, queue: pl.plugins.training_type.ddp_spawn._FakeQueue) - This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.get_from_queue` and will be removed in v1.7. """ - if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin): - self.trainer.training_type_plugin.get_from_queue(self.trainer, queue) @contextmanager def _prevent_trainer_and_dataloaders_deepcopy(self) -> None: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 15671775150f7..74cd9667d85af 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -242,8 +242,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) - else: - self.add_to_queue(trainer, extra) + self.add_to_queue(trainer, extra) state = trainer.state return _SpawnOutput(best_model_path, last_path, state, results, extra) @@ -268,8 +267,7 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer # only in case the user does not override it. # TODO: Remove the if in v1.7 self.lightning_module.get_from_queue(spawn_output.extra) - else: - self.get_from_queue(trainer, spawn_output.extra) + self.get_from_queue(trainer, spawn_output.extra) def barrier(self, *args, **kwargs) -> None: if not distributed_available(): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index b3c387b2aef5e..c8c19b1bcb26e 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -194,7 +194,7 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -217,8 +217,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op if is_overridden("add_to_queue", self.lightning_module): # TODO: Remove the if in v1.7 self.lightning_module.add_to_queue(extra) - else: - self.add_to_queue(trainer, extra) + self.add_to_queue(trainer, extra) return _SpawnOutput(best_model_path, last_path, results, extra) From f4f1269a8950993d874eb44757a34c9ca1b43a71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 12:15:47 +0100 Subject: [PATCH 057/104] deprecation calls don't come through ddp_spawn --- tests/plugins/test_ddp_spawn_plugin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 6ea265a4bb575..784e76ec900a6 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -78,8 +78,7 @@ def test_ddp_spawn_extra_parameters(tmpdir): val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() - with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"): - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert model.test_val == "test_val" @@ -105,8 +104,7 @@ def test_ddp_spawn_add_get_queue(tmpdir): val_name: str = "val_acc" model = BoringCallbackDDPSpawnModel(val_name, val) dm = BoringDataModule() - with pytest.deprecated_call(match="add_to_queue` method was deprecated in v1.5"): - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm) assert trainer.callback_metrics[val_name] == torch.tensor(val) assert ddp_spawn_plugin.new_test_val == "new_test_val" From b30c35218fd72bab9fe9433645c6568dacdead2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 12:16:02 +0100 Subject: [PATCH 058/104] prepare data only gets called on rank 0 --- tests/helpers/boring_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index d51fb44bff0d2..3a1c4f30fe1f4 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -151,8 +151,6 @@ def __init__(self, data_dir: str = "./"): self.data_dir = data_dir self.non_picklable = None self.checkpoint_state: Optional[str] = None - - def prepare_data(self): self.random_full = RandomDataset(32, 64 * 4) def setup(self, stage: Optional[str] = None): From 5434ae579f3fdf948ba7bca7f2c431e9d3535d4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 12:27:10 +0100 Subject: [PATCH 059/104] import --- tests/trainer/test_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ced4eaacdf3b9..c70fcb96ccc27 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -26,7 +26,6 @@ import cloudpickle import pytest import torch -from torch.multiprocessing import ProcessRaisedException from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD from torch.utils.data import DataLoader, IterableDataset @@ -1518,7 +1517,9 @@ def test_spawn_predict_return_predictions(_, __, accelerator): model = BoringModel() trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True) assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) - with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"): + with pytest.raises( + torch.multiprocessing.ProcessRaisedException, match="`return_predictions` should be set to `False`" + ): trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True) From 24f05f1a295f268d31147d4d6f3acf943695c2c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 12:55:56 +0100 Subject: [PATCH 060/104] update test --- tests/core/test_datamodules.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 57574c074cd77..11ae6b13aaafe 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from typing import Any, Dict from unittest import mock -from unittest.mock import call, PropertyMock +from unittest.mock import call, PropertyMock, Mock import pytest import torch @@ -40,51 +40,52 @@ @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): - dm = BoringDataModule() + dm = Mock(spec=BoringDataModule) + dm.prepare_data_per_node = True trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) - dm.random_full = None + dm.prepare_data.assert_not_called() local_rank.return_value = 0 assert trainer.local_rank == 0 trainer._data_connector.prepare_data() - assert dm.random_full is not None + dm.prepare_data.assert_called_once() # local rank = 1 (False) - dm.random_full = None + dm.reset_mock() local_rank.return_value = 1 assert trainer.local_rank == 1 trainer._data_connector.prepare_data() - assert dm.random_full is None + dm.prepare_data.assert_not_called() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) - dm.random_full = None + dm.reset_mock() dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer._data_connector.prepare_data() - assert dm.random_full is not None + dm.prepare_data.assert_called_once() # global rank = 1 (False) - dm.random_full = None + dm.reset_mock() node_rank.return_value = 1 local_rank.return_value = 0 trainer._data_connector.prepare_data() - assert dm.random_full is None + dm.prepare_data.assert_not_called() node_rank.return_value = 0 local_rank.return_value = 1 trainer._data_connector.prepare_data() - assert dm.random_full is None + dm.prepare_data.assert_not_called() # 2 dm # prepar per node = True @@ -92,10 +93,9 @@ def test_can_prepare_data(local_rank, node_rank): dm.prepare_data_per_node = True local_rank.return_value = 0 - with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: - # is_overridden prepare data = True - trainer._data_connector.prepare_data() - dm_mock.assert_called_once() + # is_overridden prepare data = True + trainer._data_connector.prepare_data() + dm.prepare_data.assert_called_once() def test_hooks_no_recursion_error(): From 39599554aac6e22372643090787e849c4dbf5c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 13:10:06 +0100 Subject: [PATCH 061/104] update exception --- tests/trainer/test_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c70fcb96ccc27..d6a5c05368bfd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1411,7 +1411,9 @@ def predict( callbacks=[cb, cb_1] if use_callbacks else [], ) if strategy == "ddp_spawn": - with pytest.raises(MisconfigurationException): + with pytest.raises( + torch.multiprocessing.ProcessRaisedException, match="`return_predictions` should be set to `False`" + ): trainer.predict(model, datamodule=dm, return_predictions=True) if datamodule: From f491abe469924fa08e1ec6c3896d0be6d75afd06 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Dec 2021 12:11:44 +0000 Subject: [PATCH 062/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/core/test_datamodules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 11ae6b13aaafe..7f26c5f1f41a6 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from typing import Any, Dict from unittest import mock -from unittest.mock import call, PropertyMock, Mock +from unittest.mock import call, Mock, PropertyMock import pytest import torch From 0c808ce4f386af4f4891e1b0a4e68252ca33758a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 15:08:42 +0100 Subject: [PATCH 063/104] adapt tpu spawn --- .../plugins/training_type/tpu_spawn.py | 71 ++++++------------- pytorch_lightning/trainer/trainer.py | 14 +++- tests/plugins/test_ddp_spawn_plugin.py | 1 - 3 files changed, 36 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index c8c19b1bcb26e..87ae3b84690f3 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -118,10 +118,25 @@ def connect(self, model: "pl.LightningModule") -> None: self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) - def pre_dispatch(self): + def pre_dispatch(self) -> None: if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) + trainer = self.lightning_module.trainer + if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: + # TODO: this is already done in the trainer, still needed? + trainer.progress_bar_callback.disable() + + shared_params = find_shared_parameters(self.model) + self.model_to_device() + if is_overridden("on_post_move_to_device", self.lightning_module): + self.model.module.on_post_move_to_device() + else: + set_shared_parameters(self.model.module, shared_params) + + self.setup_optimizers(trainer) + self.precision_plugin.connect(self._model, None, None) + def setup(self, trainer: "pl.Trainer") -> None: self.start_method = "fork" if not self.setup_optimizers_in_pre_dispatch: @@ -156,37 +171,6 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: - if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: - trainer.progress_bar_callback.disable() - - shared_params = find_shared_parameters(self.model) - self.model_to_device() - if is_overridden("on_post_move_to_device", self.lightning_module): - self.model.module.on_post_move_to_device() - else: - set_shared_parameters(self.model.module, shared_params) - - trainer.training_type_plugin.setup_optimizers(trainer) - trainer.precision_plugin.connect(self._model, None, None) - - self.barrier("pre-run-stage") - - results = trainer.run_stage() - - outputs = self.__collect_rank_zero_results(trainer, results) - - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - self.barrier("end-process") - - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if self.local_rank == 0: - time.sleep(2) - - # ensure that spawned processes go through teardown before joining - trainer._call_teardown_hook() - return outputs - def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) @@ -219,7 +203,8 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt self.lightning_module.add_to_queue(extra) self.add_to_queue(trainer, extra) - return _SpawnOutput(best_model_path, last_path, results, extra) + state = trainer.state + return _SpawnOutput(best_model_path, last_path, state, results, extra) def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: @@ -264,6 +249,10 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st } def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: + # TODO: this todo is unclear, does it still apply? + # todo: precision pluging is call in accelerator setup and should be moved + if "XLA_USE_BF16" in os.environ: + del os.environ["XLA_USE_BF16"] context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) @@ -288,21 +277,6 @@ def _worker_setup(self, process_idx: int): self.tpu_global_core_rank = xm.get_ordinal() rank_zero_only.rank = self.global_rank - def start_training(self, trainer: "pl.Trainer") -> Any: - # todo: precision pluging is call in accelerator setup and should be moved - if "XLA_USE_BF16" in os.environ: - del os.environ["XLA_USE_BF16"] - self._clean_logger(trainer) - return super().start_training(trainer) - - def start_evaluating(self, trainer: "pl.Trainer") -> Any: - self._clean_logger(trainer) - return super().start_evaluating(trainer) - - def start_predicting(self, trainer: "pl.Trainer") -> Any: - self._clean_logger(trainer) - return super().start_predicting(trainer) - def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -379,6 +353,7 @@ def checkpoint_io(self) -> CheckpointIO: def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") + # TODO: still needed? @staticmethod def _clean_logger(trainer: "pl.Trainer") -> None: loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f2ca20db3cd22..c583c190abb03 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,6 +15,7 @@ import inspect import logging import os +import time import traceback import warnings from argparse import ArgumentParser, Namespace @@ -1164,7 +1165,18 @@ def _run( self.checkpoint_connector.resume_end() - results = self.run_stage() + # TODO: needed? (was originally in TPUSpawnPLugin) + # self.training_type_plugin.barrier("pre-run-stage") + + self.run_stage() + + # TODO: needed? (was originally in TPUSpawnPLugin) + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 + # self.training_type_plugin.barrier("end-process") + + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 + # if self.local_rank == 0: + # time.sleep(2) self._teardown() diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 784e76ec900a6..8929027e78eb6 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from torch.nn.parallel.distributed import DistributedDataParallel From d6dd3433f0e2b1f086cfd01ad7c2e6f7ee1e5e98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 15:09:10 +0100 Subject: [PATCH 064/104] imports --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 87ae3b84690f3..744a0112fe8cc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -16,7 +16,7 @@ import re import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.multiprocessing as mp From b04768731a397795a154252ed8980b7ea8e0098c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 15:34:51 +0100 Subject: [PATCH 065/104] update --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c583c190abb03..79d7fb8dbe8a7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1168,7 +1168,7 @@ def _run( # TODO: needed? (was originally in TPUSpawnPLugin) # self.training_type_plugin.barrier("pre-run-stage") - self.run_stage() + results = self.run_stage() # TODO: needed? (was originally in TPUSpawnPLugin) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 @@ -1196,6 +1196,7 @@ def _run( if isinstance(self.training_type_plugin, DDPSpawnPlugin): results = self.training_type_plugin._collect_rank_zero_results(self, results) + # TODO: The reslts no longer need to pass through _collect_rank_zero_results and can be returned directly here return results def _pre_dispatch(self): From c524e52d9aaa41f1e4fb31045a29e5ed2cd1e3e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 15:46:35 +0100 Subject: [PATCH 066/104] add missing arg --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 2531824bf2ce0..951a0be78e7b9 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -114,12 +114,12 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def pre_dispatch(self) -> None: + def pre_dispatch(self, trainer: "pl.Trainer") -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): self._precision_plugin.scaler = ShardedGradScaler() - return super().pre_dispatch() + return super().pre_dispatch(trainer) @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 79d7fb8dbe8a7..9b143ba81ab0a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1154,7 +1154,7 @@ def _run( if self.state.fn == TrainerFn.FITTING: self.call_hook("on_fit_start") - # plugin will setup fitting (e.g. ddp will launch child processes) + # plugin will move model to device self._pre_dispatch() if self.training_type_plugin.restore_checkpoint_after_pre_dispatch: From 223e7aa704879abb249858d2db27dd5fdab90d82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 16:52:04 +0100 Subject: [PATCH 067/104] fix exception import on torch < 1.8 --- tests/trainer/test_trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d6a5c05368bfd..d139a896d7b47 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -50,7 +50,7 @@ from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException -from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE +from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.seed import seed_everything from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset @@ -62,6 +62,10 @@ if _OMEGACONF_AVAILABLE: from omegaconf import OmegaConf +ProcessRaisedException = Exception +if _TORCH_GREATER_EQUAL_1_8: + from torch.multiprocessing import ProcessRaisedException + @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): @@ -1411,9 +1415,7 @@ def predict( callbacks=[cb, cb_1] if use_callbacks else [], ) if strategy == "ddp_spawn": - with pytest.raises( - torch.multiprocessing.ProcessRaisedException, match="`return_predictions` should be set to `False`" - ): + with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"): trainer.predict(model, datamodule=dm, return_predictions=True) if datamodule: @@ -1519,9 +1521,7 @@ def test_spawn_predict_return_predictions(_, __, accelerator): model = BoringModel() trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True) assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) - with pytest.raises( - torch.multiprocessing.ProcessRaisedException, match="`return_predictions` should be set to `False`" - ): + with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"): trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True) From ed309d6db25225449eb00e3998d38ffff1a6fcf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 17:12:39 +0100 Subject: [PATCH 068/104] debug --- pl_examples/bug_report/bug_report_model.py | 29 ++++++++++++++-------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 7cbc797af49df..1a754c8950bc2 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -1,9 +1,10 @@ import os +from copy import deepcopy import torch from torch.utils.data import DataLoader, Dataset -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningModule, Trainer, seed_everything class RandomDataset(Dataset): @@ -49,22 +50,28 @@ def run(): val_data = DataLoader(RandomDataset(32, 64), batch_size=2) test_data = DataLoader(RandomDataset(32, 64), batch_size=2) + seed_everything(42) + model = BoringModel() + model_copy = deepcopy(model) + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 8 trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, - num_sanity_val_steps=0, + limit_train_batches=limit_train_batches, + limit_val_batches=2, max_epochs=1, - enable_model_summary=False, + log_every_n_steps=1, accelerator="cpu", + gpus=2, strategy="ddp_spawn", - num_processes=2, ) - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - trainer.validate(model, dataloaders=val_data) - trainer.test(model, dataloaders=test_data) + + trainer.fit(model, train_data) + + for param, param_copy in zip(model.parameters(), model_copy.parameters()): + assert not torch.equal(param.cpu().data, param_copy.data) if __name__ == "__main__": From 12eed616ad7c4cff636a79f9c08fa4d5dd4a3a17 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Dec 2021 16:14:01 +0000 Subject: [PATCH 069/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_examples/bug_report/bug_report_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 1a754c8950bc2..11a86289df5f7 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -4,7 +4,7 @@ import torch from torch.utils.data import DataLoader, Dataset -from pytorch_lightning import LightningModule, Trainer, seed_everything +from pytorch_lightning import LightningModule, seed_everything, Trainer class RandomDataset(Dataset): From be73261e8299a32b337141383191362623894adb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 18:35:37 +0100 Subject: [PATCH 070/104] debug tpu --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 73bcc7448b9fb..c0f72251794e7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1166,13 +1166,13 @@ def _run( self.checkpoint_connector.resume_end() # TODO: needed? (was originally in TPUSpawnPLugin) - # self.training_type_plugin.barrier("pre-run-stage") + self.training_type_plugin.barrier("pre-run-stage") results = self.run_stage() # TODO: needed? (was originally in TPUSpawnPLugin) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - # self.training_type_plugin.barrier("end-process") + self.training_type_plugin.barrier("end-process") # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 # if self.local_rank == 0: From c71fc5777a583a20be4d22dd90c8f2c648988666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 18:54:22 +0100 Subject: [PATCH 071/104] fix docs --- pytorch_lightning/plugins/training_type/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index e428b348e6e5a..94ea9675fae1c 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -27,9 +27,9 @@ import numpy as np import torch import torch.distributed +import pytorch_lightning as pl from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel - from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward From 2ed6333ae413550e50c1a734a2ba33593957e52c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 19:03:34 +0100 Subject: [PATCH 072/104] fix teardown being called twice --- pytorch_lightning/trainer/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c0f72251794e7..bea9f614f2fc2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1246,8 +1246,6 @@ def _teardown(self): Those are handled by :meth:`_call_teardown_hook`. """ - self.training_type_plugin.teardown() - # TODO: once accelerator is part of TTP, call teardown in TTP's teardown() method self.accelerator.teardown() self._data_connector.teardown() self._active_loop.teardown() From 2a8b9b4212a19070b26cdb79bee240894f73fb37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 19:04:46 +0100 Subject: [PATCH 073/104] revert a sate check --- tests/checkpointing/test_model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7360fb2ca78ac..fa08057733f68 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -402,6 +402,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): max_epochs=num_epochs, ) trainer.fit(model) + assert trainer.state.finished, f"Training failed with {trainer.state}" def test_model_checkpoint_format_checkpoint_name(tmpdir): From 5335664c6859394cf1c6b1f56d88ea8410e0df5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Dec 2021 18:04:54 +0000 Subject: [PATCH 074/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 94ea9675fae1c..f5eb95638d607 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -27,9 +27,10 @@ import numpy as np import torch import torch.distributed -import pytorch_lightning as pl from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel + +import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward From 93cfaf8cabb8b24ae4cc2cb3762c5eb10d636577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Dec 2021 23:58:00 +0100 Subject: [PATCH 075/104] fix --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index bd4c7405ddc84..c66904372ba98 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -254,7 +254,7 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer # TODO: pass also best score # load last weights - if spawn_output.last_path is not None and trainer.state.fn == TrainerFn.FITTING: + if spawn_output.last_path is not None: ckpt = self.checkpoint_io.load_checkpoint( spawn_output.last_path, map_location=(lambda storage, loc: storage) ) From dde5a3a6742494d06a6983e1412b7f607636b8ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Dec 2021 23:33:05 +0100 Subject: [PATCH 076/104] reset bug report model --- pl_examples/bug_report/bug_report_model.py | 30 +++++++--------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 11a86289df5f7..7739630237d32 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -1,10 +1,9 @@ import os -from copy import deepcopy import torch from torch.utils.data import DataLoader, Dataset -from pytorch_lightning import LightningModule, seed_everything, Trainer +from pytorch_lightning import LightningModule, Trainer class RandomDataset(Dataset): @@ -28,7 +27,6 @@ def forward(self, x): return self.layer(x) def training_step(self, batch, batch_idx): - print(self.global_rank, self.global_step) loss = self(batch).sum() self.log("train_loss", loss) return {"loss": loss} @@ -50,28 +48,18 @@ def run(): val_data = DataLoader(RandomDataset(32, 64), batch_size=2) test_data = DataLoader(RandomDataset(32, 64), batch_size=2) - seed_everything(42) - model = BoringModel() - model_copy = deepcopy(model) - model.val_dataloader = None - model.training_epoch_end = None - - limit_train_batches = 8 trainer = Trainer( - limit_train_batches=limit_train_batches, - limit_val_batches=2, + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + num_sanity_val_steps=0, max_epochs=1, - log_every_n_steps=1, - accelerator="cpu", - gpus=2, - strategy="ddp_spawn", + enable_model_summary=False, ) - - trainer.fit(model, train_data) - - for param, param_copy in zip(model.parameters(), model_copy.parameters()): - assert not torch.equal(param.cpu().data, param_copy.data) + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + trainer.test(model, dataloaders=test_data) if __name__ == "__main__": From 77329b21eb4e305efd2453d03909fe0345698d80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Dec 2021 23:43:16 +0100 Subject: [PATCH 077/104] fix merge error --- tests/plugins/test_ddp_spawn_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 62562e5401b62..c8c861050d844 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -14,6 +14,7 @@ from pathlib import Path from unittest.mock import Mock +import pytest import torch from torch.nn.parallel.distributed import DistributedDataParallel From eb05fc9058c8310365c8a507bdd3c2e52b290adb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 02:05:06 +0100 Subject: [PATCH 078/104] barrier clean ups --- .../plugins/training_type/ddp_spawn.py | 3 --- .../plugins/training_type/tpu_spawn.py | 16 ++++------------ pytorch_lightning/trainer/trainer.py | 13 +------------ 3 files changed, 5 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 9bcdd59dbb82f..2f870ae973ace 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -179,9 +179,6 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - # TODO: needed? - self.barrier() - def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7244e74a292f3..44ad9d9a66bbc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -329,9 +329,12 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra return xm.all_gather(tensor) def teardown(self) -> None: - super().teardown() os.environ.pop("PT_XLA_DEBUG", None) + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("teardown") + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 + if self.local_rank == 0: + time.sleep(2) @property def should_rank_save_checkpoint(self) -> bool: @@ -348,14 +351,3 @@ def checkpoint_io(self) -> CheckpointIO: @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") - - # TODO: still needed? - @staticmethod - def _clean_logger(trainer: "pl.Trainer") -> None: - loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] - for logger in loggers: - if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: - # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. - # we want to make sure these are closed before we spawn our own threads. - # assuming nothing else references the experiment object, python should instantly `__del__` it. - logger._experiment = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c21f224285203..ac23a62dbad12 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1165,19 +1165,7 @@ def _run( self.checkpoint_connector.resume_end() - # TODO: needed? (was originally in TPUSpawnPLugin) - self.training_type_plugin.barrier("pre-run-stage") - results = self.run_stage() - - # TODO: needed? (was originally in TPUSpawnPLugin) - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - self.training_type_plugin.barrier("end-process") - - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - # if self.local_rank == 0: - # time.sleep(2) - self._teardown() # ---------------------------- @@ -1255,6 +1243,7 @@ def _teardown(self): self.signal_connector.teardown() def run_stage(self): + self.training_type_plugin.barrier("run-stage") self.training_type_plugin.dispatch(self) self.__setup_profiler() From dbcb76ca81f8e52eca92a20b781eb1bf9f495ee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 02:05:17 +0100 Subject: [PATCH 079/104] update comments in trainer --- pytorch_lightning/trainer/trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac23a62dbad12..af855f288b92b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1125,7 +1125,8 @@ def _run( Lightning internal flow looks like this: {Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || - create accelerator || + setup accelerator || + and strategy || | || LIGHTNING | || {self.run_stage} || FLOW @@ -1137,8 +1138,6 @@ def _run( results \/ This is used to guide readers to the core loops: train, test, predict. {self._run_predict} is the simplest to understand, use `Go to Definition` to read it :) - Search for `start_training` or `start_evaluating` or `start_predicting` in - `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. """ # ---------------------------- @@ -1185,7 +1184,6 @@ def _run( if isinstance(self.training_type_plugin, DDPSpawnPlugin): results = self.training_type_plugin._collect_rank_zero_results(self, results) - # TODO: The reslts no longer need to pass through _collect_rank_zero_results and can be returned directly here return results def _pre_dispatch(self): From ed0defa0dbfb30e4ad6e7215aba2bb6beb168c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 02:08:06 +0100 Subject: [PATCH 080/104] unused import --- pytorch_lightning/trainer/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index af855f288b92b..6924a26a69646 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,7 +15,6 @@ import inspect import logging import os -import time import traceback import warnings from argparse import ArgumentParser, Namespace From 79975f24623045e1d1185880bb0a134bbcd8cd63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 12:56:21 +0100 Subject: [PATCH 081/104] debug --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 ++ pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 2f870ae973ace..b5373e22023c0 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -168,6 +168,8 @@ def _worker_setup(self, process_idx: int): ) def pre_dispatch(self, trainer: "pl.Trainer") -> None: + super().pre_dispatch(trainer) + # move the model to the correct device self.model_to_device() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 44ad9d9a66bbc..463936cd4e41d 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -118,12 +118,11 @@ def connect(self, model: "pl.LightningModule") -> None: return super().connect(model) def pre_dispatch(self, trainer: "pl.Trainer") -> None: - super().pre_dispatch(trainer) + self._move_optimizer_state() if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: - # TODO: this is already done in the trainer, still needed? trainer.progress_bar_callback.disable() shared_params = find_shared_parameters(self.model) From d5ec0b7a40737c4116fc50c478ffa63a60343721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 13:42:38 +0100 Subject: [PATCH 082/104] changelog --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/trainer.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa9bd9c0ce71b..0315911e0c7c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,6 +96,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934)) +- All spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6924a26a69646..cdd7f92900c44 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1228,9 +1228,7 @@ def _log_hyperparams(self) -> None: def _teardown(self): """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and - Callback. - - Those are handled by :meth:`_call_teardown_hook`. + Callback; those are handled by :meth:`_call_teardown_hook`. """ self.training_type_plugin.post_dispatch(self) self.accelerator.teardown() From b2f8347168d97d9de0cc8390def802a9db78caea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 13:51:11 +0100 Subject: [PATCH 083/104] update changelog --- CHANGELOG.md | 2 ++ pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0315911e0c7c3..afdb0a6171723 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - All spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) +- The `setup` and `teardown` hooks/callbacks now run under initialized process group for spawn-based plugins just like their non-spawn counterparts ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 463936cd4e41d..816ad7cccfa41 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,7 +23,6 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO From d8e62187bb82aef5bdf55e149e5ae064384e0186 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Dec 2021 12:52:43 +0000 Subject: [PATCH 084/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cdd7f92900c44..cde8e8889b479 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1228,8 +1228,7 @@ def _log_hyperparams(self) -> None: def _teardown(self): """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and - Callback; those are handled by :meth:`_call_teardown_hook`. - """ + Callback; those are handled by :meth:`_call_teardown_hook`.""" self.training_type_plugin.post_dispatch(self) self.accelerator.teardown() self._data_connector.teardown() From 436572b0c02b259ba5b7ae7bcd020f50bc6c29d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 14:52:22 +0100 Subject: [PATCH 085/104] update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afdb0a6171723..786d9d85143d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,7 +99,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - All spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) -- The `setup` and `teardown` hooks/callbacks now run under initialized process group for spawn-based plugins just like their non-spawn counterparts ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) +- The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) + ### Deprecated From a3bc1b15a8523197e0309135e198f52b7e496c8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Dec 2021 18:52:09 +0100 Subject: [PATCH 086/104] Update tests/trainer/test_trainer.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- tests/trainer/test_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5a3bbf5e89074..9d9ff2a77cfbb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,9 +62,10 @@ if _OMEGACONF_AVAILABLE: from omegaconf import OmegaConf -ProcessRaisedException = Exception if _TORCH_GREATER_EQUAL_1_8: from torch.multiprocessing import ProcessRaisedException +else: + ProcessRaisedException = Exception @pytest.mark.parametrize("url_ckpt", [True, False]) From bafd95c7a55f12d11c80d805b16f3e5e84d39c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Dec 2021 02:05:46 +0100 Subject: [PATCH 087/104] add clarification comment --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 816ad7cccfa41..9ba281cca9d57 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -260,7 +260,10 @@ def _wrapped_function( if self.local_rank == 0: return_queue.put(move_data_to_device(result, "cpu")) + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") + + # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self.local_rank == 0: time.sleep(2) @@ -328,11 +331,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra def teardown(self) -> None: os.environ.pop("PT_XLA_DEBUG", None) - # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 - self.barrier("teardown") - # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 - if self.local_rank == 0: - time.sleep(2) @property def should_rank_save_checkpoint(self) -> bool: From 338605a3b4efc3db124aef5396ed508be7974c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Dec 2021 03:03:00 +0100 Subject: [PATCH 088/104] update changelog --- CHANGELOG.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 786d9d85143d4..2b2a9f2c60508 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -96,10 +96,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934)) -- All spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) - - -- The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) +- Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) + * All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` + * The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts + * Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8) ### Deprecated From c992a55f8924001c3d994b70a54164bed72efef2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Dec 2021 02:04:25 +0000 Subject: [PATCH 089/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b2a9f2c60508..65d3bfe92ac2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -97,7 +97,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Redesigned process creation for spawn-based plugins (`DDPSpawnPlugin`, `TPUSpawnPlugin`, etc.) ([#10896](https://github.com/PyTorchLightning/pytorch-lightning/pull/10896)) - * All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` + * All spawn-based plugins now spawn processes immediately upon calling `Trainer.{fit,validate,test,predict}` * The hooks/callbacks `prepare_data`, `setup`, `configure_sharded_model` and `teardown` now run under initialized process group for spawn-based plugins just like their non-spawn counterparts * Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8) From ac1428d81dc67d56db17c314f6536e2fae6bc982 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Dec 2021 03:50:59 +0100 Subject: [PATCH 090/104] skip test that can't run on too old torch version on windows --- tests/trainer/test_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9d9ff2a77cfbb..2cd86fc71d3c7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -50,7 +50,7 @@ from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException -from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.seed import seed_everything from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset @@ -1514,6 +1514,7 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *arg assert len(predictions) == 8 +@pytest.mark.skipif(_IS_WINDOWS and not _TORCH_GREATER_EQUAL_1_8, reason="torch.distributed support required") @patch("torch.cuda.device_count", return_value=2) @patch("torch.cuda.is_available", return_value=True) @pytest.mark.parametrize("accelerator", ("cpu", "gpu")) From 77ee0ef2cc57f57693abdb4db56d085674c07f2f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Dec 2021 02:52:33 +0000 Subject: [PATCH 091/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2cd86fc71d3c7..d5cbd37ac7496 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -50,7 +50,7 @@ from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException -from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_8 +from pytorch_lightning.utilities.imports import _IS_WINDOWS, _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.seed import seed_everything from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset From c7dd23d68ef26ec3bcf075521881c83fc1859794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Dec 2021 06:55:41 +0100 Subject: [PATCH 092/104] remove todo --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9ba281cca9d57..4050f71fc00fb 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -243,7 +243,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st } def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - # TODO: this todo is unclear, does it still apply? # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] From ec50a5ebe2dd36d663b6c1e7237cedf587b17d40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 8 Dec 2021 06:56:39 +0100 Subject: [PATCH 093/104] remove deletion of XLA_USE_BF16 env variable --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4050f71fc00fb..3914b37365ca3 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -243,9 +243,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st } def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - # todo: precision pluging is call in accelerator setup and should be moved - if "XLA_USE_BF16" in os.environ: - del os.environ["XLA_USE_BF16"] context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) From 82572c8b256223f56acb8cbdb285b32ecc96d7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 12 Dec 2021 22:54:59 -0500 Subject: [PATCH 094/104] add teardown method --- pytorch_lightning/plugins/precision/precision_plugin.py | 6 ++++++ pytorch_lightning/plugins/precision/tpu_bf16.py | 5 ++++- .../plugins/training_type/training_type_plugin.py | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 109be55b8dd63..85dead9aa7088 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -235,3 +235,9 @@ def predict_step_context(self) -> Generator[None, None, None]: """A contextmanager for the predict step.""" with self.forward_context(): yield + + def teardown(self) -> None: + """This method is called to teardown the training process. + + It is the right place to release memory and free other resources. + """ \ No newline at end of file diff --git a/pytorch_lightning/plugins/precision/tpu_bf16.py b/pytorch_lightning/plugins/precision/tpu_bf16.py index 0cece48ac8057..94254313b85be 100644 --- a/pytorch_lightning/plugins/precision/tpu_bf16.py +++ b/pytorch_lightning/plugins/precision/tpu_bf16.py @@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin): def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: - os.environ["XLA_USE_BF16"] = str(1) + os.environ["XLA_USE_BF16"] = "1" return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) + + def teardown(self) -> None: + os.environ.pop("XLA_USE_BF16", None) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index fc5de9863665a..ce6e84af3b7e4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -438,6 +438,7 @@ def teardown(self) -> None: It is the right place to release memory and free other resources. """ + self.precision_plugin.teardown() @classmethod def register_plugins(cls, plugin_registry) -> None: From 752b3820c6414d5eaa42f0e208b81a6299a54861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 12 Dec 2021 23:00:42 -0500 Subject: [PATCH 095/104] add changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65d3bfe92ac2f..ea83a139ae094 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700)) +- Added a `PrecisionPlugin.teardown` method ([#????](https://github.com/PyTorchLightning/pytorch-lightning/issues/????)) + + + ### Changed - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) @@ -274,6 +278,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed TypeError cause failure in `singal_connector` `teardown` method by adding None check ([#10961](https://github.com/PyTorchLightning/pytorch-lightning/pull/10961)) +- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#????](https://github.com/PyTorchLightning/pytorch-lightning/pull/????)) + + ## [1.5.4] - 2021-11-30 ### Fixed From 7840727ed94c2fedbd2c0245e28a633ed24195dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 13 Dec 2021 13:38:50 +0100 Subject: [PATCH 096/104] add test --- tests/models/test_tpu.py | 2 -- tests/plugins/precision/__init__.py | 0 .../plugins/precision/test_tpu_bf16_plugin.py | 25 +++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 tests/plugins/precision/__init__.py create mode 100644 tests/plugins/precision/test_tpu_bf16_plugin.py diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index ea8d430918e3f..ab62c9ebb840a 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir): model = BoringModel() tpipes.run_model_test(trainer_options, model, on_gpu=False) - assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables" @pytest.mark.parametrize("tpu_core", [1, 5]) @@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core): model = BoringModel() tpipes.run_model_test(trainer_options, model, on_gpu=False) assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}" - assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables" @RunIf(tpu=True) diff --git a/tests/plugins/precision/__init__.py b/tests/plugins/precision/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/plugins/precision/test_tpu_bf16_plugin.py b/tests/plugins/precision/test_tpu_bf16_plugin.py new file mode 100644 index 0000000000000..7610a22406082 --- /dev/null +++ b/tests/plugins/precision/test_tpu_bf16_plugin.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest.mock import Mock + +from pytorch_lightning.plugins import TPUBf16PrecisionPlugin + + +def test_teardown(): + plugin = TPUBf16PrecisionPlugin() + plugin.connect(Mock(), Mock(), Mock()) + assert "XLA_USE_BF16" in os.environ + plugin.teardown() + assert "XLA_USE_BF16" not in os.environ From 5eefe913bc1a66c73c868e859a9fb264205b130a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Dec 2021 13:04:27 +0000 Subject: [PATCH 097/104] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/precision/precision_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 85dead9aa7088..01b3c303f0681 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -240,4 +240,4 @@ def teardown(self) -> None: """This method is called to teardown the training process. It is the right place to release memory and free other resources. - """ \ No newline at end of file + """ From 4e1fd18a4c5e6a4cf9ab5c21d49a3b2d604bf4cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 14 Dec 2021 16:42:11 +0100 Subject: [PATCH 098/104] Update tests/plugins/precision/test_tpu_bf16_plugin.py Co-authored-by: Rohit Gupta --- tests/plugins/precision/test_tpu_bf16_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/precision/test_tpu_bf16_plugin.py b/tests/plugins/precision/test_tpu_bf16_plugin.py index 7610a22406082..abf02548fde7d 100644 --- a/tests/plugins/precision/test_tpu_bf16_plugin.py +++ b/tests/plugins/precision/test_tpu_bf16_plugin.py @@ -20,6 +20,6 @@ def test_teardown(): plugin = TPUBf16PrecisionPlugin() plugin.connect(Mock(), Mock(), Mock()) - assert "XLA_USE_BF16" in os.environ + assert os.environ.get("XLA_USE_BF16") == "1" plugin.teardown() assert "XLA_USE_BF16" not in os.environ From 6a7f462c1737be06bba59e0d6f47d330a85f570e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Dec 2021 03:59:24 +0100 Subject: [PATCH 099/104] rm --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 62a270a17fd5a..632d28eb9e9f5 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -244,9 +244,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st } def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - # todo: precision pluging is call in accelerator setup and should be moved - if "XLA_USE_BF16" in os.environ: - del os.environ["XLA_USE_BF16"] context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) From 756146568dbc6f2533ad87f4459e4aee42035b35 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 21 Dec 2021 14:15:27 +0530 Subject: [PATCH 100/104] Update CHANGELOG.md Co-authored-by: thomas chaton --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a083da05d985a..096c1d17d50f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,7 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875)) -- Added a `PrecisionPlugin.teardown` method ([#????](https://github.com/PyTorchLightning/pytorch-lightning/issues/????)) +- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990)) From 990063b0c66c440237d3afd4a6cad9bcce09efdd Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 21 Dec 2021 14:15:42 +0530 Subject: [PATCH 101/104] Update CHANGELOG.md Co-authored-by: thomas chaton --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 096c1d17d50f4..d0d049528d71e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -336,7 +336,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119)) -- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#????](https://github.com/PyTorchLightning/pytorch-lightning/pull/????)) +- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990)) ## [1.5.6] - 2021-12-15 From ce2be66a023f6152adcda3e7f43f1f1a27adf5de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Dec 2021 15:02:16 +0100 Subject: [PATCH 102/104] call missing super().teardown() --- pytorch_lightning/plugins/training_type/single_device.py | 1 + pytorch_lightning/plugins/training_type/single_tpu.py | 1 + pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 + 3 files changed, 3 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 480ae623c5a25..99441eee64a7c 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -86,6 +86,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: return obj def teardown(self) -> None: + super().teardown() if self.on_gpu: # GPU teardown self.lightning_module.cpu() diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index df438e08d2c59..34a079f675488 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -84,6 +84,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None: return self.checkpoint_io.save_checkpoint(checkpoint, filepath) def teardown(self) -> None: + super().teardown() # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 632d28eb9e9f5..112a4d01ad978 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -327,6 +327,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra return xm.all_gather(tensor) def teardown(self) -> None: + super().teardown() os.environ.pop("PT_XLA_DEBUG", None) @property From 4c5a480a7f14ca9fe198522fc05dc716e947185d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Dec 2021 17:04:11 +0100 Subject: [PATCH 103/104] remove abstract --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index f7a45efbaad66..2440b95a28726 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -437,7 +437,6 @@ def model_sharded_context(self) -> Generator: """ yield - @abstractmethod def teardown(self) -> None: """This method is called to teardown the training process. From 35282349f26a79cb8e256419f5002feeb1619bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 22 Dec 2021 04:10:44 +0100 Subject: [PATCH 104/104] reorder --- CHANGELOG.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6688c865943eb..2f6d78ce4aaa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -144,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123)) * Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183)) * Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142)) - * Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194)) - * Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193)) - * Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190)) - * Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186)) * Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185)) - * Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) - * Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145)) + * Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186)) * Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143)) - * Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) + * Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145)) * Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210)) + * Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194)) + * Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195)) + * Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190)) + * Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193)) + * Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) + * Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) - Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))