Skip to content

Commit

Permalink
[Test][Distributed] Make more tests multi-threaded.
Browse files Browse the repository at this point in the history
This conversion covers all tests under 'test/distributed/_tensor' directory.

Fixes pytorch#108744
  • Loading branch information
fkouteib committed Apr 27, 2024
1 parent 368f521 commit 0362967
Show file tree
Hide file tree
Showing 19 changed files with 158 additions and 341 deletions.
9 changes: 2 additions & 7 deletions test/distributed/_tensor/experimental/test_local_map.py
Expand Up @@ -13,8 +13,7 @@
from torch.distributed._tensor.experimental import local_map
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
DTensorOpTestBase,
)


Expand All @@ -38,13 +37,12 @@ def mul_forward(device_mesh, X, scalar):
return torch.mul(X, scalar)


class TestLocalMap(DTensorTestBase):
class TestLocalMap(DTensorOpTestBase):
@property
def world_size(self):
return 2

# simple correctness check
@with_comms
def test_local_map_correctness(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand Down Expand Up @@ -86,7 +84,6 @@ def test_local_map_correctness(self):
self.assertEqual(Y_dt.to_local(), Y)

# check for `out_placements`
@with_comms
def test_local_map_out_placements(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand All @@ -108,7 +105,6 @@ def test_local_map_out_placements(self):
self.assertTrue(not (X.equal(Y)))

# check for `in_placements` handling
@with_comms
def test_local_map_in_placements(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand Down Expand Up @@ -174,7 +170,6 @@ def test_local_map_in_placements(self):
self.assertEqual(Y_dt.full_tensor(), Y)

# check for `redistribute_inputs` handling
@with_comms
def test_local_map_redistribute(self):
device_mesh = init_device_mesh(
device_type=self.device_type, mesh_shape=(self.world_size,)
Expand Down
8 changes: 2 additions & 6 deletions test/distributed/_tensor/experimental/test_tp_transform.py
Expand Up @@ -13,8 +13,7 @@
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
DTensorOpTestBase,
)


Expand Down Expand Up @@ -52,7 +51,7 @@ def forward(self, x):
return self.bn(self.fc(x))


class TensorParallelTest(DTensorTestBase):
class TensorParallelTest(DTensorOpTestBase):
def setUp(self) -> None:
super().setUp()

Expand All @@ -66,7 +65,6 @@ def assert_has_c10d_ops(
actual_ops_count[str(node.target)] += 1
self.assertDictEqual(expected_ops_count, actual_ops_count)

@with_comms
def test_tp_transform_with_uncovered_op(self):
model = DummyModel().to(device=self.device_type)
inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),)
Expand Down Expand Up @@ -96,7 +94,6 @@ def test_tp_transform_with_uncovered_op(self):
},
)

@with_comms
def test_tp_transform_e2e(self):
torch.manual_seed(0)
model = MLPListModule(2).to(device=self.device_type)
Expand Down Expand Up @@ -134,7 +131,6 @@ def test_tp_transform_e2e(self):
},
)

@with_comms
def test_tp_transform_no_bias(self):
torch.manual_seed(0)
model = MLPListModule(1, bias=False).to(device=self.device_type)
Expand Down
27 changes: 9 additions & 18 deletions test/distributed/_tensor/test_api.py
Expand Up @@ -13,8 +13,7 @@
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
DTensorOpTestBase,
)


Expand All @@ -33,16 +32,15 @@ def reset_parameters(self):
m.reset_parameters()


class DTensorAPITest(DTensorTestBase):
class DTensorAPITest(DTensorOpTestBase):
@property
def world_size(self) -> int:
# hard code world size to 4 as we need to test
# at least with 2d mesh
return 4

@with_comms
def test_distribute_tensor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]

for requires_grad in [True, False]:
Expand All @@ -63,7 +61,6 @@ def test_distribute_tensor(self):
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec)
self.assertEqual(dist_tensor.placements[0].dim, 1)

@with_comms
def test_distribute_tensor_errors(self):
device_mesh = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 2)
Expand Down Expand Up @@ -92,9 +89,8 @@ def test_distribute_tensor_errors(self):
new_spec = [Shard(0), Replicate()]
distribute_tensor(dtensor, device_mesh, new_spec)

@with_comms
def test_distribute_tensor_uneven_sharding(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()
input_sizes_and_shard_dims = [
((self.world_size * 3 + 1, 3, 3), 0),
((self.world_size * 3 + 2, 3, 3), 0),
Expand All @@ -114,9 +110,8 @@ def test_distribute_tensor_uneven_sharding(self):
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor, splitted_tensor_list[self.rank])

@with_comms
def test_distribute_module(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()
# fully shard all linear modules on dim 0
module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type)
shard_spec = [Shard(0)]
Expand Down Expand Up @@ -177,9 +172,8 @@ def shard_fn(name, module, device_mesh):
else:
self.assertEqual(param.placements, replica_spec)

@with_comms
def test_distribute_module_input_fn_output_fn(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# fully replicate all linear modules
module_to_replicate = MyModel(20, 1, device=self.device_type)
Expand Down Expand Up @@ -222,9 +216,8 @@ def replicate_input_fn(mod, inputs, device_mesh):
self.assertTrue(isinstance(param_grad, DTensor))
self.assertTrue(isinstance(param_grad.placements[0], Replicate))

@with_comms
def test_distribute_module_input_fn_output_fn_warning(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# fully replicate all linear modules
module_to_replicate = MyModel(20, 1, device=self.device_type)
Expand All @@ -250,9 +243,8 @@ def output_fn(outputs, device_mesh):
self.assertIsInstance(local_out, torch.Tensor)
self.assertNotIsInstance(local_out, DTensor)

@with_comms
def test_distribute_module_casting(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# check DTensor casting
dt = DTensor.from_local(torch.rand(10), device_mesh, [Replicate()])
Expand Down Expand Up @@ -288,11 +280,10 @@ def test_distribute_module_casting(self):
output = replica_model(dt)
self.assertEqual(output.dtype, torch.bfloat16)

@with_comms
def test_distribute_module_meta(self):
# If the model is too big, the user may first the create entire model on the meta device and then initialize
# it on the device in the partition function.
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
device_mesh = self.build_device_mesh()

# fully shard all parameters on dim 0
module_to_shard = MyModel(5 * self.world_size, 20, device="meta")
Expand Down
31 changes: 7 additions & 24 deletions test/distributed/_tensor/test_attention.py
Expand Up @@ -4,7 +4,7 @@

import torch
from torch import nn
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Shard
from torch.distributed._tensor import distribute_tensor, Shard
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.experimental.attention import (
_CausalBehavior,
Expand All @@ -25,17 +25,16 @@
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
DTensorOpTestBase,
ModelArgs,
Transformer,
with_comms,
)


c10d_functional = torch.ops.c10d_functional


class RingAttentionTest(DTensorTestBase):
class RingAttentionTest(DTensorOpTestBase):
@property
def world_size(self) -> int:
return 2
Expand All @@ -44,13 +43,9 @@ def world_size(self) -> int:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@parametrize("is_causal", [True, False])
def test_ring_attention_sdpa(self, is_causal: bool) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 8
query_tokens = 8
Expand Down Expand Up @@ -168,14 +163,10 @@ def test_ring_attention_sdpa(self, is_causal: bool) -> None:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
@parametrize("is_causal", [True, False])
def test_ring_attention_native_transformer(self, is_causal: bool) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 8
ntokens = 8
Expand Down Expand Up @@ -250,13 +241,9 @@ def test_is_causal_behavior(self) -> None:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
def test_ring_attention_custom_transformer(self) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 2
args = ModelArgs()
Expand Down Expand Up @@ -301,7 +288,6 @@ def test_ring_attention_custom_transformer(self) -> None:
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
)
@with_comms
@parametrize(
"attention_fn",
[
Expand All @@ -311,10 +297,7 @@ def test_ring_attention_custom_transformer(self) -> None:
],
)
def test_ring_attention_compile(self, attention_fn: object) -> None:
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, self.world_size),
)
device_mesh = self.build_device_mesh()
dtype = torch.bfloat16
bs = 8
query_tokens = 8
Expand Down

0 comments on commit 0362967

Please sign in to comment.