Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sort_waits to move wait closer to its consumer (#277) #383

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented May 8, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #277.

For more details, please see: #277 (comment)

cc @carmocca @awaelchli @crcrpar

@kiya00
Copy link
Collaborator Author

kiya00 commented May 8, 2024

Unfortunately this change doesn't help the performance on Llama-2-7b-hf

Run benchmark with torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --micro_batch_size=2 --global_batch_size=16 --model_name=Llama-2-7b-hf --return_metrics_as_json=True --json_path=benchmark_litgpt_datanew.json

Env: H100 80GB * 8, nvfuser: 0.2.3+git729f36c

before this commit(on 7cff363):

    "average_iter_time": 787.7525961026549,
    "model_flops": 377527625318400,
    "model_flop_per_sec": 3834022026881291.5,
    "tokens_per_sec": 83194.7289094999,
    "memory_used_GB": 65.789196288,
    "model_name": "Llama-2-7b-hf",
    "Num GPUS": 8,
    "Seq Len": 4096,
    "Micro BS": 2,
    "Global BS": 16,
    "GA": 1,
    "Distributed Mode": "fsdp_zero2_none_bucketing",
    "Sharding Size": null,
    "compiler": "thunder"

This PR:

    "average_iter_time": 793.2260634377599,
    "model_flops": 377527625318400,
    "model_flop_per_sec": 3809612930806172.5,
    "tokens_per_sec": 82665.07411965841,
    "memory_used_GB": 65.790769152,
    "model_name": "Llama-2-7b-hf",
    "Num GPUS": 8,
    "Seq Len": 4096,
    "Micro BS": 2,
    "Global BS": 16,
    "GA": 1,
    "Distributed Mode": "fsdp_zero2_none_bucketing",
    "Sharding Size": null,
    "compiler": "thunder"

cc: @IvanYashchuk

@IvanYashchuk
Copy link
Collaborator

Unfortunately this change doesn't help the performance on Llama-2-7b-hf

Is the wait operation now inserted at what seems like the right place to allow computation-communication overlap? What does Nsight Systems profiling tell about the overlap?

@kiya00
Copy link
Collaborator Author

kiya00 commented May 14, 2024

Is the wait operation now inserted at what seems like the right place to allow computation-communication overlap? What does Nsight Systems profiling tell about the overlap?

Before this commit:
image
This PR:
image
Although the performance doesn't change, after this PR, the allgathers in stream 22 has overlap with the computation in stream 7.

@IvanYashchuk

thunder/tests/distributed/test_ddp.py Outdated Show resolved Hide resolved
@kiya00
Copy link
Collaborator Author

kiya00 commented May 15, 2024

Hi @t-vi @carmocca , I think it's ready to merge

@kiya00 kiya00 enabled auto-merge (squash) May 15, 2024 17:07
@kiya00 kiya00 disabled auto-merge May 15, 2024 17:35
@parthmannan
Copy link
Collaborator

Just tried this branch and while I do see some small overlap with the first few layers, majority is still not overlapped and the AllGathers are launched much ahead of compute.

Screenshot 2024-05-15 at 5 31 49 PM

@crcrpar
Copy link
Collaborator

crcrpar commented May 16, 2024

I think it's expected that all the allgathers are launched in the beginning. To reduce the overhead, rate limiting we do for zero3 could be needed as well for zero2, which #259 is addressing by adding an argument

@IvanYashchuk
Copy link
Collaborator

Just tried this branch and while I do see some small overlap with the first few layers, majority is still not overlapped and the AllGathers are launched much ahead of compute.

Screenshot 2024-05-15 at 5 31 49 PM

Are stream barriers visible in the profile? They would be a good indicator of whether we're doing the correct thing on the trace level.

@IvanYashchuk
Copy link
Collaborator

I think it's expected that all the allgathers are launched in the beginning. To reduce the overhead, rate limiting we do for zero3 could be needed as well for zero2, which #259 is addressing by adding an argument

For zero3 rate limiting is needed to limit the peak allocated memory because every call to allgather allocates the output unsharded tensor. We need to limit the number of active allgathers until the result is consumed and freed.

For zero2 we don't care about the memory consumption because all the unsharded tensors are saved for backward. Are there other effects of limiting the number of allgathers besides peak memory allocation?

@crcrpar
Copy link
Collaborator

crcrpar commented May 16, 2024

In zero2 the forward trace would consists of a sequence of all-gather's followed by another of computations, which would explain the long idling in the compute stream (at least on paper).
To shorten the first sequence of all-gather's, the rate limiting should help as it reorder the sequence of bsyms so that it will be a few all-gather's followed by some compute using the those all-gather's outputs.

@kiya00
Copy link
Collaborator Author

kiya00 commented May 22, 2024

For #277, after bisect I found a76beb6 is the first bad commit. After comparing the trace before/after this commit, I found the order of allgathers changes. Before the commit, the first consumed allgather is always the first to appear in the trace, and so on. So I did some experiment on early transform to specifically reorder the parameters for Llama-2-7b-hf as @t-vi suggested, the allgathers can overlap better with the computation:
image
The next step will be to rearrange all the allgathers in the order of consumption

Patch specific for Llama-2-7b-hf
diff --git a/thunder/__init__.py b/thunder/__init__.py
index adc6d692..a82d0d1d 100644
--- a/thunder/__init__.py
+++ b/thunder/__init__.py
@@ -301,7 +301,8 @@ def jit(
         interpreter = _general_frontend
 
     if early_transforms is None:
-        early_transforms = []
+        # early_transforms = []
+        early_transforms = [tmp_trans, ]
 
     if additional_transforms is None:
         additional_transforms = []
@@ -938,3 +939,36 @@ def grad(fn):
         return original_result, original_trace
 
     return _fn
+def tmp_trans(prologue_trc, computation_trc, epilogue_trc, executors_list):
+    old_order = ("tos1","sin","transformer_wte_weight","transformer_h_0_norm_1_weight","transformer_h_0_attn_attn_weight","transformer_h_0_attn_proj_weight","transformer_h_0_norm_2_weight","transformer_h_0_mlp_fc_1_weight","transformer_h_0_mlp_fc_2_weight","transformer_h_0_mlp_proj_weight","transformer_h_1_norm_1_weight","transformer_h_1_attn_attn_weight","transformer_h_1_attn_proj_weight","transformer_h_1_norm_2_weight","transformer_h_1_mlp_fc_1_weight","transformer_h_1_mlp_fc_2_weight","transformer_h_1_mlp_proj_weight","transformer_h_2_norm_1_weight","transformer_h_2_attn_attn_weight","transformer_h_2_attn_proj_weight","transformer_h_2_norm_2_weight","transformer_h_2_mlp_fc_1_weight","transformer_h_2_mlp_fc_2_weight","transformer_h_2_mlp_proj_weight","transformer_h_3_norm_1_weight","transformer_h_3_attn_attn_weight","transformer_h_3_attn_proj_weight","transformer_h_3_norm_2_weight","transformer_h_3_mlp_fc_1_weight","transformer_h_3_mlp_fc_2_weight","transformer_h_3_mlp_proj_weight","transformer_h_4_norm_1_weight","transformer_h_4_attn_attn_weight","transformer_h_4_attn_proj_weight","transformer_h_4_norm_2_weight","transformer_h_4_mlp_fc_1_weight","transformer_h_4_mlp_fc_2_weight","transformer_h_4_mlp_proj_weight","transformer_h_5_norm_1_weight","transformer_h_5_attn_attn_weight","transformer_h_5_attn_proj_weight","transformer_h_5_norm_2_weight","transformer_h_5_mlp_fc_1_weight","transformer_h_5_mlp_fc_2_weight","transformer_h_5_mlp_proj_weight","transformer_h_6_norm_1_weight","transformer_h_6_attn_attn_weight","transformer_h_6_attn_proj_weight","transformer_h_6_norm_2_weight","transformer_h_6_mlp_fc_1_weight","transformer_h_6_mlp_fc_2_weight","transformer_h_6_mlp_proj_weight","transformer_h_7_norm_1_weight","transformer_h_7_attn_attn_weight","transformer_h_7_attn_proj_weight","transformer_h_7_norm_2_weight","transformer_h_7_mlp_fc_1_weight","transformer_h_7_mlp_fc_2_weight","transformer_h_7_mlp_proj_weight","transformer_h_8_norm_1_weight","transformer_h_8_attn_attn_weight","transformer_h_8_attn_proj_weight","transformer_h_8_norm_2_weight","transformer_h_8_mlp_fc_1_weight","transformer_h_8_mlp_fc_2_weight","transformer_h_8_mlp_proj_weight","transformer_h_9_norm_1_weight","transformer_h_9_attn_attn_weight","transformer_h_9_attn_proj_weight","transformer_h_9_norm_2_weight","transformer_h_9_mlp_fc_1_weight","transformer_h_9_mlp_fc_2_weight","transformer_h_9_mlp_proj_weight","transformer_h_10_norm_1_weight","transformer_h_10_attn_attn_weight","transformer_h_10_attn_proj_weight","transformer_h_10_norm_2_weight","transformer_h_10_mlp_fc_1_weight","transformer_h_10_mlp_fc_2_weight","transformer_h_10_mlp_proj_weight","transformer_h_11_norm_1_weight","transformer_h_11_attn_attn_weight","transformer_h_11_attn_proj_weight","transformer_h_11_norm_2_weight","transformer_h_11_mlp_fc_1_weight","transformer_h_11_mlp_fc_2_weight","transformer_h_11_mlp_proj_weight","transformer_h_12_norm_1_weight","transformer_h_12_attn_attn_weight","transformer_h_12_attn_proj_weight","transformer_h_12_norm_2_weight","transformer_h_12_mlp_fc_1_weight","transformer_h_12_mlp_fc_2_weight","transformer_h_12_mlp_proj_weight","transformer_h_13_norm_1_weight","transformer_h_13_attn_attn_weight","transformer_h_13_attn_proj_weight","transformer_h_13_norm_2_weight","transformer_h_13_mlp_fc_1_weight","transformer_h_13_mlp_fc_2_weight","transformer_h_13_mlp_proj_weight","transformer_h_14_norm_1_weight","transformer_h_14_attn_attn_weight","transformer_h_14_attn_proj_weight","transformer_h_14_norm_2_weight","transformer_h_14_mlp_fc_1_weight","transformer_h_14_mlp_fc_2_weight","transformer_h_14_mlp_proj_weight","transformer_h_15_norm_1_weight","transformer_h_15_attn_attn_weight","transformer_h_15_attn_proj_weight","transformer_h_15_norm_2_weight","transformer_h_15_mlp_fc_1_weight","transformer_h_15_mlp_fc_2_weight","transformer_h_15_mlp_proj_weight","transformer_h_16_norm_1_weight","transformer_h_16_attn_attn_weight","transformer_h_16_attn_proj_weight","transformer_h_16_norm_2_weight","transformer_h_16_mlp_fc_1_weight","transformer_h_16_mlp_fc_2_weight","transformer_h_16_mlp_proj_weight","transformer_h_17_norm_1_weight","transformer_h_17_attn_attn_weight","transformer_h_17_attn_proj_weight","transformer_h_17_norm_2_weight","transformer_h_17_mlp_fc_1_weight","transformer_h_17_mlp_fc_2_weight","transformer_h_17_mlp_proj_weight","transformer_h_18_norm_1_weight","transformer_h_18_attn_attn_weight","transformer_h_18_attn_proj_weight","transformer_h_18_norm_2_weight","transformer_h_18_mlp_fc_1_weight","transformer_h_18_mlp_fc_2_weight","transformer_h_18_mlp_proj_weight","transformer_h_19_norm_1_weight","transformer_h_19_attn_attn_weight","transformer_h_19_attn_proj_weight","transformer_h_19_norm_2_weight","transformer_h_19_mlp_fc_1_weight","transformer_h_19_mlp_fc_2_weight","transformer_h_19_mlp_proj_weight","transformer_h_20_norm_1_weight","transformer_h_20_attn_attn_weight","transformer_h_20_attn_proj_weight","transformer_h_20_norm_2_weight","transformer_h_20_mlp_fc_1_weight","transformer_h_20_mlp_fc_2_weight","transformer_h_20_mlp_proj_weight","transformer_h_21_norm_1_weight","transformer_h_21_attn_attn_weight","transformer_h_21_attn_proj_weight","transformer_h_21_norm_2_weight","transformer_h_21_mlp_fc_1_weight","transformer_h_21_mlp_fc_2_weight","transformer_h_21_mlp_proj_weight","transformer_h_22_norm_1_weight","transformer_h_22_attn_attn_weight","transformer_h_22_attn_proj_weight","transformer_h_22_norm_2_weight","transformer_h_22_mlp_fc_1_weight","transformer_h_22_mlp_fc_2_weight","transformer_h_22_mlp_proj_weight","transformer_h_23_norm_1_weight","transformer_h_23_attn_attn_weight","transformer_h_23_attn_proj_weight","transformer_h_23_norm_2_weight","transformer_h_23_mlp_fc_1_weight","transformer_h_23_mlp_fc_2_weight","transformer_h_23_mlp_proj_weight","transformer_h_24_norm_1_weight","transformer_h_24_attn_attn_weight","transformer_h_24_attn_proj_weight","transformer_h_24_norm_2_weight","transformer_h_24_mlp_fc_1_weight","transformer_h_24_mlp_fc_2_weight","transformer_h_24_mlp_proj_weight","transformer_h_25_norm_1_weight","transformer_h_25_attn_attn_weight","transformer_h_25_attn_proj_weight","transformer_h_25_norm_2_weight","transformer_h_25_mlp_fc_1_weight","transformer_h_25_mlp_fc_2_weight","transformer_h_25_mlp_proj_weight","transformer_h_26_norm_1_weight","transformer_h_26_attn_attn_weight","transformer_h_26_attn_proj_weight","transformer_h_26_norm_2_weight","transformer_h_26_mlp_fc_1_weight","transformer_h_26_mlp_fc_2_weight","transformer_h_26_mlp_proj_weight","transformer_h_27_norm_1_weight","transformer_h_27_attn_attn_weight","transformer_h_27_attn_proj_weight","transformer_h_27_norm_2_weight","transformer_h_27_mlp_fc_1_weight","transformer_h_27_mlp_fc_2_weight","transformer_h_27_mlp_proj_weight","transformer_h_28_norm_1_weight","transformer_h_28_attn_attn_weight","transformer_h_28_attn_proj_weight","transformer_h_28_norm_2_weight","transformer_h_28_mlp_fc_1_weight","transformer_h_28_mlp_fc_2_weight","transformer_h_28_mlp_proj_weight","transformer_h_29_norm_1_weight","transformer_h_29_attn_attn_weight","transformer_h_29_attn_proj_weight","transformer_h_29_norm_2_weight","transformer_h_29_mlp_fc_1_weight","transformer_h_29_mlp_fc_2_weight","transformer_h_29_mlp_proj_weight","transformer_h_30_norm_1_weight","transformer_h_30_attn_attn_weight","transformer_h_30_attn_proj_weight","transformer_h_30_norm_2_weight","transformer_h_30_mlp_fc_1_weight","transformer_h_30_mlp_fc_2_weight","transformer_h_30_mlp_proj_weight","transformer_h_31_norm_1_weight","transformer_h_31_attn_attn_weight","transformer_h_31_attn_proj_weight","transformer_h_31_norm_2_weight","transformer_h_31_mlp_fc_1_weight","transformer_h_31_mlp_fc_2_weight","transformer_h_31_mlp_proj_weight","transformer_ln_f_weight","lm_head_weight","idx")
+    ret = prologue_trc.bound_symbols[-1]
+    assert(ret.sym.id == prims.PrimIDs.RETURN)
+
+    def sort_func(x):
+        if x.name in old_order:
+            return old_order.index(x.name)
+        assert(x.name[2:] in old_order)
+        return old_order.index(x.name[2:])
+    new_ret_args = tuple(sorted(ret.args[0], key=sort_func))
+    from dataclasses import dataclass, replace
+    new_ret = replace(ret, args=(new_ret_args,))
+    new_ret = replace(new_ret, output=(new_ret_args,))
+    prologue_trc.bound_symbols[-1]=new_ret
+
+    new_args = tuple(sorted(computation_trc.args, key=sort_func))
+    computation_trc.args=new_args
+    new_list = []
+    old_list = computation_trc.bound_symbols[0:len(old_order)]
+
+    for n in old_order:
+        tmp = [p for p in old_list if p.output.name == n or p.output.name == "t_"+n]
+        assert(len(tmp) == 1)
+        new_list.append(tmp[0])
+    assert(len(new_list)==len(old_order))
+    computation_trc.bound_symbols[0:len(old_order)] = new_list
+
+    siginfo_args = []
+    for bsym in new_list:
+        siginfo_args.append((bsym.output.name, None))
+    computation_trc._siginfo.args = siginfo_args
+    return prologue_trc, computation_trc, epilogue_trc
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index 4b553c2f..816c6f11 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -403,6 +403,11 @@ class Benchmark_litGPT:
             input_ids = input_ids.to(device=self.device)
             targets = targets.to(device=self.device)
             loss = run_fwd_bwd_one_microbatch(self.model, input_ids, targets, self.gradient_accumulation_steps)
+            # if i==0 and global_rank==0:
+            #     # with open('old_trace','w') as f:
+            #     with open('reordertrace','w') as f:
+            #         f.write(str(thunder.last_traces(self.model)[-1]))
+            #         f.write(str(thunder.last_backward_traces(self.model)[-1]))
 
             # Simple Gradient Accumulation Implementation
             self.optimizer.step()

@kiya00 kiya00 marked this pull request as draft May 22, 2024 08:47
@kiya00
Copy link
Collaborator Author

kiya00 commented May 23, 2024

Hi @IvanYashchuk @crcrpar , use the sort_wait_zero3(sort the allgather+wait just before consumer) + unlimited number of inflight allgather(push allgathers to the beginning of the trace) can fix the problem.
it sorts the allgathers to their consumer order and list them at the beginning of the trace, the corresponding waits are right before the consumers.
nsys results for Llama-2-7b-hf:
image

@kiya00 kiya00 marked this pull request as ready for review May 27, 2024 12:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

thunder.distributed.utils.sort_waits is broken
5 participants