Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed and Bucketing Performance Improvements #348

Open
parthmannan opened this issue May 2, 2024 · 1 comment
Open

Distributed and Bucketing Performance Improvements #348

parthmannan opened this issue May 2, 2024 · 1 comment
Assignees
Labels
bug Something isn't working distributed enhancement New feature or request performance

Comments

@parthmannan
Copy link
Collaborator

parthmannan commented May 2, 2024

馃悰 Bug

This is a lengthy issue/post detailing my observations with our distributed and bucketing performance. Some of these are actionable items and some are just observations to be aware of.

FSDP ZeRO2 (Bucketing=None)

Screenshot 2024-05-02 at 1 20 32 PM

AllGather operations during the forward pass are launched before the computation begins. This is because the Thunder trace schedules the AllGather all before the computation and also calls the wait operators before any compute begins.
The long line of operations in stream22 are all AG kernels. This is bad for performance because -

  • There is no overlap with forward pass compute

FSDP ZeRO2 (Bucketing=Block)

Screenshot 2024-05-02 at 1 34 15 PM
  • Suffers the same problem as no bucketing with all AllGather operations being launched before computation.
  • Launches Concat CatArrayBatchedCopy_contig kernels for each AllGather operation which would prevent overlap with compute even if the AG launch schedule was interleaved.
  • Launches a lot of direct_copy_kernel kernels as well when using larger bucketing. This would also prevent any overlap with compute and degrade performance. My understanding is that this might be because we concat the parameters, copy them into a buffer before communication.

Is there a better way of allocating these buffers only in the first iteration and using portion of these buffers for the computation instead of concat+copy every iteration?

For example, below is the execution timeline for TorchInductor
Screenshot 2024-05-02 at 1 40 57 PM

FSDP ZeRO3 (Bucketing=None)

Screenshot 2024-05-02 at 1 43 05 PM

When using ZeRO3, the schedule is as expected with the AG kernels and compute kernels being interleaved. However, due to launch overheads and small message sizes without bucketing, there are many gaps where the compute is not being overlapped with communication.
There is probably room for improvement in the launch overheads (maybe the schedule even?) to improve performance but there is no fundamental bug here. This is just an observation.

FSDP ZeRO3 (Bucketing=Block)

Screenshot 2024-05-02 at 1 54 03 PM
  • Suffers the same issue as ZeRO 2 Block bucketing with many direct_copy_kernel kernels being launched with AG that adds overhead to the compute stream.
  • [See screenshot below] Another issue here is that the backward pass has most of the AllGather kernels being launched well before any ReduceScatter kernels launches. While I haven't seen any major performance degradation because of this as most communication is still hidden by compute, there are some exposed communication because of this. There also may be performance issues due to the excessive communication being pushed to the end of the iteration. Additionally, delaying RS could be responsible for worse memory usage and in turn affect performance.
Screenshot 2024-05-02 at 2 14 06 PM

I am writing all of this here to have an easy comparison of all the options tried and facilitate discussion. Please let me know if some of these require individual issues to track and I can create those.

cc @carmocca @awaelchli @crcrpar @IvanYashchuk @mruberry @t-vi @tfogal

@parthmannan parthmannan added bug Something isn't working enhancement New feature or request performance distributed labels May 2, 2024
@IvanYashchuk
Copy link
Collaborator

Thank you, Parth, for this excellent analysis and accompanying screenshots!

AllGather operations during the forward pass are launched before the computation begins.

At some point, our sorting broke and we need to restore the intended functionality, here's the issue for this: #277

Is there a better way of allocating these buffers only in the first iteration and using portion of these buffers for the computation instead of concat+copy every iteration?

Yes, there's a better way, if we used a special interleaving copy it should be possible to do fewer copies and more views. We don't have an issue tracking this, but creating microbenchmarks for bucketing is in our plans.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed enhancement New feature or request performance
Projects
None yet
Development

No branches or pull requests

3 participants