-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix sort_waits
to move wait
closer to its consumer (#277)
#383
base: main
Are you sure you want to change the base?
Conversation
Unfortunately this change doesn't help the performance on Llama-2-7b-hf Run benchmark with Env: before this commit(on 7cff363):
This PR:
cc: @IvanYashchuk |
Is the wait operation now inserted at what seems like the right place to allow computation-communication overlap? What does Nsight Systems profiling tell about the overlap? |
I think it's expected that all the allgathers are launched in the beginning. To reduce the overhead, rate limiting we do for zero3 could be needed as well for zero2, which #259 is addressing by adding an argument |
For zero3 rate limiting is needed to limit the peak allocated memory because every call to allgather allocates the output unsharded tensor. We need to limit the number of active allgathers until the result is consumed and freed. For zero2 we don't care about the memory consumption because all the unsharded tensors are saved for backward. Are there other effects of limiting the number of allgathers besides peak memory allocation? |
In zero2 the forward trace would consists of a sequence of all-gather's followed by another of computations, which would explain the long idling in the compute stream (at least on paper). |
For #277, after bisect I found a76beb6 is the first bad commit. After comparing the trace before/after this commit, I found the order of allgathers changes. Before the commit, the first consumed allgather is always the first to appear in the trace, and so on. So I did some experiment on early transform to specifically reorder the parameters for Llama-2-7b-hf as @t-vi suggested, the allgathers can overlap better with the computation: Patch specific for Llama-2-7b-hf
|
for more information, see https://pre-commit.ci
Hi @IvanYashchuk @crcrpar , use the sort_wait_zero3(sort the allgather+wait just before consumer) + unlimited number of inflight allgather(push allgathers to the beginning of the trace) can fix the problem. |
Before submitting
What does this PR do?
Fixes #277.
For more details, please see: #277 (comment)
cc @carmocca @awaelchli @crcrpar