Skip to content

Commit

Permalink
Collective's PREMUL_SUM support with PyTorch 1.13 (#15201)
Browse files Browse the repository at this point in the history
* Collective's PREMUL_SUM support with PyTorch 1.13
* Fix test
* Skip under 1.13
  • Loading branch information
carmocca committed Oct 20, 2022
1 parent ec0d6d2 commit b866dc3
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 24 deletions.
24 changes: 14 additions & 10 deletions src/lightning_lite/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing_extensions import Self

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from lightning_lite.utilities.types import _TORCH_REDUCE_OP, CollectibleGroup
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.types import CollectibleGroup, RedOpType, ReduceOp

if dist.is_available():
from torch.distributed.constants import default_pg_timeout
Expand All @@ -34,12 +34,12 @@ def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
dist.broadcast(tensor, src, group=self.group)
return tensor

def all_reduce(self, tensor: torch.Tensor, op: Union[str, _TORCH_REDUCE_OP] = "sum") -> torch.Tensor:
def all_reduce(self, tensor: torch.Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
op = self._convert_to_native_op(op)
dist.all_reduce(tensor, op=op, group=self.group)
return tensor

def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, _TORCH_REDUCE_OP] = "sum") -> torch.Tensor:
def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
op = self._convert_to_native_op(op)
dist.reduce(tensor, dst, op=op, group=self.group)
return tensor
Expand All @@ -57,7 +57,7 @@ def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: i
return tensor

def reduce_scatter(
self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, _TORCH_REDUCE_OP] = "sum"
self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
) -> torch.Tensor:
op = self._convert_to_native_op(op)
dist.reduce_scatter(output, input_list, op=op, group=self.group)
Expand Down Expand Up @@ -155,13 +155,17 @@ def destroy_group(cls, group: CollectibleGroup) -> None:
dist.destroy_process_group(group)

@classmethod
def _convert_to_native_op(cls, op: Union[str, _TORCH_REDUCE_OP]) -> _TORCH_REDUCE_OP:
if isinstance(op, _TORCH_REDUCE_OP):
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:
# in 1.13, `ReduceOp` has become an empty shell for `RedOpType`, the latter being the actually returned class.
# for example, `ReduceOp.SUM` returns a `RedOpType.SUM`. the only exception is `RedOpType.PREMUL_SUM` where
# `ReduceOp` is still the desired class, but it's created via a special `_make_nccl_premul_sum` function
if isinstance(op, ReduceOp) or _TORCH_GREATER_EQUAL_1_13 and isinstance(op, RedOpType):
return op
if not isinstance(op, str):
raise ValueError(f"op {op!r} should be a `str` or `{_TORCH_REDUCE_OP.__name__}`")
raise ValueError(f"Unsupported op {op!r} of type {type(op).__name__}")
op = op.upper()
value = getattr(_TORCH_REDUCE_OP, op, None)
# `ReduceOp` should contain `RedOpType`'s members
value = getattr(ReduceOp, op, None)
if value is None:
raise ValueError(f"op {op!r} is not a member of `{_TORCH_REDUCE_OP.__name__}`")
raise ValueError(f"op {op!r} is not a member of `ReduceOp`")
return value
2 changes: 1 addition & 1 deletion src/lightning_lite/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@
_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2")
_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0")
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True)
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_1_14 = compare_version("torch", operator.ge, "1.14.0", use_base_version=True)
9 changes: 2 additions & 7 deletions src/lightning_lite/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,11 @@
if torch.distributed.is_available():
from torch.distributed import ProcessGroup, ReduceOp

# `ReduceOp has no attribute "RedOpType"`fall back if you have this version, but it is missing
# this si a case when you have installed PyTorch from master, for example a few last NGC dockers 22.09
if _TORCH_GREATER_EQUAL_1_13 and hasattr(ReduceOp, "RedOpType"):
_TORCH_REDUCE_OP = ReduceOp.RedOpType
else:
_TORCH_REDUCE_OP = ReduceOp
RedOpType = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object
else:
ProcessGroup = Any # type: ignore[assignment,misc]
ReduceOp = object # type: ignore[assignment,misc] # we are using isinstance check once
_TORCH_REDUCE_OP = object
RedOpType = object


_DictKey = TypeVar("_DictKey")
Expand Down
38 changes: 32 additions & 6 deletions tests/tests_lite/plugins/collectives/test_torch_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
from tests_lite.helpers.runif import RunIf

from lightning_lite.accelerators import CPUAccelerator
from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator
from lightning_lite.plugins.collectives import TorchCollective
from lightning_lite.plugins.environments import LightningEnvironment
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13

if TorchCollective.is_available():
from torch.distributed import ReduceOp
Expand Down Expand Up @@ -133,13 +133,20 @@ def test_convert_ops():
assert TorchCollective._convert_to_native_op("avg") == ReduceOp.AVG

# Test invalid type
with pytest.raises(ValueError, match="op 1 should be a `str` or `Red"):
with pytest.raises(ValueError, match="Unsupported op 1 of type int"):
TorchCollective._convert_to_native_op(1)

# Test invalid string
with pytest.raises(ValueError, match="op 'INVALID' is not a member of `Red"):
TorchCollective._convert_to_native_op("invalid")

# Test RedOpType
if _TORCH_GREATER_EQUAL_1_13:
assert TorchCollective._convert_to_native_op(ReduceOp.RedOpType.AVG) == ReduceOp.AVG
op = torch.distributed._make_nccl_premul_sum(2.0) # this returns a ReduceOp
assert TorchCollective._convert_to_native_op(op) == ReduceOp.PREMUL_SUM
assert TorchCollective._convert_to_native_op("premul_sum") == ReduceOp.PREMUL_SUM


@skip_distributed_unavailable
@mock.patch.dict(os.environ, {}, clear=True)
Expand Down Expand Up @@ -174,8 +181,10 @@ def test_repeated_create_and_destroy():


def collective_launch(fn, parallel_devices, num_groups=1):
device_to_accelerator = {"cuda": CUDAAccelerator, "cpu": CPUAccelerator}
accelerator_cls = device_to_accelerator[parallel_devices[0].type]
strategy = DDPSpawnStrategy(
accelerator=CPUAccelerator(), parallel_devices=parallel_devices, cluster_environment=LightningEnvironment()
accelerator=accelerator_cls(), parallel_devices=parallel_devices, cluster_environment=LightningEnvironment()
)
launcher = _MultiProcessingLauncher(strategy=strategy)
collectives = [TorchCollective() for _ in range(num_groups)]
Expand All @@ -188,7 +197,7 @@ def wrap_launch_function(fn, strategy, collective, *args, **kwargs):
collective.setup(
world_size=strategy.num_processes,
main_address="localhost",
backend="gloo",
backend=strategy._get_process_group_backend(),
rank=strategy.global_rank,
)
return fn(*args, **kwargs)
Expand All @@ -212,7 +221,7 @@ def _test_distributed_collectives_fn(strategy, collective):

# all_reduce
this = torch.tensor(strategy.global_rank + 1)
out = collective.all_reduce(this, op="min")
out = collective.all_reduce(this, op=ReduceOp.MIN)
expected = torch.tensor(1)
torch.testing.assert_close(out, expected)

Expand All @@ -225,6 +234,23 @@ def test_collectives_distributed(n):
collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n)


def _test_distributed_collectives_cuda_fn(strategy, collective):
collective.create_group()

this = torch.tensor(1.5, device=strategy.root_device)
premul_sum = torch.distributed._make_nccl_premul_sum(2.0)
out = collective.all_reduce(this, op=premul_sum)
assert out == 3

collective.teardown()


@skip_distributed_unavailable
@RunIf(min_cuda_gpus=1, min_torch="1.13")
def test_collectives_distributed_cuda():
collective_launch(_test_distributed_collectives_cuda_fn, [torch.device("cuda")])


def _test_two_groups(strategy, left_collective, right_collective):
left_collective.create_group(ranks=[0, 1])
right_collective.create_group(ranks=[1, 2])
Expand Down

0 comments on commit b866dc3

Please sign in to comment.