From 3463e59924cedfb15d5f26b455320e25adbd630c Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 19 Nov 2021 08:53:32 -0500 Subject: [PATCH 01/11] update --- pytorch_lightning/utilities/distributed.py | 32 +++++++++++++++++++++- tests/utilities/test_distributed.py | 27 ++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1740518923c0f..e30d0141f87a4 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,12 +16,13 @@ import os from functools import wraps from platform import python_version -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE if _TPU_AVAILABLE: @@ -376,3 +377,32 @@ def init_dist_connection( f"All distributed processes registered. Starting with {world_size} processes\n" f"{'-' * 100}\n" ) + + +def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device): + """This distributed utility collects dictionary state across all processes. + + Args: + state: Dictionary contain the state of the current process + device: Current process device. + + Returns: + states: On global rank 0, a dictionary where the primary keys are + the process rank and the values their associated states. Otherwise, returns None. + """ + if not distributed_available(): + return {0: state} + states = {} + state = apply_to_collection(state, torch.Tensor, lambda x: x.to(device)) + for rank in range(1, torch.distributed.get_world_size()): + if torch.distributed.get_rank() == rank: + # Assumes world_size of 3. + objects = [state] + else: + objects = [None] + torch.distributed.broadcast_object_list(objects, src=rank, device=device) + states[rank] = objects[0] + if torch.distributed.get_rank() != 0: + return None + states[0] = state + return apply_to_collection(states, torch.Tensor, lambda x: x.to("cpu")) diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index e27b4264df126..2970fef46ee21 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -16,6 +16,12 @@ from unittest import mock import pytest +import torch +import torch.multiprocessing as mp + +import tests.helpers.utils as tutils +from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero +from tests.helpers.runif import RunIf @pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}]) @@ -53,3 +59,24 @@ def foo(): x = foo() assert x is None + + +def _test_collect_states(rank, worldsize): + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + torch.distributed.init_process_group("nccl", rank=rank, world_size=worldsize) + + state = {"something": torch.tensor([rank])} + collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}")) + if rank == 1: + assert collected_state is None + else: + assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} + + +@RunIf(skip_windows=True, min_gpus=2) +def test_collect_states(): + """Make sure result logging works with DDP.""" + tutils.set_random_main_port() + mp.spawn(_test_collect_states, args=(2,), nprocs=2) From e48ca64eb0e98b2f7850952fcfe0f032155e3eff Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 14:02:58 +0000 Subject: [PATCH 02/11] update --- pytorch_lightning/utilities/distributed.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index e30d0141f87a4..7fc647f5a691e 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -379,7 +379,7 @@ def init_dist_connection( ) -def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device): +def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Dict[str, Any]]]: """This distributed utility collects dictionary state across all processes. Args: @@ -396,7 +396,6 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device): state = apply_to_collection(state, torch.Tensor, lambda x: x.to(device)) for rank in range(1, torch.distributed.get_world_size()): if torch.distributed.get_rank() == rank: - # Assumes world_size of 3. objects = [state] else: objects = [None] From 90e3c4a7212ca3c54cacc11b56b754266dcc217b Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 19 Nov 2021 09:25:33 -0500 Subject: [PATCH 03/11] update --- pytorch_lightning/utilities/distributed.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 7fc647f5a691e..6249f63ec3d16 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import os from functools import wraps @@ -394,12 +395,17 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> return {0: state} states = {} state = apply_to_collection(state, torch.Tensor, lambda x: x.to(device)) + # `broadcast_object_list` API changed between PyTorch 1.9.x and 1.10.x devices has been added + params = inspect.signature(torch.distributed.broadcast_object_list).parameters for rank in range(1, torch.distributed.get_world_size()): if torch.distributed.get_rank() == rank: objects = [state] else: objects = [None] - torch.distributed.broadcast_object_list(objects, src=rank, device=device) + kwargs = dict(src=rank) + if "device" in params: + kwargs["device"] = device + torch.distributed.broadcast_object_list(objects, **kwargs) states[rank] = objects[0] if torch.distributed.get_rank() != 0: return None From 85d709cde85ca8386769af034930cdb105daa42e Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 19 Nov 2021 09:50:10 -0500 Subject: [PATCH 04/11] update --- pytorch_lightning/utilities/distributed.py | 12 ++---------- tests/utilities/test_distributed.py | 3 ++- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 6249f63ec3d16..906d1ce847384 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import os from functools import wraps @@ -23,7 +22,6 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE if _TPU_AVAILABLE: @@ -394,20 +392,14 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> if not distributed_available(): return {0: state} states = {} - state = apply_to_collection(state, torch.Tensor, lambda x: x.to(device)) - # `broadcast_object_list` API changed between PyTorch 1.9.x and 1.10.x devices has been added - params = inspect.signature(torch.distributed.broadcast_object_list).parameters for rank in range(1, torch.distributed.get_world_size()): if torch.distributed.get_rank() == rank: objects = [state] else: objects = [None] - kwargs = dict(src=rank) - if "device" in params: - kwargs["device"] = device - torch.distributed.broadcast_object_list(objects, **kwargs) + torch.distributed.broadcast_object_list(objects, src=rank, device=device) states[rank] = objects[0] if torch.distributed.get_rank() != 0: return None states[0] = state - return apply_to_collection(states, torch.Tensor, lambda x: x.to("cpu")) + return states diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index 2970fef46ee21..4e11d9035fece 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -72,10 +72,11 @@ def _test_collect_states(rank, worldsize): if rank == 1: assert collected_state is None else: + print(collected_state) assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} -@RunIf(skip_windows=True, min_gpus=2) +@RunIf(skip_windows=True, min_gpus=2, min_torch="1.10") def test_collect_states(): """Make sure result logging works with DDP.""" tutils.set_random_main_port() From 5400b076c520b2c77ac3c8c326c68d0178728311 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 19 Nov 2021 09:58:34 -0500 Subject: [PATCH 05/11] update --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 906d1ce847384..89fa1aa704705 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -378,7 +378,7 @@ def init_dist_connection( ) -def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Dict[str, Any]]]: +def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]: """This distributed utility collects dictionary state across all processes. Args: From c93252237af3feab1f12ccb2f18d076b92f01cf5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 16:29:52 +0000 Subject: [PATCH 06/11] update --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 89fa1aa704705..244367832fb69 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -396,7 +396,7 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> if torch.distributed.get_rank() == rank: objects = [state] else: - objects = [None] + objects = [None] # type: ignore torch.distributed.broadcast_object_list(objects, src=rank, device=device) states[rank] = objects[0] if torch.distributed.get_rank() != 0: From 05b389b02d0782645f8aed022e11ebc549c98fdd Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 16:55:48 +0000 Subject: [PATCH 07/11] update --- pytorch_lightning/utilities/distributed.py | 8 +++----- tests/utilities/test_distributed.py | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 244367832fb69..b95440ba02bec 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -392,14 +392,12 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> if not distributed_available(): return {0: state} states = {} + current_rank = torch.distributed.get_rank() for rank in range(1, torch.distributed.get_world_size()): - if torch.distributed.get_rank() == rank: - objects = [state] - else: - objects = [None] # type: ignore + objects = [state if current_rank == rank else None] # type: ignore torch.distributed.broadcast_object_list(objects, src=rank, device=device) states[rank] = objects[0] - if torch.distributed.get_rank() != 0: + if current_rank != 0: return None states[0] = state return states diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index 4e11d9035fece..f0bba003d56ae 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -72,9 +72,10 @@ def _test_collect_states(rank, worldsize): if rank == 1: assert collected_state is None else: - print(collected_state) assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} + torch.distributed.destroy_process_group() + @RunIf(skip_windows=True, min_gpus=2, min_torch="1.10") def test_collect_states(): From 5e0fb6d02de47d78e21c61660e5be5c5d1a6e533 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 19 Nov 2021 17:16:18 +0000 Subject: [PATCH 08/11] Update pytorch_lightning/utilities/distributed.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index b95440ba02bec..192e175b5b73b 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -394,7 +394,7 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> states = {} current_rank = torch.distributed.get_rank() for rank in range(1, torch.distributed.get_world_size()): - objects = [state if current_rank == rank else None] # type: ignore + objects = [state if current_rank == rank else None] torch.distributed.broadcast_object_list(objects, src=rank, device=device) states[rank] = objects[0] if current_rank != 0: From 1eab18d4fdf1ed37c124d3453c041ad4b01739ef Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 19 Nov 2021 12:56:28 -0500 Subject: [PATCH 09/11] update --- tests/utilities/test_distributed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index f0bba003d56ae..e732bceb6de81 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -74,8 +74,6 @@ def _test_collect_states(rank, worldsize): else: assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} - torch.distributed.destroy_process_group() - @RunIf(skip_windows=True, min_gpus=2, min_torch="1.10") def test_collect_states(): From e6da4f61cb09a3fb0e0adaf722ab93090f77e968 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 13:52:45 +0000 Subject: [PATCH 10/11] update on comments --- pytorch_lightning/utilities/distributed.py | 2 +- tests/utilities/test_distributed.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 192e175b5b73b..7c6e4f4048181 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -382,7 +382,7 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> """This distributed utility collects dictionary state across all processes. Args: - state: Dictionary contain the state of the current process + state: Dictionary containing the state of the current process device: Current process device. Returns: diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index e732bceb6de81..a48b4486a470f 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -61,22 +61,25 @@ def foo(): assert x is None -def _test_collect_states(rank, worldsize): +def _test_collect_states(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" # initialize the process group - torch.distributed.init_process_group("nccl", rank=rank, world_size=worldsize) + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) state = {"something": torch.tensor([rank])} collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}")) - if rank == 1: - assert collected_state is None - else: + if rank == 0: assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} + else: + assert collected_state is None @RunIf(skip_windows=True, min_gpus=2, min_torch="1.10") def test_collect_states(): - """Make sure result logging works with DDP.""" + """This test ensures state are properly collected across processes. + + This would be used to collect dataloader states as an example. + """ tutils.set_random_main_port() mp.spawn(_test_collect_states, args=(2,), nprocs=2) From 0e852f539cfe07faf2bf862bf3db2ea533dcd6f1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 13:53:55 +0000 Subject: [PATCH 11/11] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd9cf54a9730e..c2cdc5fda828e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) + * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) -