You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a lengthy issue/post detailing my observations with our distributed and bucketing performance. Some of these are actionable items and some are just observations to be aware of.
FSDP ZeRO2 (Bucketing=None)
AllGather operations during the forward pass are launched before the computation begins. This is because the Thunder trace schedules the AllGather all before the computation and also calls the wait operators before any compute begins.
The long line of operations in stream22 are all AG kernels. This is bad for performance because -
There is no overlap with forward pass compute
FSDP ZeRO2 (Bucketing=Block)
Suffers the same problem as no bucketing with all AllGather operations being launched before computation.
Launches Concat CatArrayBatchedCopy_contig kernels for each AllGather operation which would prevent overlap with compute even if the AG launch schedule was interleaved.
Launches a lot of direct_copy_kernel kernels as well when using larger bucketing. This would also prevent any overlap with compute and degrade performance. My understanding is that this might be because we concat the parameters, copy them into a buffer before communication.
Is there a better way of allocating these buffers only in the first iteration and using portion of these buffers for the computation instead of concat+copy every iteration?
For example, below is the execution timeline for TorchInductor
FSDP ZeRO3 (Bucketing=None)
When using ZeRO3, the schedule is as expected with the AG kernels and compute kernels being interleaved. However, due to launch overheads and small message sizes without bucketing, there are many gaps where the compute is not being overlapped with communication.
There is probably room for improvement in the launch overheads (maybe the schedule even?) to improve performance but there is no fundamental bug here. This is just an observation.
FSDP ZeRO3 (Bucketing=Block)
Suffers the same issue as ZeRO 2 Block bucketing with many direct_copy_kernel kernels being launched with AG that adds overhead to the compute stream.
[See screenshot below] Another issue here is that the backward pass has most of the AllGather kernels being launched well before any ReduceScatter kernels launches. While I haven't seen any major performance degradation because of this as most communication is still hidden by compute, there are some exposed communication because of this. There also may be performance issues due to the excessive communication being pushed to the end of the iteration. Additionally, delaying RS could be responsible for worse memory usage and in turn affect performance.
I am writing all of this here to have an easy comparison of all the options tried and facilitate discussion. Please let me know if some of these require individual issues to track and I can create those.
Thank you, Parth, for this excellent analysis and accompanying screenshots!
AllGather operations during the forward pass are launched before the computation begins.
At some point, our sorting broke and we need to restore the intended functionality, here's the issue for this: #277
Is there a better way of allocating these buffers only in the first iteration and using portion of these buffers for the computation instead of concat+copy every iteration?
Yes, there's a better way, if we used a special interleaving copy it should be possible to do fewer copies and more views. We don't have an issue tracking this, but creating microbenchmarks for bucketing is in our plans.
馃悰 Bug
This is a lengthy issue/post detailing my observations with our distributed and bucketing performance. Some of these are actionable items and some are just observations to be aware of.
FSDP ZeRO2 (Bucketing=None)
AllGather operations during the forward pass are launched before the computation begins. This is because the Thunder trace schedules the AllGather all before the computation and also calls the
wait
operators before any compute begins.The long line of operations in stream22 are all AG kernels. This is bad for performance because -
FSDP ZeRO2 (Bucketing=Block)
CatArrayBatchedCopy_contig
kernels for each AllGather operation which would prevent overlap with compute even if the AG launch schedule was interleaved.direct_copy_kernel
kernels as well when using larger bucketing. This would also prevent any overlap with compute and degrade performance. My understanding is that this might be because we concat the parameters, copy them into a buffer before communication.Is there a better way of allocating these buffers only in the first iteration and using portion of these buffers for the computation instead of concat+copy every iteration?
For example, below is the execution timeline for
TorchInductor
FSDP ZeRO3 (Bucketing=None)
When using ZeRO3, the schedule is as expected with the AG kernels and compute kernels being interleaved. However, due to launch overheads and small message sizes without bucketing, there are many gaps where the compute is not being overlapped with communication.
There is probably room for improvement in the launch overheads (maybe the schedule even?) to improve performance but there is no fundamental bug here. This is just an observation.
FSDP ZeRO3 (Bucketing=Block)
direct_copy_kernel
kernels being launched with AG that adds overhead to the compute stream.I am writing all of this here to have an easy comparison of all the options tried and facilitate discussion. Please let me know if some of these require individual issues to track and I can create those.
cc @carmocca @awaelchli @crcrpar @IvanYashchuk @mruberry @t-vi @tfogal
The text was updated successfully, but these errors were encountered: