Skip to content

Commit

Permalink
add more assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 6, 2024
1 parent b3103ce commit c917bee
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions tests/tests_fabric/strategies/test_model_parallel_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,14 @@ def test_tensor_parallel():
with fabric.init_module(empty_init=True):
model = FeedForward()

model = fabric.setup(model)
optimizer = torch.optim.AdamW(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

optimizer = fabric.setup_optimizers(optimizer)

device_mesh = fabric.strategy.device_mesh
assert all(tensor.device_mesh == device_mesh["tensor_parallel"] for tensor in optimizer.param_groups[0]["params"])
assert all(isinstance(weight, DTensor) for weight in model.parameters())
assert model.w1.weight.device_mesh == fabric.strategy.device_mesh["tensor_parallel"]
assert model.w1.weight.device_mesh == device_mesh["tensor_parallel"]

dataset_size = 6
dataset = RandomDataset(32, dataset_size)
Expand All @@ -160,6 +163,8 @@ def test_tensor_parallel():

output = model(batch)
fabric.backward(output.sum())
assert isinstance(model.w1.weight.grad, DTensor)
assert model.w1.weight.grad.device_mesh == device_mesh["tensor_parallel"]
optimizer.step()
optimizer.zero_grad()

Expand All @@ -181,23 +186,28 @@ def test_fsdp2_tensor_parallel():
with fabric.init_module(empty_init=True):
model = FeedForward()

model = fabric.setup(model)
optimizer = torch.optim.AdamW(model.parameters())
model, optimizer = fabric.setup(model, optimizer)
optimizer = fabric.setup_optimizers(optimizer)

assert all(isinstance(weight, DTensor) for weight in model.parameters())
assert all(isinstance(tensor, DTensor) for tensor in optimizer.param_groups[0]["params"])
assert model.w1.weight.device_mesh.ndim == 2
assert model.w1.weight.device_mesh.size(0) == 2
assert model.w1.weight.device_mesh.size(1) == 2
assert all(weight.device.type != "meta" for weight in model.parameters())
assert all(tensor.device_mesh.ndim == 2 for tensor in optimizer.param_groups[0]["params"])
assert all(tensor.device.type != "meta" for tensor in optimizer.param_groups[0]["params"])

dataset_size = 8
dataset = RandomDataset(32, dataset_size)
dataloader = DataLoader(dataset, batch_size=2)
dataloader = fabric.setup_dataloaders(dataloader)

# No data sharding across TP dimension, sharding across data-parallel dimension only
dp_mesh = fabric.strategy.device_mesh["data_parallel"]
tp_mesh = fabric.strategy.device_mesh["tensor_parallel"]
device_mesh = fabric.strategy.device_mesh
dp_mesh = device_mesh["data_parallel"]
tp_mesh = device_mesh["tensor_parallel"]
assert len(dataloader) == dataset_size // dataloader.batch_size // dp_mesh.size()
assert isinstance(dataloader.sampler, DistributedSampler)

Expand All @@ -212,6 +222,8 @@ def test_fsdp2_tensor_parallel():

output = model(batch)
fabric.backward(output.sum())
assert isinstance(model.w1.weight.grad, DTensor)
assert model.w1.weight.grad.device_mesh == device_mesh
optimizer.step()
optimizer.zero_grad()

Expand Down Expand Up @@ -385,10 +397,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
"clip_type",
[
pytest.param("norm", marks=pytest.mark.skip("Gradient clipping by norm is not correct.")),
pytest.param(
"val",
marks=pytest.mark.xfail(raises=RuntimeError, reason="Clipping DTensor by value raises error in PyTorch"),
),
"val",
],
)
def test_clip_gradients(clip_type, precision):
Expand All @@ -403,8 +412,9 @@ def test_clip_gradients(clip_type, precision):
model = torch.nn.Linear(in_features, out_features, bias=False)
model.weight.data.fill_(0.01)

model = fabric.setup(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
optimizer = fabric.setup_optimizers(optimizer)

batch = torch.full((1, in_features), 0.1, device=fabric.device)
loss = model(batch).sum()
Expand All @@ -424,7 +434,7 @@ def test_clip_gradients(clip_type, precision):
val = model.weight.grad.full_tensor()[0, 0].item()
new_val = val / 2.0
fabric.clip_gradients(model, optimizer, clip_val=new_val)
assert torch.allclose(model.weight.full_tensor().grad, torch.full_like(model.weight.grad, new_val))
assert torch.allclose(model.weight.grad.full_tensor(), torch.full_like(model.weight.grad.full_tensor(), new_val))
else:
raise AssertionError(f"Unknown clip type: {clip_type}")

Expand Down

0 comments on commit c917bee

Please sign in to comment.