Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Add support for collecting states across processes #10639

Merged
merged 13 commits into from Nov 23, 2021
27 changes: 26 additions & 1 deletion pytorch_lightning/utilities/distributed.py
Expand Up @@ -16,7 +16,7 @@
import os
from functools import wraps
from platform import python_version
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.nn.parallel.distributed import DistributedDataParallel
Expand Down Expand Up @@ -376,3 +376,28 @@ def init_dist_connection(
f"All distributed processes registered. Starting with {world_size} processes\n"
f"{'-' * 100}\n"
)


def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""This distributed utility collects dictionary state across all processes.

Args:
state: Dictionary contain the state of the current process
tchaton marked this conversation as resolved.
Show resolved Hide resolved
device: Current process device.

Returns:
states: On global rank 0, a dictionary where the primary keys are
the process rank and the values their associated states. Otherwise, returns None.
"""
if not distributed_available():
return {0: state}
states = {}
current_rank = torch.distributed.get_rank()
for rank in range(1, torch.distributed.get_world_size()):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
objects = [state if current_rank == rank else None]
torch.distributed.broadcast_object_list(objects, src=rank, device=device)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
states[rank] = objects[0]
if current_rank != 0:
return None
states[0] = state
return states
29 changes: 29 additions & 0 deletions tests/utilities/test_distributed.py
Expand Up @@ -16,6 +16,12 @@
from unittest import mock

import pytest
import torch
import torch.multiprocessing as mp

import tests.helpers.utils as tutils
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
from tests.helpers.runif import RunIf


@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}])
Expand Down Expand Up @@ -53,3 +59,26 @@ def foo():

x = foo()
assert x is None


def _test_collect_states(rank, worldsize):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
os.environ["MASTER_ADDR"] = "localhost"

# initialize the process group
torch.distributed.init_process_group("nccl", rank=rank, world_size=worldsize)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

state = {"something": torch.tensor([rank])}
collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}"))
if rank == 1:
assert collected_state is None
else:
assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}}
tchaton marked this conversation as resolved.
Show resolved Hide resolved

torch.distributed.destroy_process_group()


@RunIf(skip_windows=True, min_gpus=2, min_torch="1.10")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def test_collect_states():
"""Make sure result logging works with DDP."""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tutils.set_random_main_port()
mp.spawn(_test_collect_states, args=(2,), nprocs=2)