Skip to content

Commit

Permalink
Fix BF16 teardown for TPU precision plugin (#10990)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
6 people committed Dec 22, 2021
1 parent 235efb3 commit ba8e7cd
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 14 deletions.
22 changes: 15 additions & 7 deletions CHANGELOG.md
Expand Up @@ -55,6 +55,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))


- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))



### Changed

- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
Expand Down Expand Up @@ -140,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))


- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
Expand Down Expand Up @@ -337,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))


- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))



## [1.5.7] - 2021-12-21

Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Expand Up @@ -236,3 +236,9 @@ def predict_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the predict step."""
with self.forward_context():
yield

def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/precision/tpu_bf16.py
Expand Up @@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_USE_BF16"] = "1"
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)

def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/single_device.py
Expand Up @@ -86,6 +86,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def teardown(self) -> None:
super().teardown()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Expand Up @@ -74,6 +74,7 @@ def model_to_device(self) -> None:
self.model.to(self.root_device)

def teardown(self) -> None:
super().teardown()
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Expand Up @@ -244,9 +244,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
}

def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
context = mp.get_context(self.start_method or "fork")
return_queue = context.SimpleQueue()
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
Expand Down Expand Up @@ -340,6 +337,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
return xm.all_gather(tensor)

def teardown(self) -> None:
super().teardown()
os.environ.pop("PT_XLA_DEBUG", None)

@classmethod
Expand Down
Expand Up @@ -437,13 +437,13 @@ def model_sharded_context(self) -> Generator:
"""
yield

@abstractmethod
def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
self._move_optimizer_state(torch.device("cpu"))
self.precision_plugin.teardown()

@classmethod
def register_plugins(cls, plugin_registry) -> None:
Expand Down
2 changes: 0 additions & 2 deletions tests/models/test_tpu.py
Expand Up @@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir):

model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"


@pytest.mark.parametrize("tpu_core", [1, 5])
Expand All @@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"


@RunIf(tpu=True)
Expand Down
Empty file.
25 changes: 25 additions & 0 deletions tests/plugins/precision/test_tpu_bf16_plugin.py
@@ -0,0 +1,25 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock

from pytorch_lightning.plugins import TPUBf16PrecisionPlugin


def test_teardown():
plugin = TPUBf16PrecisionPlugin()
plugin.connect(Mock(), Mock(), Mock())
assert os.environ.get("XLA_USE_BF16") == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ

0 comments on commit ba8e7cd

Please sign in to comment.