Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunder.distributed.utils.sort_waits is broken #277

Open
IvanYashchuk opened this issue Apr 25, 2024 · 2 comments 路 May be fixed by #383
Open

thunder.distributed.utils.sort_waits is broken #277

IvanYashchuk opened this issue Apr 25, 2024 · 2 comments 路 May be fixed by #383
Assignees
Labels
bug Something isn't working distributed

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Apr 25, 2024

馃悰 Bug

thunder.distributed.utils.sort_waits is broken in that it does not sort the waits to be close to the consumers. The effect is that we don't have overlap of communication and computation.

Steps to reproduce:

  1. Set a breakpoint at
    fw_extrace = sort_waits(fw_extrace)
  2. Apply the following patch (for shorter traces):
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index 770bca9b..014a49ed 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -71,6 +71,7 @@ class Benchmark_litGPT:
         self.device = device
         self.model_name = model_name
         self.config = Config.from_name(self.model_name)
+        self.config.n_layer = 2
         self.compile = compile
         self.dynamic = dynamic
         self.distributed_mode = distributed_mode
  1. Run on a 2 GPU system torchrun --nproc_per_node=2 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --micro_batch_size=1 --global_batch_size=2 --model_name=Llama-2-7b-hf

Inspect the fw_extrace before and after sort_waits.

cc @carmocca @awaelchli @crcrpar

@IvanYashchuk IvanYashchuk added bug Something isn't working distributed labels Apr 25, 2024
@IvanYashchuk
Copy link
Collaborator Author

The following seems to fix the forward trace sorting. Need to check how it affects everything else.

diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py
index 48cc66d4..540d117d 100644
--- a/thunder/distributed/utils.py
+++ b/thunder/distributed/utils.py
@@ -150,7 +150,7 @@ def sort_waits(execution_trace):
                     # Prefer nodes that are earlier in the trace
                     return order_in_trace[node.bsym]
 
-        return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))
+        return max(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))
 
     new_execution_trace = from_trace(execution_trace)
 
@@ -160,8 +160,8 @@ def sort_waits(execution_trace):
         lambda: "Cannot sort execution trace with del nodes",
     )
     new_execution_trace.bound_symbols = toposort_bsym_dag(
-        bsym_list_to_dag(execution_trace.bound_symbols)[0],
-        TOPOSORT_ORDER.TOP_DOWN,
+        bsym_list_to_dag(execution_trace.bound_symbols)[1],
+        TOPOSORT_ORDER.BOTTOM_UP,
         selector=prefer_comm_over_other_over_wait,
     )
     return new_execution_trace

@kiya00
Copy link
Collaborator

kiya00 commented May 8, 2024

After bisect I found a76beb6 is the first bad commit.

It can be reproduced by running python ../thunder/benchmarks/distributed.py --world-size 2 --model Llama-2-7b-hf -D fsdp --bucketing-strategies none --sharding-strategies zero2 --skip-torch with patch (the benchmark_litgpt.py mentioned above didn't exist at that time):

git diff
diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py
index 6c9e5e88..adf16bcf 100644
--- a/thunder/benchmarks/distributed.py
+++ b/thunder/benchmarks/distributed.py
@@ -232,6 +232,7 @@ if __name__ == "__main__":
         else:
             kwargs["batchdims"] = (1,)
         config = LitGPTConfig.from_name(args.model)
+        config.n_layer=2
         b = LitGPTBenchmark(config, dtype=torch_dtype, **kwargs)
 
     results: list[dict[str, int | float | str]] = []
diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py
index 2eb3f1f5..c380e7e4 100644
--- a/thunder/executors/torch_autograd.py
+++ b/thunder/executors/torch_autograd.py
@@ -305,7 +305,11 @@ def split_forward_backward(func, compile_data, compile_stats, /, *args, **kwargs
                 compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE,
             )
         if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO2:
+            with open('bef_bisect9','w') as f:
+                f.write(str(fw_extrace))
             fw_extrace = sort_waits(fw_extrace)
+            with open('aft_bisect9','w') as f:
+                f.write(str(fw_extrace))
             bw_extrace = sort_waits(bw_extrace)
     if getattr(compile_data.fn, "use_ddp", False):
         bw_extrace = sort_waits(bw_extrace)

The reason is:
In short: the order of topological equal torch_wait_prim_impl depends on the order it appears in the trace, and jit modifies the order of the input parameters, which changes the order of torch_wait_prim_impl

Detailed explanation:
the multiple torch_wait_prim_impl are topological equal, the key used in sort_waits is the same (len(order_in_trace)), and min pick the first one(return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))).

Solution:
I've had a similar problem with Zero3(https://github.com/Lightning-AI/lit-thunder-LEGACY/pull/2140#pullrequestreview-1888218006) and fixed it with a similar patch in sort_waits_for_zero3. I think this should fix the current problem until we have more requirements for topological equality in sort_waits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants