Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchrun AttributeError caused by file_based_local_timer on Windows #85427

Closed
ejguan opened this issue Sep 21, 2022 · 8 comments 路 Fixed by Lightning-AI/pytorch-lightning#15645
Closed
Labels
high priority module: windows Windows support for PyTorch oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue triage review
Milestone

Comments

@ejguan
Copy link
Contributor

ejguan commented Sep 21, 2022

馃悰 Describe the bug

During import time of torchrun, an AttributeError is raised because module 'signal' has no attribute 'SIGKILL'. Here is the culprit:

def __init__(self, file_path: str, signal=signal.SIGKILL) -> None:

I encountered such problem when running subprocess.run(["torchrun", ..., "script.py"])

Versions

PyTorch Nightly

cc @ezyang @gchanan @zou3519 @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @peterjc123 @mszhanyi @skyline75489 @nbcsm @pietern @SciPioneer

@ezyang ezyang added oncall: distributed Add this issue/PR to distributed oncall triage queue module: windows Windows support for PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 21, 2022
@daavoo
Copy link

daavoo commented Oct 31, 2022

This has been now released in v1.13.0 and it has caused some CI failures on windows

daavoo added a commit to iterative/dvclive that referenced this issue Oct 31, 2022
daavoo added a commit to iterative/dvclive that referenced this issue Oct 31, 2022
* fastai: Remove ProgressCallback in tests.

Per fastai/fastai#3809

* pin PytTorch version.

Per pytorch/pytorch#85427
@H-Huang
Copy link
Member

H-Huang commented Oct 31, 2022

cc @bchen2020 @d4l3k windows breakage

@H-Huang H-Huang added oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 31, 2022
@d4l3k
Copy link
Collaborator

d4l3k commented Oct 31, 2022

cc @clee2000 @janeyx99 who disabled distributed tests in windows in #76848

Are we not supporting distributed on windows now? "minimal interest in maintaining distributed for windows". If that's the case we should make that explicit and not do it via quiet breakages like this

cc @malfet

@malfet
Copy link
Contributor

malfet commented Oct 31, 2022

I'm willing to stamp #76848 post factum: Windows distributed issues has been piling up and nobody on the engineering side were looking into fixing those (which this issue perfectly highlights, as it was filed on Sep 30th and nobody cared to submit a fix for it or mark it as release blocking.

By the way, I've added the test that imports all public package names, let me check if torch.distribtued. is just blocklisted on Windows, as I could not get an ack from any maintainers that they care about it.

pytorch/test/test_testing.py

Lines 1785 to 1787 in ff94494

if IS_WINDOWS or IS_MACOS:
# Distributed does not work on Windows or by default on Mac
ignored_modules.append("torch.distributed.")

@d4l3k
Copy link
Collaborator

d4l3k commented Oct 31, 2022

We can just add an import test for torch.distributed.elastic if distributed is broken. Though if distributed doesn't work doesn't seem like torchrun has much value

@daavoo @ejguan can you share a bit about how you're using torchrun on windows?

cc @kiukchung

@kiukchung
Copy link
Collaborator

The straight forward fix is to get the platform-based kill signal from

def _get_kill_signal() -> signal.Signals:

instead of defaulting it to signal.SIGKILL in the offending ctor above.

@daavoo
Copy link

daavoo commented Oct 31, 2022

@daavoo @ejguan can you share a bit about how you're using torchrun on windows?

Well ... I am not using torchrun on windows at all 馃槄

I just encountered the error in CI while importing a downstream library (transformers):

Traceback
_________________ ERROR collecting tests/test_huggingface.py __________________
.nox\tests-3-9\lib\site-packages\transformers\utils\import_utils.py:1063: in _get_module
    return importlib.import_module("." + module_name, self.__name__)
C:\hostedtoolcache\windows\Python\3.9.13\x64\lib\importlib\__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1030: in _gcd_import
    ???
<frozen importlib._bootstrap>:1007: in _find_and_load
    ???
<frozen importlib._bootstrap>:986: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:680: in _load_unlocked
    ???
<frozen importlib._bootstrap_external>:850: in exec_module
    ???
<frozen importlib._bootstrap>:228: in _call_with_frames_removed
    ???
.nox\tests-3-9\lib\site-packages\transformers\modeling_utils.py:78: in <module>
    from accelerate import __version__ as accelerate_version
.nox\tests-3-9\lib\site-packages\accelerate\__init__.py:7: in <module>
    from .accelerator import Accelerator
.nox\tests-3-9\lib\site-packages\accelerate\accelerator.py:27: in <module>
    from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
.nox\tests-3-9\lib\site-packages\accelerate\checkpointing.py:24: in <module>
    from .utils import (
.nox\tests-3-9\lib\site-packages\accelerate\utils\__init__.py:96: in <module>
    from .launch import PrepareForLaunch, _filter_args, get_launch_prefix
.nox\tests-3-9\lib\site-packages\accelerate\utils\launch.py:25: in <module>
    import torch.distributed.run as distrib_run
.nox\tests-3-9\lib\site-packages\torch\distributed\run.py:386: in <module>
    from torch.distributed.launcher.api import LaunchConfig, elastic_launch
.nox\tests-3-9\lib\site-packages\torch\distributed\launcher\__init__.py:10: in <module>
    from torch.distributed.launcher.api import (  # noqa: F401
.nox\tests-3-9\lib\site-packages\torch\distributed\launcher\api.py:16: in <module>
    from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
.nox\tests-3-9\lib\site-packages\torch\distributed\elastic\agent\server\__init__.py:40: in <module>
    from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE
.nox\tests-3-9\lib\site-packages\torch\distributed\elastic\agent\server\local_elastic_agent.py:19: in <module>
    import torch.distributed.elastic.timer as timer
.nox\tests-3-9\lib\site-packages\torch\distributed\elastic\timer\__init__.py:44: in <module>
    from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest  # noqa: F401
.nox\tests-3-9\lib\site-packages\torch\distributed\elastic\timer\file_based_local_timer.py:63: in <module>
    class FileTimerClient(TimerClient):
.nox\tests-3-9\lib\site-packages\torch\distributed\elastic\timer\file_based_local_timer.py:81: in FileTimerClient
    def __init__(self, file_path: str, signal=signal.SIGKILL) -> None:
E   AttributeError: module 'signal' has no attribute 'SIGKILL'

The above exception was the direct cause of the following exception:
tests\test_huggingface.py:7: in <module>
    from transformers import (
<frozen importlib._bootstrap>:1055: in _handle_fromlist
    ???
.nox\tests-3-9\lib\site-packages\transformers\utils\import_utils.py:1053: in __getattr__
    module = self._get_module(self._class_to_module[name])
.nox\tests-3-9\lib\site-packages\transformers\utils\import_utils.py:1065: in _get_module
    raise RuntimeError(
E   RuntimeError: Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
E   module 'signal' has no attribute 'SIGKILL'

@awaelchli
Copy link
Contributor

Encountered this issue the same way as @daavoo in CI when importing torchrun. I believe a fix could be to change the implementation to:

    def __init__(self, file_path: str, signal=None) -> None:
        super().__init__()
        self._file_path = file_path
        self.signal = signal.SIGKILL if signal is None else signal

@malfet malfet added this to the 1.13.1 milestone Nov 2, 2022
izaitsevfb pushed a commit to izaitsevfb/pytorch that referenced this issue Dec 2, 2022
Also, add `torch.distributed` to test imports, so that we would not
regress in the future

Fixes pytorch#85427
Pull Request resolved: pytorch#88522
Approved by: https://github.com/d4l3k

(cherry picked from commit f98edfc)
izaitsevfb pushed a commit to izaitsevfb/pytorch that referenced this issue Dec 6, 2022
Also, add `torch.distributed` to test imports, so that we would not
regress in the future

Fixes pytorch#85427
Pull Request resolved: pytorch#88522
Approved by: https://github.com/d4l3k

(cherry picked from commit f98edfc)
atalman pushed a commit that referenced this issue Dec 6, 2022
Also, add `torch.distributed` to test imports, so that we would not
regress in the future

Fixes #85427
Pull Request resolved: #88522
Approved by: https://github.com/d4l3k

(cherry picked from commit f98edfc)

Co-authored-by: Nikita Shulga <nshulga@meta.com>
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
Also, add `torch.distributed` to test imports, so that we would not
regress in the future

Fixes pytorch#85427
Pull Request resolved: pytorch#88522
Approved by: https://github.com/d4l3k
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: windows Windows support for PyTorch oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: r2p Add this issue/PR to R2P (elastic) oncall triage queue triage review
Projects
Status: Done
8 participants