From 256213db6cb936179b2482f67f1604f8fc08328d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Dec 2022 11:50:08 +0100 Subject: [PATCH] Re-enable Lite CLI on Windows + PyTorch 1.13 (#15645) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- src/lightning_fabric/cli.py | 9 --------- tests/tests_fabric/test_cli.py | 21 +-------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/src/lightning_fabric/cli.py b/src/lightning_fabric/cli.py index 913ed4d98cf17..c294dccf21e14 100644 --- a/src/lightning_fabric/cli.py +++ b/src/lightning_fabric/cli.py @@ -148,15 +148,6 @@ def _get_num_processes(accelerator: str, devices: str) -> int: def _torchrun_launch(args: Namespace, script_args: List[str]) -> None: """This will invoke `torchrun` programmatically to launch the given script in new processes.""" - - if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13: # pragma: no cover - # TODO: remove once import issue is resolved: https://github.com/pytorch/pytorch/issues/85427 - _log.error( - "On the Windows platform, this launcher is currently only supported on torch < 1.13 due to a bug" - " upstream: https://github.com/pytorch/pytorch/issues/85427" - ) - raise SystemExit(1) - import torch.distributed.run as torchrun if args.strategy == "dp": diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index 87fc3dd528e6e..3536ed6f6e990 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -16,22 +16,12 @@ from unittest.mock import Mock import pytest +import torch.distributed.run from tests_fabric.helpers.runif import RunIf from lightning_fabric.cli import _run_model from lightning_fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13 -if not (_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13): - import torch.distributed.run - - -def skip_windows_pt_1_13(): - # https://github.com/pytorch/pytorch/issues/85427 - return pytest.mark.skipif( - condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13), - reason="Torchelastic import bug in 1.13 affecting Windows", - ) - @pytest.fixture def fake_script(tmp_path): @@ -40,7 +30,6 @@ def fake_script(tmp_path): return str(script) -@skip_windows_pt_1_13() @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_defaults(monkeypatch, fake_script): monkeypatch.setattr(torch.distributed, "run", Mock()) @@ -55,7 +44,6 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script): assert os.environ["LT_PRECISION"] == "32" -@skip_windows_pt_1_13() @pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2) @@ -67,7 +55,6 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script): assert os.environ["LT_ACCELERATOR"] == accelerator -@skip_windows_pt_1_13() @pytest.mark.parametrize("strategy", ["dp", "ddp", "deepspeed"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2) @@ -79,7 +66,6 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script): assert os.environ["LT_STRATEGY"] == strategy -@skip_windows_pt_1_13() @pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning_fabric.accelerators.cuda.num_cuda_devices", return_value=2) @@ -92,7 +78,6 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script): @RunIf(mps=True) -@skip_windows_pt_1_13() @pytest.mark.parametrize("accelerator", ["mps", "gpu"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script): @@ -103,7 +88,6 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script): assert os.environ["LT_DEVICES"] == "1" -@skip_windows_pt_1_13() @pytest.mark.parametrize("num_nodes", ["1", "2", "3"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): @@ -114,7 +98,6 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): assert os.environ["LT_NUM_NODES"] == num_nodes -@skip_windows_pt_1_13() @pytest.mark.parametrize("precision", ["64", "32", "16", "bf16"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_precision(precision, monkeypatch, fake_script): @@ -125,7 +108,6 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script): assert os.environ["LT_PRECISION"] == precision -@skip_windows_pt_1_13() @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_torchrun_defaults(monkeypatch, fake_script): torchrun_mock = Mock() @@ -145,7 +127,6 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script): ) -@skip_windows_pt_1_13() @pytest.mark.parametrize( "devices,expected", [