Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BF16 teardown for TPU precision plugin #10990

Merged
merged 131 commits into from Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
d05363f
improve spawn queue
awaelchli Oct 20, 2021
d650e26
clean up
awaelchli Oct 20, 2021
5fda23a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2021
d6b4a34
Merge branch 'master' into feature/simple-spawn
awaelchli Nov 30, 2021
bcfb853
fix
awaelchli Nov 30, 2021
97b4bf6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
38b3a54
rename
awaelchli Nov 30, 2021
955b6c8
delete dead code
awaelchli Nov 30, 2021
13393e8
Merge remote-tracking branch 'origin/feature/simple-spawn' into featu…
awaelchli Nov 30, 2021
f3216b2
clean up
awaelchli Nov 30, 2021
2d00231
update lite
awaelchli Nov 30, 2021
7aa3646
retain the queue interface in hooks
awaelchli Nov 30, 2021
fb0c0d8
update tests
awaelchli Nov 30, 2021
1bc59ae
Merge branch 'master' into feature/simple-spawn
awaelchli Nov 30, 2021
7e6c75e
_notebooks
awaelchli Nov 30, 2021
b7efc50
reset notebooks
awaelchli Nov 30, 2021
84ca8b4
avoid circular import
awaelchli Nov 30, 2021
965c724
fix unused imports
awaelchli Nov 30, 2021
1aae8dd
reset debugging script
awaelchli Nov 30, 2021
4b998db
typing _ExtraQueue
awaelchli Nov 30, 2021
5871a4b
bring changes to tpu_spawn plugin
awaelchli Nov 30, 2021
aa76840
unify
awaelchli Nov 30, 2021
37f9db9
remove dead code
awaelchli Nov 30, 2021
d68cb35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
dd80be9
remove queue from tpu spawn
awaelchli Nov 30, 2021
f97eee8
type annotation for new_process
awaelchli Nov 30, 2021
ad61d27
Merge remote-tracking branch 'origin/feature/simple-spawn' into refac…
awaelchli Nov 30, 2021
459121e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
72535ff
unused imports
awaelchli Nov 30, 2021
3095da9
Merge remote-tracking branch 'origin/feature/simple-spawn' into refac…
awaelchli Nov 30, 2021
61192df
move check
awaelchli Nov 30, 2021
801f529
revert
awaelchli Nov 30, 2021
1cd258b
collect results on tpu
awaelchli Nov 30, 2021
ae6019e
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Nov 30, 2021
10ecbfd
rename
awaelchli Nov 30, 2021
ebba63f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2021
d7df4d9
fix merge errors
awaelchli Nov 30, 2021
4c547aa
fix merge errors
awaelchli Nov 30, 2021
e4e2a77
re-add clean_logger
awaelchli Dec 1, 2021
86e43b2
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
acac29d
fix typing
awaelchli Dec 1, 2021
0ae457a
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
880c8fc
changelog entries
awaelchli Dec 1, 2021
5eeb02a
Merge branch 'master' into refactor/spawn/simple-spawn
awaelchli Dec 1, 2021
7520adc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2021
96f2749
rename _ExtraQueue -> _FakeQueue
awaelchli Dec 1, 2021
65d183c
missing typing updates
awaelchli Dec 1, 2021
8c4e2e4
Introducing NamedTuple for spawn output typing
awaelchli Dec 1, 2021
213b447
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2021
de4617f
remove post_dispatch
awaelchli Dec 2, 2021
815172e
step 1
awaelchli Dec 2, 2021
be735bd
update flow
awaelchli Dec 2, 2021
2879ccb
fix it
awaelchli Dec 2, 2021
ace196e
jackpot!
awaelchli Dec 2, 2021
4ff41a9
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 2, 2021
34a889a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2021
ad3f39d
update sharded and tests
awaelchli Dec 2, 2021
c897a20
pull down spawn call
awaelchli Dec 2, 2021
90054cf
simplify test
awaelchli Dec 2, 2021
009abfa
attach model as early as possible
awaelchli Dec 2, 2021
376e4fe
demonstrate which tests fails
awaelchli Dec 2, 2021
de1811e
set module
awaelchli Dec 3, 2021
ef61a0b
update exception
awaelchli Dec 3, 2021
809014a
imports
awaelchli Dec 3, 2021
440b639
transfer trainer state
awaelchli Dec 3, 2021
ab5559e
fix problem with getqueue
awaelchli Dec 3, 2021
f4f1269
deprecation calls don't come through ddp_spawn
awaelchli Dec 3, 2021
b30c352
prepare data only gets called on rank 0
awaelchli Dec 3, 2021
5434ae5
import
awaelchli Dec 3, 2021
24f05f1
update test
awaelchli Dec 3, 2021
3959955
update exception
awaelchli Dec 3, 2021
f491abe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
0c808ce
adapt tpu spawn
awaelchli Dec 3, 2021
d6dd343
imports
awaelchli Dec 3, 2021
63e6e21
Merge remote-tracking branch 'origin/refactor/spawn/dispatch' into re…
awaelchli Dec 3, 2021
15dabb8
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 3, 2021
b047687
update
awaelchli Dec 3, 2021
c524e52
add missing arg
awaelchli Dec 3, 2021
223e7aa
fix exception import on torch < 1.8
awaelchli Dec 3, 2021
ed309d6
debug
awaelchli Dec 3, 2021
12eed61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
9401e66
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 3, 2021
be73261
debug tpu
awaelchli Dec 3, 2021
c71fc57
fix docs
awaelchli Dec 3, 2021
2ed6333
fix teardown being called twice
awaelchli Dec 3, 2021
2a8b9b4
revert a sate check
awaelchli Dec 3, 2021
5335664
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
5c7e159
Merge remote-tracking branch 'origin/refactor/spawn/dispatch' into re…
awaelchli Dec 3, 2021
93cfaf8
fix
awaelchli Dec 3, 2021
26408b8
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 4, 2021
70b332d
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
d9669c7
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
3663bd7
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
3d81c11
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
fb47802
Merge branch 'master' into refactor/spawn/dispatch
awaelchli Dec 6, 2021
dde5a3a
reset bug report model
awaelchli Dec 6, 2021
77329b2
fix merge error
awaelchli Dec 6, 2021
eb05fc9
barrier clean ups
awaelchli Dec 7, 2021
dbcb76c
update comments in trainer
awaelchli Dec 7, 2021
ed0defa
unused import
awaelchli Dec 7, 2021
79975f2
debug
awaelchli Dec 7, 2021
d5ec0b7
changelog
awaelchli Dec 7, 2021
b2f8347
update changelog
awaelchli Dec 7, 2021
d8e6218
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2021
436572b
update changelog
awaelchli Dec 7, 2021
a3bc1b1
Update tests/trainer/test_trainer.py
awaelchli Dec 7, 2021
b2ce8eb
Merge remote-tracking branch 'origin/refactor/spawn/dispatch' into re…
awaelchli Dec 7, 2021
bafd95c
add clarification comment
awaelchli Dec 8, 2021
338605a
update changelog
awaelchli Dec 8, 2021
c992a55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
ac1428d
skip test that can't run on too old torch version on windows
awaelchli Dec 8, 2021
77ee0ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
c7dd23d
remove todo
awaelchli Dec 8, 2021
ec50a5e
remove deletion of XLA_USE_BF16 env variable
awaelchli Dec 8, 2021
82572c8
add teardown method
awaelchli Dec 13, 2021
752b382
add changelog
awaelchli Dec 13, 2021
7840727
add test
awaelchli Dec 13, 2021
b0a7b48
Merge branch 'master' into bugfix/bf16-env-variable
awaelchli Dec 13, 2021
5eefe91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2021
4e1fd18
Update tests/plugins/precision/test_tpu_bf16_plugin.py
awaelchli Dec 14, 2021
4599c33
Merge branch 'master' into bugfix/bf16-env-variable
awaelchli Dec 14, 2021
a891411
Merge branch 'master' into bugfix/bf16-env-variable
awaelchli Dec 21, 2021
6a7f462
rm
awaelchli Dec 21, 2021
7561465
Update CHANGELOG.md
kaushikb11 Dec 21, 2021
990063b
Update CHANGELOG.md
kaushikb11 Dec 21, 2021
ce2be66
call missing super().teardown()
awaelchli Dec 21, 2021
a93a15f
Merge branch 'master' into bugfix/bf16-env-variable
awaelchli Dec 21, 2021
4c5a480
remove abstract
awaelchli Dec 21, 2021
ac480f4
Merge branch 'master' into bugfix/bf16-env-variable
awaelchli Dec 21, 2021
db5c6cb
Merge branch 'master' into bugfix/bf16-env-variable
awaelchli Dec 22, 2021
3528234
reorder
awaelchli Dec 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 15 additions & 7 deletions CHANGELOG.md
Expand Up @@ -55,6 +55,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))


- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))



### Changed

- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
Expand Down Expand Up @@ -140,16 +144,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Renamed the `ParallelPlugin` to `ParallelStrategy` ([#11123](https://github.com/PyTorchLightning/pytorch-lightning/pull/11123))
* Renamed the `DataParallelPlugin` to `DataParallelStrategy` ([#11183](https://github.com/PyTorchLightning/pytorch-lightning/pull/11183))
* Renamed the `DDPPlugin` to `DDPStrategy` ([#11142](https://github.com/PyTorchLightning/pytorch-lightning/pull/11142))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDP2Plugin` to `DDP2Strategy` ([#11185](https://github.com/PyTorchLightning/pytorch-lightning/pull/11185))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPShardedPlugin` to `DDPShardedStrategy` ([#11186](https://github.com/PyTorchLightning/pytorch-lightning/pull/11186))
* Renamed the `DDPFullyShardedPlugin` to `DDPFullyShardedStrategy` ([#11143](https://github.com/PyTorchLightning/pytorch-lightning/pull/11143))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `DDPSpawnPlugin` to `DDPSpawnStrategy` ([#11145](https://github.com/PyTorchLightning/pytorch-lightning/pull/11145))
* Renamed the `DDPSpawnShardedPlugin` to `DDPSpawnShardedStrategy` ([#11210](https://github.com/PyTorchLightning/pytorch-lightning/pull/11210))
* Renamed the `DeepSpeedPlugin` to `DeepSpeedStrategy` ([#11194](https://github.com/PyTorchLightning/pytorch-lightning/pull/11194))
* Renamed the `HorovodPlugin` to `HorovodStrategy` ([#11195](https://github.com/PyTorchLightning/pytorch-lightning/pull/11195))
* Renamed the `TPUSpawnPlugin` to `TPUSpawnStrategy` ([#11190](https://github.com/PyTorchLightning/pytorch-lightning/pull/11190))
* Renamed the `IPUPlugin` to `IPUStrategy` ([#11193](https://github.com/PyTorchLightning/pytorch-lightning/pull/11193))
* Renamed the `SingleDevicePlugin` to `SingleDeviceStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))
* Renamed the `SingleTPUPlugin` to `SingleTPUStrategy` ([#11182](https://github.com/PyTorchLightning/pytorch-lightning/pull/11182))


- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
Expand Down Expand Up @@ -337,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))


- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))



## [1.5.7] - 2021-12-21

Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Expand Up @@ -236,3 +236,9 @@ def predict_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the predict step."""
with self.forward_context():
yield

def teardown(self) -> None:
"""This method is called to teardown the training process.

It is the right place to release memory and free other resources.
"""
5 changes: 4 additions & 1 deletion pytorch_lightning/plugins/precision/tpu_bf16.py
Expand Up @@ -28,5 +28,8 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[nn.Module, List[Optimizer], List[Any]]:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_USE_BF16"] = "1"
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)

def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/single_device.py
Expand Up @@ -86,6 +86,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
return obj

def teardown(self) -> None:
super().teardown()
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Expand Up @@ -74,6 +74,7 @@ def model_to_device(self) -> None:
self.model.to(self.root_device)

def teardown(self) -> None:
super().teardown()
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Expand Up @@ -244,9 +244,6 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st
}

def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
# todo: precision pluging is call in accelerator setup and should be moved
if "XLA_USE_BF16" in os.environ:
del os.environ["XLA_USE_BF16"]
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
context = mp.get_context(self.start_method or "fork")
return_queue = context.SimpleQueue()
xmp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), **self.get_mp_spawn_kwargs())
Expand Down Expand Up @@ -340,6 +337,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
return xm.all_gather(tensor)

def teardown(self) -> None:
super().teardown()
os.environ.pop("PT_XLA_DEBUG", None)

@classmethod
Expand Down
Expand Up @@ -437,13 +437,13 @@ def model_sharded_context(self) -> Generator:
"""
yield

@abstractmethod
def teardown(self) -> None:
"""This method is called to teardown the training process.

It is the right place to release memory and free other resources.
"""
self._move_optimizer_state(torch.device("cpu"))
self.precision_plugin.teardown()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def register_plugins(cls, plugin_registry) -> None:
Expand Down
2 changes: 0 additions & 2 deletions tests/models/test_tpu.py
Expand Up @@ -122,7 +122,6 @@ def test_model_16bit_tpu_cores_1(tmpdir):

model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"


@pytest.mark.parametrize("tpu_core", [1, 5])
Expand All @@ -144,7 +143,6 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}"
assert os.environ.get("XLA_USE_BF16") == str(1), "XLA_USE_BF16 was not set in environment variables"
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


@RunIf(tpu=True)
Expand Down
Empty file.
25 changes: 25 additions & 0 deletions tests/plugins/precision/test_tpu_bf16_plugin.py
@@ -0,0 +1,25 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock

from pytorch_lightning.plugins import TPUBf16PrecisionPlugin


def test_teardown():
plugin = TPUBf16PrecisionPlugin()
plugin.connect(Mock(), Mock(), Mock())
assert os.environ.get("XLA_USE_BF16") == "1"
plugin.teardown()
assert "XLA_USE_BF16" not in os.environ