-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
thunder.distributed.utils.sort_waits
is broken
#277
Comments
The following seems to fix the forward trace sorting. Need to check how it affects everything else. diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py
index 48cc66d4..540d117d 100644
--- a/thunder/distributed/utils.py
+++ b/thunder/distributed/utils.py
@@ -150,7 +150,7 @@ def sort_waits(execution_trace):
# Prefer nodes that are earlier in the trace
return order_in_trace[node.bsym]
- return min(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))
+ return max(range(len(eligible_nodes)), key=lambda i: key(eligible_nodes[i]))
new_execution_trace = from_trace(execution_trace)
@@ -160,8 +160,8 @@ def sort_waits(execution_trace):
lambda: "Cannot sort execution trace with del nodes",
)
new_execution_trace.bound_symbols = toposort_bsym_dag(
- bsym_list_to_dag(execution_trace.bound_symbols)[0],
- TOPOSORT_ORDER.TOP_DOWN,
+ bsym_list_to_dag(execution_trace.bound_symbols)[1],
+ TOPOSORT_ORDER.BOTTOM_UP,
selector=prefer_comm_over_other_over_wait,
)
return new_execution_trace
|
After bisect I found a76beb6 is the first bad commit. It can be reproduced by running git diff
diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py
index 6c9e5e88..adf16bcf 100644
--- a/thunder/benchmarks/distributed.py
+++ b/thunder/benchmarks/distributed.py
@@ -232,6 +232,7 @@ if __name__ == "__main__":
else:
kwargs["batchdims"] = (1,)
config = LitGPTConfig.from_name(args.model)
+ config.n_layer=2
b = LitGPTBenchmark(config, dtype=torch_dtype, **kwargs)
results: list[dict[str, int | float | str]] = []
diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py
index 2eb3f1f5..c380e7e4 100644
--- a/thunder/executors/torch_autograd.py
+++ b/thunder/executors/torch_autograd.py
@@ -305,7 +305,11 @@ def split_forward_backward(func, compile_data, compile_stats, /, *args, **kwargs
compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE,
)
if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO2:
+ with open('bef_bisect9','w') as f:
+ f.write(str(fw_extrace))
fw_extrace = sort_waits(fw_extrace)
+ with open('aft_bisect9','w') as f:
+ f.write(str(fw_extrace))
bw_extrace = sort_waits(bw_extrace)
if getattr(compile_data.fn, "use_ddp", False):
bw_extrace = sort_waits(bw_extrace) The reason is: Detailed explanation: Solution: |
馃悰 Bug
thunder.distributed.utils.sort_waits
is broken in that it does not sort the waits to be close to the consumers. The effect is that we don't have overlap of communication and computation.Steps to reproduce:
lightning-thunder/thunder/executors/torch_autograd.py
Line 192 in 4d9fa60
torchrun --nproc_per_node=2 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --micro_batch_size=1 --global_batch_size=2 --model_name=Llama-2-7b-hf
Inspect the
fw_extrace
before and aftersort_waits
.cc @carmocca @awaelchli @crcrpar
The text was updated successfully, but these errors were encountered: