Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 19, 2024
1 parent 1b1f03e commit fb62091
Showing 1 changed file with 30 additions and 36 deletions.
66 changes: 30 additions & 36 deletions src/lightning/fabric/strategies/fsdp.py
Expand Up @@ -468,19 +468,16 @@ def save_checkpoint(
path.unlink()
path.mkdir(parents=True, exist_ok=True)

converted_state, metadata = _save_state_dict(
state, module, path, filter, self._state_dict_type, self.world_size
)
converted_state, metadata = _get_state_dict(state, module, filter, self._state_dict_type, self.world_size)
_distributed_checkpoint_save(converted_state, path)
if self.global_rank == 0:
torch.save(metadata, path / _METADATA_FILENAME)

elif self._state_dict_type == "full":
if _is_sharded_checkpoint(path):
shutil.rmtree(path)

converted_state, metadata = _save_state_dict(
state, module, path, filter, self._state_dict_type, self.world_size
)
converted_state, metadata = _get_state_dict(state, module, filter, self._state_dict_type, self.world_size)
converted_state.update(metadata)
if self.global_rank == 0:
torch.save(converted_state, path)
Expand Down Expand Up @@ -540,7 +537,7 @@ def load_checkpoint(
module_key, module = list(modules.items())[0]

if _is_sharded_checkpoint(path):
_load_state_dict(module, module_key, optimizers, path, "sharded", strict, self.world_size)
_set_state_dict(module, module_key, optimizers, path, "sharded", strict, self.world_size)

# Load metadata (anything not a module or optimizer)
metadata = torch.load(path / _METADATA_FILENAME)
Expand All @@ -554,7 +551,7 @@ def load_checkpoint(
return metadata

if _is_full_checkpoint(path):
checkpoint = _load_state_dict(module, module_key, optimizers, path, "full", strict, self.world_size)
checkpoint = _set_state_dict(module, module_key, optimizers, path, "full", strict, self.world_size)
assert checkpoint is not None

requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
Expand Down Expand Up @@ -862,10 +859,9 @@ def _distributed_checkpoint_save(state_dict: Dict[str, Any], path: Path) -> None
save(state_dict, writer)


def _save_state_dict(
def _get_state_dict(
state: Dict[str, Any],
module: Module,
path: Path,
filter: Optional[Dict[str, Callable[[str, Any], bool]]],
state_dict_type: Literal["sharded", "full"],
world_size: int,
Expand Down Expand Up @@ -910,12 +906,28 @@ def _save_state_dict(
target_dict = metadata
_apply_filter(key, filter or {}, converted, target_dict)

_distributed_checkpoint_save(converted_state, path)

return converted_state, metadata


def _load_state_dict(
def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import load

# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load
reader = FileSystemReader(path=path)
load(module_state, reader)


def _set_state_dict(
module: Module,
module_key: str,
optimizers: Dict[str, torch.optim.Optimizer],
Expand All @@ -931,14 +943,14 @@ def _load_state_dict(
set_optimizer_state_dict,
)

options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False)
options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False, strict=strict)
module_state = {module_key: module.state_dict()}
_distributed_checkpoint_load(module_state, path)
set_model_state_dict(module, module_state, options=options) # type: ignore[arg-type]
for key, optimizer in optimizers.items():
optimizer_state = {key: optimizer.state_dict()}
set_model_state_dict(module, module_state[module_key], options=options) # type: ignore[arg-type]
for optim_key, optim in optimizers.values():
optimizer_state = {optim_key: optim_key.state_dict()}
_distributed_checkpoint_load(optimizer_state, path)
set_optimizer_state_dict(module, optimizer, optim_state_dict=optimizer_state, options=options)
set_optimizer_state_dict(module, optim, optim_state_dict=optimizer_state[optim_key], options=options)
else:
if state_dict_type == "sharded":
state_dict_ctx = _get_sharded_state_dict_context(module)
Expand Down Expand Up @@ -998,21 +1010,3 @@ def _load_state_dict(
optim.load_state_dict(optim_state_dict)

return checkpoint


def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import load

# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load
reader = FileSystemReader(path=path)
load(module_state, reader)

0 comments on commit fb62091

Please sign in to comment.