Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test][Distributed] Make more tests multi-threaded. #125095

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 2 additions & 9 deletions test/distributed/_tensor/experimental/test_local_map.py
Expand Up @@ -12,10 +12,7 @@
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.experimental import local_map
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


def equal_forward(device_mesh, X, Y):
Expand All @@ -38,13 +35,12 @@ def mul_forward(device_mesh, X, scalar):
return torch.mul(X, scalar)


class TestLocalMap(DTensorTestBase):
class TestLocalMap(DTensorOpTestBase):
@property
def world_size(self):
return 2

# simple correctness check
@with_comms
def test_local_map_correctness(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand Down Expand Up @@ -86,7 +82,6 @@ def test_local_map_correctness(self):
self.assertEqual(Y_dt.to_local(), Y)

# check for `out_placements`
@with_comms
def test_local_map_out_placements(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand All @@ -108,7 +103,6 @@ def test_local_map_out_placements(self):
self.assertTrue(not (X.equal(Y)))

# check for `in_placements` handling
@with_comms
def test_local_map_in_placements(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand Down Expand Up @@ -174,7 +168,6 @@ def test_local_map_in_placements(self):
self.assertEqual(Y_dt.full_tensor(), Y)

# check for `redistribute_inputs` handling
@with_comms
def test_local_map_redistribute(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand Down
10 changes: 2 additions & 8 deletions test/distributed/_tensor/experimental/test_tp_transform.py
Expand Up @@ -12,10 +12,7 @@
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


class MLPListModule(torch.nn.Module):
Expand Down Expand Up @@ -52,7 +49,7 @@ def forward(self, x):
return self.bn(self.fc(x))


class TensorParallelTest(DTensorTestBase):
class TensorParallelTest(DTensorOpTestBase):
def setUp(self) -> None:
super().setUp()

Expand All @@ -66,7 +63,6 @@ def assert_has_c10d_ops(
actual_ops_count[str(node.target)] += 1
self.assertDictEqual(expected_ops_count, actual_ops_count)

@with_comms
def test_tp_transform_with_uncovered_op(self):
model = DummyModel().to(device=self.device_type)
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
Expand Down Expand Up @@ -96,7 +92,6 @@ def test_tp_transform_with_uncovered_op(self):
},
)

@with_comms
def test_tp_transform_e2e(self):
torch.manual_seed(0)
model = MLPListModule(2).to(device=self.device_type)
Expand Down Expand Up @@ -134,7 +129,6 @@ def test_tp_transform_e2e(self):
},
)

@with_comms
def test_tp_transform_no_bias(self):
torch.manual_seed(0)
model = MLPListModule(1, bias=False).to(device=self.device_type)
Expand Down
29 changes: 9 additions & 20 deletions test/distributed/_tensor/test_api.py
Expand Up @@ -12,10 +12,7 @@
Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase


class MyModel(nn.Module):
Expand All @@ -33,16 +30,15 @@ def reset_parameters(self):
m.reset_parameters()


class DTensorAPITest(DTensorTestBase):
class DTensorAPITest(DTensorOpTestBase):
@property
def world_size(self) -> int:
# hard code world size to 4 as we need to test
# at least with 2d mesh
return 4

@with_comms
def test_distribute_tensor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]

for requires_grad in [True, False]:
Expand All @@ -63,7 +59,6 @@ def test_distribute_tensor(self):
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec)
self.assertEqual(dist_tensor.placements[0].dim, 1)

@with_comms
def test_distribute_tensor_errors(self):
device_mesh = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 2)
Expand Down Expand Up @@ -92,9 +87,8 @@ def test_distribute_tensor_errors(self):
new_spec = [Shard(0), Replicate()]
distribute_tensor(dtensor, device_mesh, new_spec)

@with_comms
def test_distribute_tensor_uneven_sharding(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()
input_sizes_and_shard_dims = [
((self.world_size * 3 + 1, 3, 3), 0),
((self.world_size * 3 + 2, 3, 3), 0),
Expand All @@ -114,9 +108,8 @@ def test_distribute_tensor_uneven_sharding(self):
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor, splitted_tensor_list[self.rank])

@with_comms
def test_distribute_module(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()
# fully shard all linear modules on dim 0
module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type)
shard_spec = [Shard(0)]
Expand Down Expand Up @@ -177,9 +170,8 @@ def shard_fn(name, module, device_mesh):
else:
self.assertEqual(param.placements, replica_spec)

@with_comms
def test_distribute_module_input_fn_output_fn(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# fully replicate all linear modules
module_to_replicate = MyModel(20, 1, device=self.device_type)
Expand Down Expand Up @@ -222,9 +214,8 @@ def replicate_input_fn(mod, inputs, device_mesh):
self.assertTrue(isinstance(param_grad, DTensor))
self.assertTrue(isinstance(param_grad.placements[0], Replicate))

@with_comms
def test_distribute_module_input_fn_output_fn_warning(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# fully replicate all linear modules
module_to_replicate = MyModel(20, 1, device=self.device_type)
Expand All @@ -250,9 +241,8 @@ def output_fn(outputs, device_mesh):
self.assertIsInstance(local_out, torch.Tensor)
self.assertNotIsInstance(local_out, DTensor)

@with_comms
def test_distribute_module_casting(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# check DTensor casting
dt = DTensor.from_local(torch.rand(10), device_mesh, [Replicate()])
Expand Down Expand Up @@ -288,11 +278,10 @@ def test_distribute_module_casting(self):
output = replica_model(dt)
self.assertEqual(output.dtype, torch.bfloat16)

@with_comms
def test_distribute_module_meta(self):
# If the model is too big, the user may first the create entire model on the meta device and then initialize
# it on the device in the partition function.
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# fully shard all parameters on dim 0
module_to_shard = MyModel(5 * self.world_size, 20, device="meta")
Expand Down
31 changes: 7 additions & 24 deletions test/distributed/_tensor/test_attention.py
Expand Up @@ -4,7 +4,7 @@

import torch
from torch import nn
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Shard
from torch.distributed._tensor import distribute_tensor, Shard
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.experimental.attention import (
_CausalBehavior,
Expand All @@ -25,17 +25,16 @@
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
DTensorOpTestBase,
ModelArgs,
Transformer,
with_comms,
)


c10d_functional = torch.ops.c10d_functional


class RingAttentionTest(DTensorTestBase):
class RingAttentionTest(DTensorOpTestBase):
@property
def world_size(self) -> int:
return 2
Expand All @@ -44,13 +43,9 @@ def world_size(self) -> int:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@parametrize("is_causal", [True, False])
def test_ring_attention_sdpa(self, is_causal: bool) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 8
query_tokens = 8
Expand Down Expand Up @@ -168,14 +163,10 @@ def test_ring_attention_sdpa(self, is_causal: bool) -> None:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
@parametrize("is_causal", [True, False])
def test_ring_attention_native_transformer(self, is_causal: bool) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 8
ntokens = 8
Expand Down Expand Up @@ -250,13 +241,9 @@ def test_is_causal_behavior(self) -> None:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
def test_ring_attention_custom_transformer(self) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 2
args = ModelArgs()
Expand Down Expand Up @@ -301,7 +288,6 @@ def test_ring_attention_custom_transformer(self) -> None:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@parametrize(
"attention_fn",
[
Expand All @@ -311,10 +297,7 @@ def test_ring_attention_custom_transformer(self) -> None:
],
)
def test_ring_attention_compile(self, attention_fn: object) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 8
query_tokens = 8
Expand Down