Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 19, 2024
1 parent 1b1f03e commit 97469de
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions src/lightning/fabric/strategies/fsdp.py
Expand Up @@ -915,6 +915,24 @@ def _save_state_dict(
return converted_state, metadata


def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import load

# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load
reader = FileSystemReader(path=path)
load(module_state, reader)


def _load_state_dict(
module: Module,
module_key: str,
Expand All @@ -931,14 +949,14 @@ def _load_state_dict(
set_optimizer_state_dict,
)

options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False)
options = StateDictOptions(full_state_dict=state_dict_type == "full", cpu_offload=False, strict=strict)
module_state = {module_key: module.state_dict()}
_distributed_checkpoint_load(module_state, path)
set_model_state_dict(module, module_state, options=options) # type: ignore[arg-type]
for key, optimizer in optimizers.items():
optimizer_state = {key: optimizer.state_dict()}
set_model_state_dict(module, module_state[module_key], options=options) # type: ignore[arg-type]
for optim_key, optim in optimizers.values():
optimizer_state = {optim_key: optim_key.state_dict()}
_distributed_checkpoint_load(optimizer_state, path)
set_optimizer_state_dict(module, optimizer, optim_state_dict=optimizer_state, options=options)
set_optimizer_state_dict(module, optim, optim_state_dict=optimizer_state[optim_key], options=options)
else:
if state_dict_type == "sharded":
state_dict_ctx = _get_sharded_state_dict_context(module)
Expand Down Expand Up @@ -998,21 +1016,3 @@ def _load_state_dict(
optim.load_state_dict(optim_state_dict)

return checkpoint


def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
if _TORCH_GREATER_EQUAL_2_3:
from torch.distributed.checkpoint import load

# let torch automatically infer the reader to use. This might also support fsspec paths in the future
# https://github.com/pytorch/pytorch/issues/118036
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
else: # deprecated
from torch.distributed.checkpoint import FileSystemReader

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint import load
else:
from torch.distributed.checkpoint import load_state_dict as load
reader = FileSystemReader(path=path)
load(module_state, reader)

0 comments on commit 97469de

Please sign in to comment.