Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collective's PREMUL_SUM support with PyTorch 1.13 #15201

Merged
merged 7 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/lightning_lite/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing_extensions import Self

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from lightning_lite.utilities.types import _TORCH_REDUCE_OP, CollectibleGroup
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.types import CollectibleGroup, RedOpType, ReduceOp

if dist.is_available():
from torch.distributed.constants import default_pg_timeout
Expand All @@ -34,12 +34,12 @@ def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
dist.broadcast(tensor, src, group=self.group)
return tensor

def all_reduce(self, tensor: torch.Tensor, op: Union[str, _TORCH_REDUCE_OP] = "sum") -> torch.Tensor:
def all_reduce(self, tensor: torch.Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
op = self._convert_to_native_op(op)
dist.all_reduce(tensor, op=op, group=self.group)
return tensor

def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, _TORCH_REDUCE_OP] = "sum") -> torch.Tensor:
def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
op = self._convert_to_native_op(op)
dist.reduce(tensor, dst, op=op, group=self.group)
return tensor
Expand All @@ -57,7 +57,7 @@ def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: i
return tensor

def reduce_scatter(
self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, _TORCH_REDUCE_OP] = "sum"
self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
) -> torch.Tensor:
op = self._convert_to_native_op(op)
dist.reduce_scatter(output, input_list, op=op, group=self.group)
Expand Down Expand Up @@ -155,13 +155,17 @@ def destroy_group(cls, group: CollectibleGroup) -> None:
dist.destroy_process_group(group)

@classmethod
def _convert_to_native_op(cls, op: Union[str, _TORCH_REDUCE_OP]) -> _TORCH_REDUCE_OP:
if isinstance(op, _TORCH_REDUCE_OP):
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:
# in 1.13, `ReduceOp` has become an empty shell for `RedOpType`, the latter being the actually returned class.
# for example, `ReduceOp.SUM` returns a `RedOpType.SUM`. the only exception is `RedOpType.PREMUL_SUM` where
# `ReduceOp` is still the desired class, but it's created via a special `_make_nccl_premul_sum` function
if isinstance(op, ReduceOp) or _TORCH_GREATER_EQUAL_1_13 and isinstance(op, RedOpType):
return op
if not isinstance(op, str):
raise ValueError(f"op {op!r} should be a `str` or `{_TORCH_REDUCE_OP.__name__}`")
raise ValueError(f"Unsupported op {op!r} of type {type(op).__name__}")
op = op.upper()
value = getattr(_TORCH_REDUCE_OP, op, None)
# `ReduceOp` should contain `RedOpType`'s members
value = getattr(ReduceOp, op, None)
if value is None:
raise ValueError(f"op {op!r} is not a member of `{_TORCH_REDUCE_OP.__name__}`")
raise ValueError(f"op {op!r} is not a member of `ReduceOp`")
return value
2 changes: 1 addition & 1 deletion src/lightning_lite/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@
_TORCH_LESSER_EQUAL_1_10_2 = compare_version("torch", operator.le, "1.10.2")
_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0")
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0", use_base_version=True)
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_1_14 = compare_version("torch", operator.ge, "1.14.0", use_base_version=True)
9 changes: 2 additions & 7 deletions src/lightning_lite/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,11 @@
if torch.distributed.is_available():
from torch.distributed import ProcessGroup, ReduceOp

# `ReduceOp has no attribute "RedOpType"`fall back if you have this version, but it is missing
# this si a case when you have installed PyTorch from master, for example a few last NGC dockers 22.09
if _TORCH_GREATER_EQUAL_1_13 and hasattr(ReduceOp, "RedOpType"):
_TORCH_REDUCE_OP = ReduceOp.RedOpType
else:
_TORCH_REDUCE_OP = ReduceOp
RedOpType = ReduceOp.RedOpType if _TORCH_GREATER_EQUAL_1_13 else object
Borda marked this conversation as resolved.
Show resolved Hide resolved
else:
ProcessGroup = Any # type: ignore[assignment,misc]
ReduceOp = object # type: ignore[assignment,misc] # we are using isinstance check once
_TORCH_REDUCE_OP = object
RedOpType = object


_DictKey = TypeVar("_DictKey")
Expand Down
38 changes: 32 additions & 6 deletions tests/tests_lite/plugins/collectives/test_torch_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
from tests_lite.helpers.runif import RunIf

from lightning_lite.accelerators import CPUAccelerator
from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator
from lightning_lite.plugins.collectives import TorchCollective
from lightning_lite.plugins.environments import LightningEnvironment
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13

if TorchCollective.is_available():
from torch.distributed import ReduceOp
Expand Down Expand Up @@ -133,13 +133,20 @@ def test_convert_ops():
assert TorchCollective._convert_to_native_op("avg") == ReduceOp.AVG

# Test invalid type
with pytest.raises(ValueError, match="op 1 should be a `str` or `Red"):
with pytest.raises(ValueError, match="Unsupported op 1 of type int"):
TorchCollective._convert_to_native_op(1)

# Test invalid string
with pytest.raises(ValueError, match="op 'INVALID' is not a member of `Red"):
TorchCollective._convert_to_native_op("invalid")

# Test RedOpType
if _TORCH_GREATER_EQUAL_1_13:
assert TorchCollective._convert_to_native_op(ReduceOp.RedOpType.AVG) == ReduceOp.AVG
op = torch.distributed._make_nccl_premul_sum(2.0) # this returns a ReduceOp
assert TorchCollective._convert_to_native_op(op) == ReduceOp.PREMUL_SUM
assert TorchCollective._convert_to_native_op("premul_sum") == ReduceOp.PREMUL_SUM


@skip_distributed_unavailable
@mock.patch.dict(os.environ, {}, clear=True)
Expand Down Expand Up @@ -174,8 +181,10 @@ def test_repeated_create_and_destroy():


def collective_launch(fn, parallel_devices, num_groups=1):
device_to_accelerator = {"cuda": CUDAAccelerator, "cpu": CPUAccelerator}
accelerator_cls = device_to_accelerator[parallel_devices[0].type]
strategy = DDPSpawnStrategy(
accelerator=CPUAccelerator(), parallel_devices=parallel_devices, cluster_environment=LightningEnvironment()
accelerator=accelerator_cls(), parallel_devices=parallel_devices, cluster_environment=LightningEnvironment()
)
launcher = _MultiProcessingLauncher(strategy=strategy)
collectives = [TorchCollective() for _ in range(num_groups)]
Expand All @@ -188,7 +197,7 @@ def wrap_launch_function(fn, strategy, collective, *args, **kwargs):
collective.setup(
world_size=strategy.num_processes,
main_address="localhost",
backend="gloo",
backend=strategy._get_process_group_backend(),
rank=strategy.global_rank,
)
return fn(*args, **kwargs)
Expand All @@ -212,7 +221,7 @@ def _test_distributed_collectives_fn(strategy, collective):

# all_reduce
this = torch.tensor(strategy.global_rank + 1)
out = collective.all_reduce(this, op="min")
out = collective.all_reduce(this, op=ReduceOp.MIN)
expected = torch.tensor(1)
torch.testing.assert_close(out, expected)

Expand All @@ -225,6 +234,23 @@ def test_collectives_distributed(n):
collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n)


def _test_distributed_collectives_cuda_fn(strategy, collective):
collective.create_group()

this = torch.tensor(1.5, device=strategy.root_device)
premul_sum = torch.distributed._make_nccl_premul_sum(2.0)
out = collective.all_reduce(this, op=premul_sum)
assert out == 3

collective.teardown()


@skip_distributed_unavailable
@RunIf(min_cuda_gpus=1, min_torch="1.13")
def test_collectives_distributed_cuda():
collective_launch(_test_distributed_collectives_cuda_fn, [torch.device("cuda")])


def _test_two_groups(strategy, left_collective, right_collective):
left_collective.create_group(ranks=[0, 1])
right_collective.create_group(ranks=[1, 2])
Expand Down