Skip to content

Commit

Permalink
[PT-D][Checkpoint] Update import and update docstring for distributed…
Browse files Browse the repository at this point in the history
… checkpoint (pytorch#89256)

Update test import and docstring as we have moved distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (pytorch#88698).

Test: CI
Pull Request resolved: pytorch#89256
Approved by: https://github.com/fduwjj
  • Loading branch information
wz337 authored and kulinseth committed Dec 9, 2022
1 parent f5f76e1 commit 03de71e
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 67 deletions.
125 changes: 63 additions & 62 deletions test/distributed/checkpoint/test_checkpoint.py
Expand Up @@ -2,9 +2,9 @@

import sys
from typing import Optional, List, cast
from torch.distributed._shard.checkpoint.storage import WriteResult
from torch.distributed.checkpoint.storage import WriteResult

from torch.distributed._shard.checkpoint import (
from torch.distributed.checkpoint import (
StorageReader,
StorageWriter,
CheckpointException,
Expand Down Expand Up @@ -63,6 +63,7 @@
)
sys.exit(0)


class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -121,34 +122,44 @@ def test_default_metadata(self) -> None:
)

state_dict = {
'sharded': sharded_tensor.rand(spec, (10, 10, )),
'replicated': torch.rand(4, device=device),
'bytes': [1, 2, 3, 4],
"sharded": sharded_tensor.rand(
spec,
(
10,
10,
),
),
"replicated": torch.rand(4, device=device),
"bytes": [1, 2, 3, 4],
}

metadata = _create_default_local_metadata(state_dict)
self.assertTrue('bytes' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['bytes'], BytesStorageMetadata)
self.assertTrue("bytes" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["bytes"], BytesStorageMetadata
)

self.assertTrue('replicated' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['replicated'], TensorStorageMetadata)
md = metadata.state_dict_metadata['replicated']
self.assertEqual(md.size, state_dict['replicated'].size())
self.assertTrue("replicated" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["replicated"], TensorStorageMetadata
)
md = metadata.state_dict_metadata["replicated"]
self.assertEqual(md.size, state_dict["replicated"].size())
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(1, len(md.chunks))

self.assertTrue('sharded' in metadata.state_dict_metadata)
self.assertIsInstance(metadata.state_dict_metadata['sharded'], TensorStorageMetadata)
md = metadata.state_dict_metadata['sharded']
self.assertTrue("sharded" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["sharded"], TensorStorageMetadata
)
md = metadata.state_dict_metadata["sharded"]
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(md.size, state_dict['sharded'].size())
self.assertEqual(md.size, state_dict["sharded"].size())
self.assertEqual(2, len(md.chunks))


class TestStorageBase:
def __init__(
self,
fail_conf
):
def __init__(self, fail_conf):
self.fail_conf = fail_conf
self.rank = 0 if not dist.is_initialized() else dist.get_rank()

Expand All @@ -164,16 +175,16 @@ def _fail_rank_async(self, name, result=None):
ranks = self._get_ranks(name)
fut = Future()
if ranks is not None and self.rank in ranks:
fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
fut.set_exception(
ValueError(f"async rank fail {self.rank} for {name}")
)
else:
fut.set_result(result)
return fut


class FaultyStorageWriter(TestStorageBase, StorageWriter):
def __init__(
self,
fail_conf
):
def __init__(self, fail_conf):
super(FaultyStorageWriter, self).__init__(fail_conf)

def init(self, is_coordinator: bool) -> None:
Expand All @@ -188,23 +199,19 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
return plans

def write_data(
self,
plan: SavePlan,
planner: SavePlanner
self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]:
self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", [])

def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
def finish(
self, metadata: Metadata, results: List[List[WriteResult]]
) -> None:
self._fail_rank("fail_finish")


class FaultyStorageReader(TestStorageBase, StorageReader):
def __init__(
self,
metadata,
fail_conf
):
def __init__(self, metadata, fail_conf):
super(FaultyStorageReader, self).__init__(fail_conf)
self.metadata = metadata

Expand All @@ -219,35 +226,32 @@ def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
self._fail_rank("fail_prepare_global_plan")
return plans

def read_data(
self,
plan: LoadPlan,
planner: LoadPlanner
) -> Future[None]:
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
self._fail_rank("fail_read_data")
return self._fail_rank_async("fail_read_data_async")

def read_metadata(self) -> Metadata:
self._fail_rank("fail_read_metadata")
return self.metadata


class TestDistributedFailure(ShardedTensorTestBase):
def get_spec(self):
return ChunkShardingSpec(
dim=0,
placements=[
f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
]
],
)

@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_dummy_writer_works(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}

save_state_dict(state_dict, FaultyStorageWriter({}))
Expand All @@ -257,9 +261,9 @@ def test_dummy_writer_works(self) -> None:
@requires_nccl()
def test_dummy_reader_works(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}
metadata = _create_default_local_metadata(state_dict)

Expand All @@ -283,8 +287,10 @@ def _test_dist_failure(self, callback, kwargs):

failed_ranks = e.failures.keys()
for rank in bad_ranks:
self.assertTrue(rank in failed_ranks, msg=f"{rank} was supposed to fail was fine")

self.assertTrue(
rank in failed_ranks,
msg=f"{rank} was supposed to fail was fine",
)

def _test_save(self, state_dict, coordinator=0, **kwargs):
no_dist = not dist.is_initialized()
Expand All @@ -296,6 +302,7 @@ def _save():
coordinator_rank=coordinator,
no_dist=no_dist,
)

self._test_dist_failure(_save, kwargs)

def _test_load(self, state_dict, coordinator=0, **kwargs):
Expand All @@ -317,9 +324,9 @@ def _load():
@requires_nccl()
def test_save_error_handling(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}

self._test_save(state_dict, fail_init=[0])
Expand All @@ -334,10 +341,7 @@ def test_save_error_handling(self) -> None:
self._test_save(state_dict, coordinator=1, fail_finish=[1])

def test_save_error_handling_no_dist(self) -> None:
state_dict = {
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
}
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}

self.assertFalse(dist.is_initialized())

Expand All @@ -354,9 +358,9 @@ def test_save_error_handling_no_dist(self) -> None:
@requires_nccl()
def test_load_error_handling(self) -> None:
state_dict = {
'sharded': sharded_tensor.rand(self.get_spec(), 20, 20),
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}

self._test_load(state_dict)
Expand All @@ -373,12 +377,8 @@ def test_load_error_handling(self) -> None:
self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])


def test_load_error_handling_no_dist(self) -> None:
state_dict = {
'replicated': torch.rand(10, 10),
'bytes': [1, 2, 3, 4]
}
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
self._test_load(state_dict)
self._test_load(state_dict, fail_init=[0])
self._test_load(state_dict, fail_read_metadata=[0])
Expand All @@ -387,5 +387,6 @@ def test_load_error_handling_no_dist(self) -> None:
self._test_load(state_dict, fail_read_data=[0])
self._test_load(state_dict, fail_read_data_async=[0])


if __name__ == "__main__":
run_tests()
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_distributed_checkpoint.py
Expand Up @@ -5,7 +5,7 @@

import torch
from torch import distributed as dist
from torch.distributed._shard.checkpoint import (
from torch.distributed.checkpoint import (
FileSystemReader,
FileSystemWriter,
load_state_dict,
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/checkpoint/state_dict_loader.py
Expand Up @@ -59,9 +59,9 @@ def load_state_dict(
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_loader = torch.distributed._shard.checkpoint.FileSystemLoader("/checkpoint/1")
>>> fs_storage_loader = torch.distributed.checkpoint.FileSystemLoader("/checkpoint/1")
>>> torch.distributed._shard.checkpoint.load_state_dict(
>>> torch.distributed.checkpoint.load_state_dict(
>>> state_dict=model_state_dict,
>>> storage_reader=fs_storage_loader,
>>> )
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/checkpoint/state_dict_saver.py
Expand Up @@ -59,8 +59,8 @@ def save_state_dict(
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_writer = torch.distributed._shard.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed._shard.checkpoint.save_state_dict(
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save_state_dict(
>>> state_dict=model_state_dict,
>>> storage_writer=fs_stroage_writer,
>>> )
Expand Down

0 comments on commit 03de71e

Please sign in to comment.