Skip to content

Commit

Permalink
Fix torch.distributed._* import statements in tests (#11416)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
akihironitta and carmocca committed Jan 12, 2022
1 parent b058025 commit ba71937
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import pytest
import torch
import torch.distributed as dist
from torch import nn
from torch.optim import Adam, SGD

Expand Down Expand Up @@ -304,16 +303,18 @@ def assert_device(device: torch.device) -> None:
assert_device(torch.device("cpu"))


class BoringModelWithShardedTensor(BoringModel):
def __init__(self, spec):
super().__init__()
self.sharded_tensor = dist._sharded_tensor.empty(spec, 10, 20)
self.sharded_tensor.local_shards()[0].tensor.fill_(0)


@RunIf(min_torch="1.10", skip_windows=True)
def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
spec = dist._sharding_spec.ChunkShardingSpec(
from torch.distributed._sharded_tensor import empty as sharded_tensor_empty
from torch.distributed._sharding_spec import ChunkShardingSpec

class BoringModelWithShardedTensor(BoringModel):
def __init__(self, spec):
super().__init__()
self.sharded_tensor = sharded_tensor_empty(spec, 10, 20)
self.sharded_tensor.local_shards()[0].tensor.fill_(0)

spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cpu",
Expand Down

0 comments on commit ba71937

Please sign in to comment.