diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fa0710eb8da3..2f6d78ce4aaa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875)) +- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990)) + + + ### Changed - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) @@ -140,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123)) * Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183)) * Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142)) - * Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194)) - * Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193)) - * Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190)) - * Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186)) * Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185)) - * Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) - * Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145)) + * Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186)) * Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143)) - * Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) + * Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145)) * Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210)) + * Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194)) + * Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195)) + * Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190)) + * Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193)) + * Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) + * Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182)) - Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130)) @@ -337,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119)) +- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990)) + + ## [1.5.7] - 2021-12-21 diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index bddc8793f51e2..e875fe51f19e7 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -236,3 +236,9 @@ def predict_step_context(self) -> Generator[None, None, None]: """A contextmanager for the predict step.""" with self.forward_context(): yield + + def teardown(self) -> None: + """This method is called to teardown the training process. + + It is the right place to release memory and free other resources. + """ diff --git a/pytorch_lightning/plugins/precision/tpu_bf16.py b/pytorch_lightning/plugins/precision/tpu_bf16.py index 0cece48ac8057..94254313b85be 100644 --- a/pytorch_lightning/plugins/precision/tpu_bf16.py +++ b/pytorch_lightning/plugins/precision/tpu_bf16.py @@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin): def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: - os.environ["XLA_USE_BF16"] = str(1) + os.environ["XLA_USE_BF16"] = "1" return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) + + def teardown(self) -> None: + os.environ.pop("XLA_USE_BF16", None) diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index b077a448045b9..95458ba642d97 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -86,6 +86,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: return obj def teardown(self) -> None: + super().teardown() if self.on_gpu: # GPU teardown self.lightning_module.cpu() diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 7753e331c0577..a21b168925fcb 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -74,6 +74,7 @@ def model_to_device(self) -> None: self.model.to(self.root_device) def teardown(self) -> None: + super().teardown() # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2e6c579e84764..c3e6e5d623291 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -244,9 +244,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st } def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: - # todo: precision pluging is call in accelerator setup and should be moved - if "XLA_USE_BF16" in os.environ: - del os.environ["XLA_USE_BF16"] context = mp.get_context(self.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs()) @@ -340,6 +337,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra return xm.all_gather(tensor) def teardown(self) -> None: + super().teardown() os.environ.pop("PT_XLA_DEBUG", None) @classmethod diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 77611233e79d8..75561a34ca056 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -437,13 +437,13 @@ def model_sharded_context(self) -> Generator: """ yield - @abstractmethod def teardown(self) -> None: """This method is called to teardown the training process. It is the right place to release memory and free other resources. """ self._move_optimizer_state(torch.device("cpu")) + self.precision_plugin.teardown() @classmethod def register_plugins(cls, plugin_registry) -> None: diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 4e59761f2e934..833d6dd316bf5 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir): model = BoringModel() tpipes.run_model_test(trainer_options, model, on_gpu=False) - assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables" @pytest.mark.parametrize("tpu_core", [1, 5]) @@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core): model = BoringModel() tpipes.run_model_test(trainer_options, model, on_gpu=False) assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}" - assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables" @RunIf(tpu=True) diff --git a/tests/plugins/precision/__init__.py b/tests/plugins/precision/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/plugins/precision/test_tpu_bf16_plugin.py b/tests/plugins/precision/test_tpu_bf16_plugin.py new file mode 100644 index 0000000000000..abf02548fde7d --- /dev/null +++ b/tests/plugins/precision/test_tpu_bf16_plugin.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest.mock import Mock + +from pytorch_lightning.plugins import TPUBf16PrecisionPlugin + + +def test_teardown(): + plugin = TPUBf16PrecisionPlugin() + plugin.connect(Mock(), Mock(), Mock()) + assert os.environ.get("XLA_USE_BF16") == "1" + plugin.teardown() + assert "XLA_USE_BF16" not in os.environ