New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experimental support for fairscale ShardedDDP #9139
Conversation
wrt your notes on GPU memory consumption improvements - from what I have seen checking GPU allocation often doesn't show the real difference, as pytorch tends to use more than it absolutely needs if there is spare memory - or rather it can go with less when the memory is tight - so to get the best improvements stats it's the best to try to push instead the BS until it OOMs, and then you get a more precise difference - which usually leads to more precise improvement numbers than just comparing memory allocation. This is just in my experience. All I'm saying is that probably the improvements are even better than what they seem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool, very clean!
finetune_trainer crashes with this option:
could probably extend |
Oh it's just because it overrides the |
OK, next we have this:
Coincidentally I have just had the same issue with deepspeed integration when I enable its internal fp16 handling. Didn't get to the root of it yet, but removing note: I'm switching to deepspeed fp16 handling there... |
Is it FP16 with AMP or with apex? I don't believe fairscale is compatible with apex. |
native amp See the command line I'm testing with at: |
@@ -215,6 +215,9 @@ class TrainingArguments: | |||
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or | |||
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the | |||
other choices will force the requested backend. | |||
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed | |||
training only). This is an experimental feature. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
@@ -35,6 +36,10 @@ | |||
from transformers.training_args import ParallelMode | |||
|
|||
|
|||
if is_fairscale_available(): | |||
from fairscale.optim import OSS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OSS
is a bit cryptic to me, but I think it's still better to use the "real" name instead of import OSS as OptimizerStateSharding
-> so good for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm using the same convention they do too, to not surprise any user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean!
hey there, a bit late, but one of the fairscale/shardedDDP author. The issue with Apex (and vanilla Torch) grad scaler is that it does not know about the gradient sharding, so not all the ranks will have the same behaviour. Torch AMP is supported though, you just have to pass in the ShardedGradScaler as defined here https://github.com/facebookresearch/fairscale/blob/master/fairscale/optim/grad_scaler.py#L24 |
Yes, we're passing that scaler :-) The issue was with AMP not Apex. It looks like there is a problem with or without FP16 with one of models. |
* Experimental stupport for fairscale ShardedDDP * Add import error if fairscale not available * Address review comments * Fix seq2seq trainer
What does this PR do?
This PR adds support for FairScale's shared DDP training to save GPU memory when training distributed models. Initial tests see a nice reduction of GPU memory used indeed!
This follows the steps of the main example provided on the FairScale repo, integrating them in our Trainer API. To activate training with shared DDP, one must pass along the flag
--sharded_ddp
in a distributed launch command.Benchmarks tried:
bert_base_uncased
-> goes from 5GB per GPU to 4GB per GPU with no hurt on accuracyxlnet_large-cased
-> goes from 11.5GB per GPU to 8GB per GPU (didn't go until the end so didn't check if the accuracy was the same. Training loss seemed equivalent.)