Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

No clear way to load models #78

Open
stephenroller opened this issue May 10, 2022 · 6 comments
Open

No clear way to load models #78

stephenroller opened this issue May 10, 2022 · 6 comments
Labels
better-eng Things that can help make things sane bug Something isn't working

Comments

@stephenroller
Copy link
Contributor

stephenroller commented May 10, 2022

馃殌 Feature Request

Loading models is a bit of a pain right now. It's done differently in multiple scripts (including our internal eval scripts). Not all ways are compatible with all checkpoint forms.

This typically requires setting a TON of command line args based on what the model checkpoints need (--model-parallel, --ddp-backend fully_sharded, --distributed-port, etc.). Many of these args can be picked up by just looking at the files.

Afterwards we should refactor a few scripts to use this One True Method

@stephenroller stephenroller added the bug Something isn't working label May 10, 2022
@patrickvonplaten
Copy link
Contributor

Any way to provide the different eval scripts? :-)

Is this related to #73 ?

@nickums
Copy link

nickums commented May 12, 2022

I can not find metaseq-api-local.py anywhere in OPT/

@punitkoura
Copy link
Contributor

From #277

We should make model loading "just work". I shouldn't need to pass so many args to get it to find the right checkpoint.
I should be able to specify sharded checkpoints by pointing to the shard0-rank0 pt.

@punitkoura punitkoura self-assigned this Aug 1, 2022
@punitkoura
Copy link
Contributor

punitkoura commented Aug 4, 2022

Types of model checkpoints

We currently have three types of model checkpoints -
1. Singleton checkpoint - For example, the 355M checkpoint. The file format here is like reshard.pt .
2. Unsharded model parallel checkpoint - The file format here is like reshard-model_part-*.pt where * goes from 0 to number_of_model_parts - 1 .
3. Sharded model parallel checkpoint - The file format here is like reshard-model_part-0-shard0.pt , where the model part and shard numbers range over the number of model parallel parts and fully sharded data parallel shards respectively.

Here, the name "reshard" is just a convention. It can be any name. For example - "125m-model_part-0-shard0.pt"

@punitkoura
Copy link
Contributor

punitkoura commented Aug 4, 2022

How do we determine the type of model checkpoint?

cfg.common.model_parallel_size - Which determines the model parallel size. If this is 1, we can infer that the model is not model parallel. However, it might still be sharded through FSDP.

cfg.checkpoint.checkpoint_shard_count - Which determines the number of FSDP shards we have for the model. For model parallel models, each model part has these many shards.

If both these parameters are 1, we have a singleton model.


Both these config values can be determined from the model checkpoint itself.

@punitkoura
Copy link
Contributor

cfg.distributed_training.use_sharded_state - if True, then state_dict will return FSDP.local_state_dict and load_state_dict will call FSDP.load_local_state_dict. Otherwise, state_dict will return the full model weights on data parallel rank 0 (empty on other ranks) and load_state_dict will broadcast model weights from rank 0 to other ranks.

From metaseq/distributed/fully_sharded_data_parallel.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better-eng Things that can help make things sane bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants