diff --git a/test/test_testing.py b/test/test_testing.py index d1bfe7c63ff5676..1add378326413a6 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -18,7 +18,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest, + (IS_FBCODE, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest, parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM) from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, @@ -1782,9 +1782,15 @@ def test_circular_dependencies(self) -> None: # See https://github.com/pytorch/pytorch/issues/77801 if not sys.version_info >= (3, 9): ignored_modules.append("torch.utils.benchmark") - if IS_WINDOWS: - # Distributed does not work on Windows - ignored_modules.append("torch.distributed.") + if IS_WINDOWS or IS_MACOS: + # Distributed should be importable on Windows(except nn.api.), but not on Mac + if IS_MACOS: + ignored_modules.append("torch.distributed.") + else: + ignored_modules.append("torch.distributed.nn.api.") + ignored_modules.append("torch.distributed.optim.") + ignored_modules.append("torch.distributed.pipeline.") + ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") torch_dir = os.path.dirname(torch.__file__) diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index fc259b545320f34..56ab4b074191ce3 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -10,6 +10,7 @@ import os import select import signal +import sys import threading import time from typing import Callable, Dict, List, Optional, Set, Tuple @@ -78,7 +79,8 @@ class FileTimerClient(TimerClient): signal: singal, the signal to use to kill the process. Using a negative or zero signal will not kill the process. """ - def __init__(self, file_path: str, signal=signal.SIGKILL) -> None: + def __init__(self, file_path: str, signal=(signal.SIGKILL if sys.platform != "win32" else + signal.CTRL_C_EVENT)) -> None: # type: ignore[attr-defined] super().__init__() self._file_path = file_path self.signal = signal