diff --git a/CHANGELOG.md b/CHANGELOG.md index fd9cf54a9730e..c2cdc5fda828e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) + * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) - diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1740518923c0f..7c6e4f4048181 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os from functools import wraps from platform import python_version -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -376,3 +376,28 @@ def init_dist_connection( f"All distributed processes registered. Starting with {world_size} processes\n" f"{'-' * 100}\n" ) + + +def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]: + """This distributed utility collects dictionary state across all processes. + + Args: + state: Dictionary containing the state of the current process + device: Current process device. + + Returns: + states: On global rank 0, a dictionary where the primary keys are + the process rank and the values their associated states. Otherwise, returns None. + """ + if not distributed_available(): + return {0: state} + states = {} + current_rank = torch.distributed.get_rank() + for rank in range(1, torch.distributed.get_world_size()): + objects = [state if current_rank == rank else None] + torch.distributed.broadcast_object_list(objects, src=rank, device=device) + states[rank] = objects[0] + if current_rank != 0: + return None + states[0] = state + return states diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index e27b4264df126..a48b4486a470f 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -16,6 +16,12 @@ from unittest import mock import pytest +import torch +import torch.multiprocessing as mp + +import tests.helpers.utils as tutils +from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero +from tests.helpers.runif import RunIf @pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}]) @@ -53,3 +59,27 @@ def foo(): x = foo() assert x is None + + +def _test_collect_states(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + state = {"something": torch.tensor([rank])} + collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}")) + if rank == 0: + assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} + else: + assert collected_state is None + + +@RunIf(skip_windows=True, min_gpus=2, min_torch="1.10") +def test_collect_states(): + """This test ensures state are properly collected across processes. + + This would be used to collect dataloader states as an example. + """ + tutils.set_random_main_port() + mp.spawn(_test_collect_states, args=(2,), nprocs=2)