Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 19, 2024
1 parent 1b1f03e commit bf021fe
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 37 deletions.
67 changes: 31 additions & 36 deletions src/lightning/fabric/strategies/fsdp.py
Expand Up @@ -439,6 +439,7 @@ def save_checkpoint(
)
if filter is not None and self._state_dict_type == "sharded":
# https://github.com/pytorch/pytorch/issues/105379
# FIXME: revisit support with new APIs
raise NotImplementedError(
"FSDP doesn't support loading sharded filtered checkpoints, so saving them is disabled."
)
Expand Down Expand Up @@ -468,19 +469,16 @@ def save_checkpoint(
path.unlink()
path.mkdir(parents=True, exist_ok=True)

converted_state, metadata = _save_state_dict(
state, module, path, filter, self._state_dict_type, self.world_size
)
converted_state, metadata = _get_state_dict(state, module, filter, self._state_dict_type, self.world_size)
_distributed_checkpoint_save(converted_state, path)
if self.global_rank == 0:
torch.save(metadata, path / _METADATA_FILENAME)

elif self._state_dict_type == "full":
if _is_sharded_checkpoint(path):
shutil.rmtree(path)

converted_state, metadata = _save_state_dict(
state, module, path, filter, self._state_dict_type, self.world_size
)
converted_state, metadata = _get_state_dict(state, module, filter, self._state_dict_type, self.world_size)
converted_state.update(metadata)
if self.global_rank == 0:
torch.save(converted_state, path)
Expand Down Expand Up @@ -540,7 +538,7 @@ def load_checkpoint(
module_key, module = list(modules.items())[0]

if _is_sharded_checkpoint(path):
_load_state_dict(module, module_key, optimizers, path, "sharded", strict, self.world_size)
_set_state_dict(module, module_key, optimizers, path, "sharded", strict, self.world_size)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
Expand All @@ -554,7 +552,7 @@ def load_checkpoint(
return metadata

if _is_full_checkpoint(path):
checkpoint = _load_state_dict(module, module_key, optimizers, path, "full", strict, self.world_size)
checkpoint = _set_state_dict(module, module_key, optimizers, path, "full", strict, self.world_size)
assert checkpoint is not None

requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
Expand Down Expand Up @@ -862,10 +860,9 @@ def _distributed_checkpoint_save(state_dict: Dict[str, Any], path: Path) -> None
save(state_dict, writer)


def _save_state_dict(
def _get_state_dict(
state: Dict[str, Any],
module: Module,
path: Path,
filter: Optional[Dict[str, Callable[[str, Any], bool]]],
state_dict_type: Literal["sharded", "full"],
world_size: int,
Expand Down Expand Up @@ -910,12 +907,28 @@ def _save_state_dict(
target_dict = metadata
_apply_filter(key, filter or {}, converted, target_dict)

_distributed_checkpoint_save(converted_state, path)

return converted_state, metadata


def _load_state_dict(
def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import load

# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load
reader = FileSystemReader(path=path)
load(module_state, reader)


def _set_state_dict(
module: Module,
module_key: str,
optimizers: Dict[str, torch.optim.Optimizer],
Expand All @@ -931,14 +944,14 @@ def _load_state_dict(
set_optimizer_state_dict,
)

options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False)
options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False, strict=strict)
module_state = {module_key: module.state_dict()}
_distributed_checkpoint_load(module_state, path)
set_model_state_dict(module, module_state, options=options) # type: ignore[arg-type]
for key, optimizer in optimizers.items():
optimizer_state = {key: optimizer.state_dict()}
set_model_state_dict(module, module_state[module_key], options=options) # type: ignore[arg-type]
for optim_key, optim in optimizers.values():
optimizer_state = {optim_key: optim_key.state_dict()}
_distributed_checkpoint_load(optimizer_state, path)
set_optimizer_state_dict(module, optimizer, optim_state_dict=optimizer_state, options=options)
set_optimizer_state_dict(module, optim, optim_state_dict=optimizer_state[optim_key], options=options)
else:
if state_dict_type == "sharded":
state_dict_ctx = _get_sharded_state_dict_context(module)
Expand Down Expand Up @@ -998,21 +1011,3 @@ def _load_state_dict(
optim.load_state_dict(optim_state_dict)

return checkpoint


def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import load

# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load
reader = FileSystemReader(path=path)
load(module_state, reader)
2 changes: 1 addition & 1 deletion tests/tests_fabric/strategies/test_fsdp.py
Expand Up @@ -244,7 +244,7 @@ def test_fsdp_save_checkpoint_storage_options(tmp_path):
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp._save_state_dict", return_value=({}, {}))
@mock.patch("lightning.fabric.strategies.fsdp._get_state_dict", return_value=({}, {}))
@mock.patch("lightning.fabric.strategies.fsdp.torch.save")
@mock.patch("lightning.fabric.strategies.fsdp.shutil")
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
Expand Down

0 comments on commit bf021fe

Please sign in to comment.