Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Pipeline Parallel Support #4412

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

andoorve
Copy link
Contributor

@andoorve andoorve commented Apr 27, 2024

Adds initial pipeline parallelism support to vLLM.

ToDo:

Milestone 1: POC Prototype

Milestone 2: Mergeable

  • Fix issues related to LLaMa incorrect outputs (Bug filed against PyTorch [Distributed] P2P Operations on NCCL do not respect tag pytorch/pytorch#125079 and worked around)
  • Refactor to move sending and recving code out of models.
  • Check if there's a simpler way to do weight loading
  • Enable multi-node
  • Add RFC for community benefit
  • Add some testing
  • Assert out models that are not supported yet as well as LLMEngine.
  • Check if any PyNCCL changes are necessary
  • Rebase on latest
  • Tests passing

FIX #4461

Goals for this PR:

  • Functional eager-mode PP
  • Support AsyncLLMEngine
  • Support RayGPUExecutor
  • Support LLaMa/GPT2
  • Support chunked prefill

Non-goals for this PR (To be covered in future PRs)

  • Be fully optimized
  • Support LLMEngine (this may be removed in the future)
  • Support any other distributed backend
  • Support models other than LLaMa/GPT2
  • Support CUDAGraph (this is already supported in this PR but issues on this should not be blocking merge)

cc: @zhuohan123 @WoosukKwon @simon-mo @youkaichao

@robertgshaw2-neuralmagic
Copy link
Collaborator

@andoorve - Exciting!!!

@youkaichao
Copy link
Member

@andoorve thanks for the effort! Can you write an RFC to describe the overall design so that people can easily understand it? example rfcs: https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc

@andoorve
Copy link
Contributor Author

@youkaichao Yes for sure, it is one of the TODO items above

@@ -746,7 +763,8 @@ def execute_model(
logits = self.model.compute_logits(hidden_states, sampling_metadata)

# Only perform sampling in the driver worker.
if not self.is_driver_worker:
if (not (is_pipeline_model_parallel_last_rank()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so for tp, the first rank (driver) performs sampling, and for pp, the last rank (the last worker in the last pp's tp group) performs sampling, is this correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the first worker of the last PP's TP group

@andoorve
Copy link
Contributor Author

Updated the RFC here: #4461 @youkaichao

Let me know if anything needs further elaboration

@andoorve
Copy link
Contributor Author

FYI pretty sure PyTorch has a bug, filed here: pytorch/pytorch#125079

Worked around this last week by making sending and receiving phase for each model atomic by concatenating residuals and hidden states.

@youkaichao
Copy link
Member

@andoorve hi, I already made the change to pynccl to support multiple groups in #4512 . The first rank can be read from the group argument directly.

@andoorve
Copy link
Contributor Author

andoorve commented May 1, 2024

Sounds good @youkaichao, I can update mine once that's merged.

Will you also include the change to create the multiple CPU TP groups or should I create a separate PR?

@youkaichao
Copy link
Member

Will you also include the change to create the multiple CPU TP groups or should I create a separate PR?

Yes, that's also in my plan. I will break #4460 down into small pieces to be merged, ETA this week.

@andoorve
Copy link
Contributor Author

andoorve commented May 1, 2024

Sounds good - I'll revert the PyNCCL changes on this PR and wait for that to be merged to add in

@GindaChen
Copy link
Contributor

GindaChen commented May 1, 2024

Hey @andoorve - This is super exciting!

I'm trying to run a simple example with PP = 2, but encountered some error at runtime. I coded my own example using the simple example script examples/offline_inference.py and added the pipeline_parallel_size=2 in the argument.

- llm = LLM(model="facebook/opt-125m", load_format="dummy")
+ llm = LLM(model="facebook/opt-2.7b", pipeline_parallel_size=2, load_format="dummy")

This is the error I hit: error.txt. It seems like it's complaining the kv_caches list item not found (probably empty?)

ERROR 05-01 20:45:18 worker_base.py:147] Traceback (most recent call last): ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/worker_base.py", line 139, in execute_method ERROR 05-01 20:45:18 worker_base.py:147] return executor(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 05-01 20:45:18 worker_base.py:147] return func(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/worker.py", line 140, in determine_num_available_blocks ERROR 05-01 20:45:18 worker_base.py:147] self.model_runner.profile_run() ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 05-01 20:45:18 worker_base.py:147] return func(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/model_runner.py", line 844, in profile_run ERROR 05-01 20:45:18 worker_base.py:147] self.execute_model(seqs, kv_caches) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 05-01 20:45:18 worker_base.py:147] return func(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/worker/model_runner.py", line 763, in execute_model ERROR 05-01 20:45:18 worker_base.py:147] hidden_states = model_executable(**execute_model_kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl ERROR 05-01 20:45:18 worker_base.py:147] return self._call_impl(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl ERROR 05-01 20:45:18 worker_base.py:147] return forward_call(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/model_executor/models/opt.py", line 300, in forward ERROR 05-01 20:45:18 worker_base.py:147] hidden_states = self.model(input_ids, positions, kv_caches, ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl ERROR 05-01 20:45:18 worker_base.py:147] return self._call_impl(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl ERROR 05-01 20:45:18 worker_base.py:147] return forward_call(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/model_executor/models/opt.py", line 275, in forward ERROR 05-01 20:45:18 worker_base.py:147] return self.decoder(input_ids, positions, kv_caches, attn_metadata) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl ERROR 05-01 20:45:18 worker_base.py:147] return self._call_impl(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm-pp-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl ERROR 05-01 20:45:18 worker_base.py:147] return forward_call(*args, **kwargs) ERROR 05-01 20:45:18 worker_base.py:147] File "/workspace/vllm/vllm/model_executor/models/opt.py", line 249, in forward ERROR 05-01 20:45:18 worker_base.py:147] hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) ERROR 05-01 20:45:18 worker_base.py:147] IndexError: list index out of range

I haven't dug into the code deep enough, and curious what is the best way to test and play around with it. If you can point me to some potential starting point, that would be awesome enough. Thanks!

@andoorve
Copy link
Contributor Author

andoorve commented May 1, 2024

Hey @GindaChen there's a couple of things here,

We haven't supported OPT yet, also the LLMEngine entry point won't work. We're only supporting AsyncLLMEngine right now

@andoorve
Copy link
Contributor Author

andoorve commented May 1, 2024

The way I would recommend is try with the online serving entrypoint with the LLaMa model. That'd be the best way to start playing around with it

@GindaChen

@youkaichao
Copy link
Member

@andoorve FYI: pynccl with multiple groups is landed at #4512 .

@youkaichao
Copy link
Member

Will you also include the change to create the multiple CPU TP groups or should I create a separate PR?

@andoorve please check out #4566 and see if you need anything else.

@andoorve
Copy link
Contributor Author

andoorve commented May 2, 2024

LGTM - I guess one thing we can add is PP PyNCCL group

@youkaichao
Copy link
Member

LGTM - I guess one thing we can add is PP PyNCCL group

That's in my plan. Which operation do you need for pp? allreduce? gather? or anything else?

@andoorve
Copy link
Contributor Author

andoorve commented May 2, 2024

We only need point-to-point, blocking send and blocking recv only. It's not critical though unless torch.distributed.* ops don't work well with CUDA graph.

@SolitaryThinker
Copy link

Hi @andoorve,

While benchmarking using your PR, I've consistently encountered engine timeouts with smaller models on setups far below total VRAM capacity, which might relate to the issues you've linked (e.g., [Bug]: Engine iteration timed out #4293, #4430, #4135). I'm using commit 9d698fa.

Setup and Reproduction:
Models and Hardware:

  • Llama-2-7b-hf on 2x A100s
  • llama-160m on 2x RTX A4000s:
python -m vllm.entrypoints.openai.api_server --model JackFram/llama-160m \
--swap-space 16 \
--disable-log-requests \
--pipeline-parallel-size 2
python benchmarks/benchmark_serving.py --backend vllm --model JackFram/llama-160m \
--dataset-name sharegpt \
--dataset-path /workspace/sharegpt.json \
--num-prompts 3

Observation:
Engine hangs almost immediately with 3 running prompts, similar issues with larger models at non-infinite --request-rate.

Proposed Solution:

I traced the issue to async.gather(*coros) in ray_gpu_executor.py returning prematurely because it does not block on ray.ObjectRefs. Inserting ray.wait(coros[1:]) before the gather aligns with the intended code semantics and resolves the hanging.

Branch with fix: https://github.com/SolitaryThinker/vllm/tree/pipeline-parallel-fix

I noticed a new commit from you regarding TP+PP fix, but it didn’t resolve the issue in my environment. Could it be due to missing the latest pynccl changes with groups #4512?

This is my first time handling VLLM and Ray, so any insights or corrections on my understanding or approach would be greatly appreciated.

Additional technical details:
After some digging, I realized that async.gather(*coros) is returning before workers threads have finished. The cause is that coros consist of both futures and ray.ObjectRefs, the latter of which asyncio.gather does not appear to block on. Thus back in the run_engine_loop, the VE that is assumed to be finished executing after this call:

 done, _ = await asyncio.wait(requests_in_progress, return_when=asyncio.FIRST_COMPLETED)

call still could have workers running when a new engine_step task for the VE is created. I'm not sure the exact interaction that causes the hanging, but inserting a ray.wait(coros[1:]) before the gather seems to actually respect the intended semantics of the code to wait for materialization of the ray.objectref.

Thanks
-will

@andoorve
Copy link
Contributor Author

andoorve commented May 6, 2024

@SolitaryThinker

Thanks for the thorough investigation and the fix!

It's indeed true that there are existing issues with hanging on the current vLLM mainline, and I have not rebased on the latest PyNCCL changes yet. I also am unable to reproduce this issue easily with GPT2 when I try with my own testing. For these reasons I haven't investigated as deeply yet. I'll give your setup and fix a try once I check if multi-node is functional.

I wonder if this is a similar reason as to why the TP-only cases are hanging in the issues mentioned above since there is no such ray.wait in that situation as well. In the meanwhile @rkooo567 maybe you might have some comments?

@youkaichao
Copy link
Member

FYI: I recently find clean up logic is prone to hang, and this is "fixed" in #4508 .

@andoorve
Copy link
Contributor Author

andoorve commented May 6, 2024

@SolitaryThinker I tried the model/commands above that are giving you issues. I was unable to reproduce on my setup.

My Setup

Started a fresh instance with the following:

GCP g2-standard-48 (4 x NVIDIA L4)
Image: Google, Deep Learning VM with CUDA 12.1, M120, Debian 11, Python 3.10. With CUDA 12.1 preinstalled.
vLLM install @ 04b5fe9

Experiments

Started vLLM with

python -m vllm.entrypoints.openai.api_server --model JackFram/llama-160m \
--swap-space 16 \
--disable-log-requests \
--pipeline-parallel-size 2

Ran the below 3 times:

python benchmarks/benchmark_serving.py --backend vllm --model JackFram/llama-160m \
--dataset-name sharegpt \
--dataset-path ~/sharegpt.json \
--num-prompts 3

Killed vLLM server then repeated the above experiment 2 more times for a total of 3 separate serving instances, 9 benchmark tries, and 27 total requests sent.

See expected benchmark results each time:

Traffic request rate: inf
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00,  2.18s/it]
============ Serving Benchmark Result ============
Successful requests:                     3
Benchmark duration (s):                  6.55
Total input tokens:                      72
Total generated tokens:                  1380
Request throughput (req/s):              0.46
Input token throughput (tok/s):          10.99
Output token throughput (tok/s):         210.70
---------------Time to First Token----------------
Mean TTFT (ms):                          29.55
Median TTFT (ms):                        27.27
P99 TTFT (ms):                           34.69
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.42
Median TPOT (ms):                        7.77
P99 TPOT (ms):                           7.80

I wonder if it might only be reproducible on other instances... needs further investigation though.

@zhengxingmao
Copy link

zhengxingmao commented May 7, 2024

A very meaningful feature.
Hi @andoorve ,I have conducted verification based on your PR, and currently, the service can start normally. However, an error occurs when processing requests.
My env:
RTX-4090 2 nodes
vLLM install @ 04b5fe9

Here is the command:

python3 -m vllm.entrypoints.openai.api_server --trust-remote-code --model /data/llvm/llama_weight --gpu-memory-utilization 0.60 --pipeline-parallel-size 2 --port 8000 --host 0.0.0.0 --enforce-eager

And here is error stack:

ERROR 05-07 16:55:03 async_llm_engine.py:43] Engine background task failed
ERROR 05-07 16:55:03 async_llm_engine.py:43] Traceback (most recent call last):
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "python/ray/_raylet.pyx", line 902, in ray._raylet.prepare_args_internal
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 494, in serialize
ERROR 05-07 16:55:03 async_llm_engine.py:43]     return self._serialize_to_msgpack(value)
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 472, in _serialize_to_msgpack
ERROR 05-07 16:55:03 async_llm_engine.py:43]     pickle5_serialized_object = self._serialize_to_pickle5(
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 425, in _serialize_to_pickle5
ERROR 05-07 16:55:03 async_llm_engine.py:43]     raise e
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py", line 420, in _serialize_to_pickle5
ERROR 05-07 16:55:03 async_llm_engine.py:43]     inband = pickle.dumps(
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py", line 88, in dumps
ERROR 05-07 16:55:03 async_llm_engine.py:43]     cp.dump(obj)
ERROR 05-07 16:55:03 async_llm_engine.py:43]   File "/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py", line 733, in dump
ERROR 05-07 16:55:03 async_llm_engine.py:43]     return Pickler.dump(self, obj)
ERROR 05-07 16:55:03 async_llm_engine.py:43] TypeError: cannot pickle 'torch._C.Generator' object

@andoorve
Copy link
Contributor Author

andoorve commented May 7, 2024

@zhengxingmao Thanks for reporting this! Does this happen without PP? If not, I think it could be some interaction with the following flags with PP.
--trust-remote-code --model /data/llvm/llama_weight --gpu-memory-utilization 0.60

Can you try without these flags and use a model directly from HF? (LLaMa)

@andoorve
Copy link
Contributor Author

andoorve commented May 7, 2024

@SolitaryThinker

I did some investigation into what you were saying. I think there are real hangs that appear. I tried LLaMa 3 8B with effectively infinite request rate on 2 L4s and saw hangs - not sure if this is the same situation that you found yourself in. Strangely, if I did a warm up request first, the hang went away.

The ray.wait solution doesn't help, and it's not intended for async contexts. See here https://docs.ray.io/en/latest/ray-core/api/doc/ray.wait.html:

This method will issue a warning if it’s running inside an async context. Instead of ray.wait(ray_waitables), you can use await asyncio.wait(ray_waitables).

Also from here, asyncio methods such as asyncio.wait and asyncio.gather should be sufficient:
https://docs.ray.io/en/latest/ray-core/actors/async_api.html

I resolved a hang on my end with:
df9b0c4

Maybe this helps for you?

@andoorve andoorve marked this pull request as ready for review May 7, 2024 20:46
@andakai
Copy link

andakai commented May 13, 2024

Hi, @andoorve. I tried the codes on the pipeline-parallel branch of https://github.com/andoorve/vllm/tree/pipeline-parallel.

I run the server on 2 x 4090:

export CUDA_VISIBLE_DEVICES=0,1
python -m vllm.entrypoints.openai.api_server    \
        --model /home/models/Llama-2-7b-hf  \
        --dtype auto    \
        --pipeline-parallel-size 2    \
        --disable-log-requests

The server runs successfully. Then I run the client as:

python /home/vllm-pipeline-parallel/my_benchmark/benchmark_serving.py \
        --backend vllm \
        --model /home/models/Llama-2-7b-hf \
        --tokenizer /home/models/Llama-2-7b-hf  \
        --request-rate inf  \
        --trust-remote-code  \
        --num-prompts=2 \
        --host localhost    \
        --port 8000 \
        --endpoint /v1/completions  \
        --input-len 512 \    # the prompt is randomly generated.
        --output-len 64

And the error occurs:

INFO:     127.0.0.1:40278 - "POST /v1/completions HTTP/1.1" 200 OK                                                                                                                                                                                                                                                                                                                                                        
INFO:     127.0.0.1:40294 - "POST /v1/completions HTTP/1.1" 200 OK                                                                                                                                                                                                                                                                                                                                                        
ERROR 05-13 14:57:34 async_llm_engine.py:43] Engine background task failed                                                                                                                                                                                                                                                                                                                                                
ERROR 05-13 14:57:34 async_llm_engine.py:43] Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                           
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 38, in _raise_exception_on_finish                                                                                                                                                                                                                                                        
ERROR 05-13 14:57:34 async_llm_engine.py:43]     task.result()                                                                                                                                                                                                                                                                                                                                                            
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 513, in run_engine_loop                                                                                                                                                                                                                                                                  
ERROR 05-13 14:57:34 async_llm_engine.py:43]     result = task.result()                                                                                                                                                                                                                                                                                                                                                   
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/tasks.py", line 479, in wait_for                                                                                                                                                                                                                                                                 
ERROR 05-13 14:57:34 async_llm_engine.py:43]     return fut.result()                                                                                                                                                                                                                                                                                                                                                      
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 472, in engine_step                                                                                                                                                                                                                                                                      
ERROR 05-13 14:57:34 async_llm_engine.py:43]     request_outputs = await self.engine.step_async(virtual_engine)                                                                                                                                                                                                                                                                                                           
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 214, in step_async                                                                                                                                                                                                                                                                       
ERROR 05-13 14:57:34 async_llm_engine.py:43]     output = await self.model_executor.execute_model_async(                                                                                                                                                                                                                                     
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/vllm-pipeline-parallel/vllm/executor/distributed_gpu_executor.py", line 110, in execute_model_async                                       
ERROR 05-13 14:57:34 async_llm_engine.py:43]     all_outputs = await self._run_workers_async("execute_model",                                                                                                                                                                                                                                
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/vllm-pipeline-parallel/vllm/executor/ray_gpu_executor.py", line 349, in _run_workers_async                                                                                                                                                                                
ERROR 05-13 14:57:34 async_llm_engine.py:43]     async with self.pp_locks[pp_rank]:                                                                                                                          
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/locks.py", line 14, in __aenter__                                                   
ERROR 05-13 14:57:34 async_llm_engine.py:43]     await self.acquire()                                                                                                                                        
ERROR 05-13 14:57:34 async_llm_engine.py:43]   File "/home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/locks.py", line 120, in acquire                                                     
ERROR 05-13 14:57:34 async_llm_engine.py:43]     await fut                                            
ERROR 05-13 14:57:34 async_llm_engine.py:43] RuntimeError: Task <Task pending name='Task-816' coro=<AsyncLLMEngine.engine_step() running at /home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py:472> cb=[_release_waiter(<Future pendi...7958b2820>()]>)() at /home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/tasks.py:416]> got Future <Future pending> attached to a different loop
Exception in callback functools.partial(<function _raise_exception_on_finish at 0x7fd7752c09d0>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7fd817218220>>)                                                                                                                                                                                  
handle: <Handle functools.partial(<function _raise_exception_on_finish at 0x7fd7752c09d0>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7fd817218220>>)>                                                                                                                                                                                       
Traceback (most recent call last):                                                                    
  File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 38, in _raise_exception_on_finish                                                                                        
    task.result()                                                                                     
  File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 513, in run_engine_loop                                                                                                  
    result = task.result()                                                                            
  File "/home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/tasks.py", line 479, in wait_for                                                                                                 
    return fut.result()                                                                               
  File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 472, in engine_step                                                                                                      
    request_outputs = await self.engine.step_async(virtual_engine)                                                                                                                                           
  File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 214, in step_async                                                                                                       
    output = await self.model_executor.execute_model_async(                                           
  File "/home/vllm-pipeline-parallel/vllm/executor/distributed_gpu_executor.py", line 110, in execute_model_async                                                                                    
    all_outputs = await self._run_workers_async("execute_model",                                                                                                                                             
  File "/home/vllm-pipeline-parallel/vllm/executor/ray_gpu_executor.py", line 349, in _run_workers_async                                                                                             
    async with self.pp_locks[pp_rank]:                                                                
  File "/home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/locks.py", line 14, in __aenter__                                                                                                
    await self.acquire()                                                                              
  File "/home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/locks.py", line 120, in acquire                                                                                                  
    await fut                                                                                         
RuntimeError: Task <Task pending name='Task-816' coro=<AsyncLLMEngine.engine_step() running at /home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py:472> cb=[_release_waiter(<Future pendi...7958b2820>()]>)() at /home/anaconda/anaconda3/envs/vllm_pp/lib/python3.9/asyncio/tasks.py:416]> got Future <Future pending> attached to a different loop

The above exception was the direct cause of the following exception:

Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                                                                        
  File "uvloop/cbhandles.pyx", line 63, in uvloop.loop.Handle._run                                                                                                                                                                                                                                                                                                                                                        
  File "/home/vllm-pipeline-parallel/vllm/engine/async_llm_engine.py", line 45, in _raise_exception_on_finish                                                                                                                                                                                                                                                                                                     
    raise AsyncEngineDeadError(                                                                                                                                                                                                                                                                                                                                                                                           
vllm.engine.async_llm_engine.AsyncEngineDeadError: Task finished unexpectedly. This should never happen! Please open an issue on Github. See stack trace above for the actual cause.                                                                                                                                                                                                                                      
ERROR:    Exception in ASGI application

And the info shows that the running hangs. I am not sure if this is the same problem @SolitaryThinker met.

INFO 05-13 15:12:01 metrics.py:229] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 2 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 1.8%, CPU KV cache usage: 0.0%

By the way, if changing num_prompt to 1, it serves successfully. The tp=2 configuration works well too.

@andoorve
Copy link
Contributor Author

Hey @darrenglow

Thanks for trying it out. I tried as well and am not able to reproduce on 2xL4. Can you try it with python 3.10 if possible?

https://stackoverflow.com/questions/55918048/asyncio-semaphore-runtimeerror-task-got-future-attached-to-a-different-loop

@SolitaryThinker
Copy link

SolitaryThinker commented May 13, 2024

Hi

@andoorve, thanks for the feedback. I was also having the same issue as @darrenglow and using python10 fixed both the hanging and crashing I was experiencing.

However vLLM hangs during CudaGraph capture when enabling PP+TP together without using the --enforce-eager flag.

Example command and output:

python -m vllm.entrypoints.openai.api_server   \ 
        --model meta-llama/Llama-2-13b-hf \
        --swap-space 16   \
        --disable-log-requests   \  
        --tensor-parallel-size 2 \
        --pipeline-parallel-size 2

Output:

INFO 05-13 17:53:55 api_server.py:151] vLLM API server version 0.4.1
INFO 05-13 17:53:55 api_server.py:152] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, served_model_name=None, lora_modules=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], model='meta-llama/Llama-2-13b-hf', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='outlines', worker_use_ray=False, pipeline_parallel_size=2, tensor_parallel_size=2, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=16, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=5, disable_log_stats=False, quantization=None, enforce_eager=False, max_context_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', max_cpu_loras=None, fully_sharded_loras=False, device='auto', image_input_type=None, image_token_id=None, image_input_shape=None, image_feature_size=None, scheduler_delay_factor=0.0, enable_chunked_prefill=False, speculative_model=None, num_speculative_tokens=None, speculative_max_model_len=None, model_loader_extra_config=None, engine_use_ray=False, disable_log_requests=True, max_log_len=None)
/home/haozhang/anaconda3/bin/vllm10/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
INFO 05-13 17:53:56 config.py:549] Disabled the custom all-reduce kernel because it is not supported with pipeline parallelism.
2024-05-13 17:53:58,773	INFO worker.py:1749 -- Started a local Ray instance.
INFO 05-13 17:54:00 llm_engine.py:98] Initializing an LLM engine (v0.4.1) with config: model='meta-llama/Llama-2-13b-hf', speculative_config=None, tokenizer='meta-llama/Llama-2-13b-hf', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=2, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0)
INFO 05-13 17:54:10 utils.py:613] Found nccl from library /home/haozhang/.config/vllm/nccl/cu12/libnccl.so.2.18.1
(RayWorkerWrapper pid=1346071) INFO 05-13 17:54:10 utils.py:613] Found nccl from library /home/haozhang/.config/vllm/nccl/cu12/libnccl.so.2.18.1
(RayWorkerWrapper pid=1346450) INFO 05-13 17:54:10 utils.py:613] Found nccl from library /home/haozhang/.config/vllm/nccl/cu12/libnccl.so.2.18.1
(RayWorkerWrapper pid=1346571) INFO 05-13 17:54:10 utils.py:613] Found nccl from library /home/haozhang/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 05-13 17:54:11 selector.py:28] Using FlashAttention-2 backend.
(RayWorkerWrapper pid=1346071) INFO 05-13 17:54:11 selector.py:28] Using FlashAttention-2 backend.
(RayWorkerWrapper pid=1346450) INFO 05-13 17:54:11 selector.py:28] Using FlashAttention-2 backend.
(RayWorkerWrapper pid=1346571) INFO 05-13 17:54:11 selector.py:28] Using FlashAttention-2 backend.
INFO 05-13 17:54:13 pynccl_utils.py:43] vLLM is using nccl==2.18.1
(RayWorkerWrapper pid=1346071) INFO 05-13 17:54:13 pynccl_utils.py:43] vLLM is using nccl==2.18.1
(RayWorkerWrapper pid=1346450) INFO 05-13 17:54:13 pynccl_utils.py:43] vLLM is using nccl==2.18.1
(RayWorkerWrapper pid=1346571) INFO 05-13 17:54:13 pynccl_utils.py:43] vLLM is using nccl==2.18.1
INFO 05-13 17:54:17 weight_utils.py:193] Using model weights format ['*.safetensors']
(RayWorkerWrapper pid=1346450) INFO 05-13 17:54:17 weight_utils.py:193] Using model weights format ['*.safetensors']
(RayWorkerWrapper pid=1346571) INFO 05-13 17:54:17 weight_utils.py:193] Using model weights format ['*.safetensors']
(RayWorkerWrapper pid=1346071) INFO 05-13 17:54:18 weight_utils.py:193] Using model weights format ['*.safetensors']
INFO 05-13 17:54:18 model_runner.py:178] Loading model weights took 6.2831 GB
(RayWorkerWrapper pid=1346450) INFO 05-13 17:54:20 model_runner.py:178] Loading model weights took 6.2831 GB
(RayWorkerWrapper pid=1346571) INFO 05-13 17:54:20 model_runner.py:178] Loading model weights took 6.2831 GB
(RayWorkerWrapper pid=1346071) INFO 05-13 17:54:21 model_runner.py:178] Loading model weights took 6.2831 GB
(RayWorkerWrapper pid=1346450) [rank2]:[W ProcessGroupNCCL.cpp:2291] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
(RayWorkerWrapper pid=1346571) [rank3]:[W ProcessGroupNCCL.cpp:2291] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank0]:[W ProcessGroupNCCL.cpp:2291] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
(RayWorkerWrapper pid=1346071) [rank1]:[W ProcessGroupNCCL.cpp:2291] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
INFO 05-13 17:54:23 distributed_gpu_executor.py:46] # GPU blocks: 8586, # CPU blocks: 5242
INFO 05-13 17:54:29 model_runner.py:888] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-13 17:54:29 model_runner.py:892] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(RayWorkerWrapper pid=1346071) INFO 05-13 17:54:30 model_runner.py:888] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(RayWorkerWrapper pid=1346071) INFO 05-13 17:54:30 model_runner.py:892] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(RayWorkerWrapper pid=1346450) INFO 05-13 17:54:30 model_runner.py:888] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(RayWorkerWrapper pid=1346450) INFO 05-13 17:54:30 model_runner.py:892] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(RayWorkerWrapper pid=1346571) INFO 05-13 17:54:30 model_runner.py:888] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(RayWorkerWrapper pid=1346571) INFO 05-13 17:54:30 model_runner.py:892] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.

Thanks
-will

@andakai
Copy link

andakai commented May 14, 2024

Hey @darrenglow

Thanks for trying it out. I tried as well and am not able to reproduce on 2xL4. Can you try it with python 3.10 if possible?

https://stackoverflow.com/questions/55918048/asyncio-semaphore-runtimeerror-task-got-future-attached-to-a-different-loop

Thanks for your reply. Switching to python 3.10 did help solve the problem. Now I also met the same problem as @SolitaryThinker points out.

However vLLM hangs during CudaGraph capture when enabling PP+TP together without using the --enforce-eager flag.

@andoorve
Copy link
Contributor Author

Hi @SolitaryThinker, @darrenglow

Thanks for your comments. My best guess right now is its probably related to the fact that this PR hasn't been rebased on the latest distributed/PyNCCL changes which we would need. However, CUDAGraph is a tricky thing to work with in general so we can't be sure.

Currently I'm waiting on reviews before rebasing, at which point we can try again. We may still merge without CUDAGraph support though for TP + PP, especially since chunked prefill is eager. This is so that we have something functional in the mainline as soon as possible - this is TBA though.

coros.append(
worker.execute_method.remote(
method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused by this loop. Please help with my understanding, thanks! @andoorve

Say if we have pp_size = 2 and tp_size=1, we'll iterate over the two ranks from 0 to 1 (outer loop, inner loop diminishes).

On pp_rank=0, we will launch the execution of the first stage on its corresponding GPU (L355) and await its completion (L368). Once L368 returns, we will proceed to rank 1. The part I am confused, if in rank=0's execution, we launch a NCCL send, this send won't return unless we have rank=1 to launch its corresponding recv, right? In this case, that L368 await will never return hence block the loop to launch its corresponding recv?

Please correct if I am wrong, appreciate your help! @andoorve !

Copy link
Collaborator

@zhisbug zhisbug May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

followup:

I debugged a bit and did a simple experiment:

I file a single request to a model served with pp=2:

  • On rank=0, I do not change any model code and just launch the NCCL send as your code did
  • while on rank=1, I changed your code and do not launch a recv ( but using fake values of hidden_states).

I added a few prints in your code and found that the send on rank=0 can still passes! Given that your send_to_next_rank and recv_from_prev_rank are indeed using the synchronous send/recv API, this is extremely suspicious that the current code might have messed up some send/recv pairings...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zhisbug

Thanks for your thorough investigation! I was also initially puzzled by this and did not expect this code to work - send/recv should be blocking and thus we should hang while waiting at L368 await. At the time, I thought this might be due to a ray quirk - i.e. returning early somehow.

At the time I dismissed it because:

a) The output we get is correct
b) When I printed the sent/recv'd tensors those appeared to be correct
c) When I check the trace with nsys I see send and recv matching up there.

However, what you are saying is true - I do see the send ending before the recv begins when I print timestamps. It needs a more in-depth look.

I tried your debugging method but I do see hangs when I do it this way which we would expect:

  • On rank=0, I do not change any model code and just launch the NCCL send as your code did
  • while on rank=1, I changed your code and do not launch a recv ( but using fake values of hidden_states).

May I know your exact modifications?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you change your llama forward function to the following, in which I let the first profile_run to complete a full send/recv by making a special case based on shape (in fact, I also found the profile run won't go through your raygpuexecutor code), but the following send/recv to be only partial -- only send, no recv.

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if is_pipeline_model_parallel_first_rank():
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            if inputs_embeds is not None:
                sizes = list(inputs_embeds.size())
            else:
                sizes = list(input_ids.size()) + [self.config.hidden_size]
            print(f"{sizes}")
            if sizes[0] == 2048:
                hidden_states, residual = recv_prev_rank(
                    2, torch.Size(sizes), self.embed_tokens.weight.dtype,
                    self.embed_tokens.weight.device)
            else:
                if inputs_embeds is not None:
                    hidden_states = inputs_embeds
                else:
                    hidden_states = self.get_input_embeddings(input_ids)
                residual = None

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i - self.start_layer],
                attn_metadata,
                residual,
            )

Then you can launch the openai server and send one prompt using the following two commands:

CUDA_LAUNCH_BLOCKING=1 python -m vllm.entrypoints.openai.api_server --model JackFram/llama-160m  --swap-space 16 --pipeline-parallel-size 2 --enforce-eager
python benchmark_serving.py --backend vllm --model JackFram/llama-160m --dataset-name sharegpt --dataset-path ~/sharegpt.json --num-prompts 1 --sharegpt-output-len 5

If you add a few prints at the end of the send API:

def send_next_rank(tensors: List[torch.Tensor]) -> None:
    """Send the tensors to the next pipeline model parallel rank."""
    print(f"global rank {torch.distributed.get_rank()} sending...")
    combined_tensor = torch.cat(tensors, dim=0)
    torch.distributed.send(combined_tensor,
                           get_pipeline_model_parallel_next_rank(),
                           get_pipeline_model_parallel_group())
    print(f"global rank {torch.distributed.get_rank()} sent done")

You will observe the first send passes even I do not launch recv for it! After the first send/recv, the server will hang at the line which I print "global rank {torch.distributed.get_rank()} sending..." .

This is extremely strange because we do not launch a coressponding recv for the first send, how could?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I have a good idea of what is going on here, and why what we are still seeing the correct output even though send and recv here are not operating as we expect - but it's harder to figure out why that's happening.

My hypothesis is that torch.distributed.send is enqueuing the correct send operation on the relevant CUDA stream. However, instead of blocking waiting for that NCCL operation to complete, it is instead returning immediately from the Python function. I.e. it is operating similar to the non-blocking isend function which is not what we expect.

This is why you see the first send "go-through" (It is not actually sending just returning from the Python function) and blocking on the second send.

I modified communication_op.py to include some print statements and time.sleeps.

def send_next_rank(tensors: List[torch.Tensor]) -> None:
    """Send the tensors to the next pipeline model parallel rank."""
    combined_tensor = torch.cat(tensors, dim=0)
    torch.cat(tensors, dim=0)
    print (f'SEND STARTING {time.time()}', flush=True)
    torch.distributed.send(combined_tensor,
                           get_pipeline_model_parallel_next_rank(),
                           get_pipeline_model_parallel_group())
    print(f'SEND SUM: {combined_tensor.sum()}', flush=True)
    print (f'SEND COMPLETED {time.time()}', flush=True)
    time.sleep(5)


def recv_prev_rank(num_tensors: int, sizes: torch.Size, dtype: torch.dtype,
                   device: torch.device) -> List[torch.Tensor]:
    sizes = list(sizes)
    """Receive tensors from the previous pipeline model parallel rank."""
    combined_tensor = torch.empty([sizes[0] * num_tensors] + sizes[1:],
                                  dtype=dtype,
                                  device=device)
    time.sleep(5)
    print (f'RECV STARTING {time.time()}', flush=True)
    torch.distributed.recv(combined_tensor,
                           get_pipeline_model_parallel_prev_rank(),
                           get_pipeline_model_parallel_group())
    print(f'RECV SUM: {combined_tensor.sum()}', flush=True)
    print (f'RECV COMPLETED {time.time()}', flush=True)
    return torch.chunk(combined_tensor, num_tensors, dim=0)

This gives the following output (I also include print statements for entering and exiting an outer loop iteration in ray_gpu_executor.py.

PP RANK 0 STARTED! at 1715983628.996715
SEND STARTING 1715983629.1642857
SEND SUM: -248.0
SEND COMPLETED 1715983629.1664162
PP RANK 0 DONE! at 1715983634.1697052
PP RANK 1 STARTED! at 1715983634.169776
(RayWorkerWrapper pid=2335407) RECV STARTING 1715983639.196885
(RayWorkerWrapper pid=2335407) RECV SUM: -248.0
(RayWorkerWrapper pid=2335407) RECV COMPLETED 1715983639.1983254
PP RANK 1 DONE! at 1715983639.2434692

Here, the time in the python portion of the send function is completely disjoint from the time spent in the recv portion (and PP rank 0 is done before the subsequent RECV is started). However, the checksums are consistent. This is consistent with what you saw in your experiment where the second send was blocked by the first (since there was no matching recv for the first send) even though the Python function completed.

I tried to gather nsys trace here for one more piece of evidence, but unfortunately was not able to get both the GPU traces for some reason.

Of course, although in this case this behaviour seems to be working in our favour it is not at all consistent with what we expect of torch.distributed.send and torch.distributed.recv which is to block the Python thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To continue, I also tried a small unit test and here the semantics are respected. That is, send blocks its process until the recv is completed.

import time
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
    """ Distributed function to be implemented later. """
    device = torch.device(f"cuda:{rank}")
    if rank == 0:
        tens = torch.ones([33, 4096], dtype=torch.bfloat16, device=device)
        dist.send(tens, dst=1)
        print(f'SEND COMPLETED {time.ctime()}', flush=True)
    else:
        tens = torch.empty([33, 4096], dtype=torch.bfloat16, device=device)
        time.sleep(10)
        print(f'ABOUT TO BEGIN RECV {time.ctime()}', flush=True)
        dist.recv(tens, src=0)
        print (f'{tens}', flush=True)

def init_process(rank, size, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29501'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size, run, 'nccl'))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

I have been trying to see if there are some env variables that perhaps Ray or vLLM might have set to change PyTorch's behaviour but not successful so far.

@andoorve
Copy link
Contributor Author

@SolitaryThinker @andakai

This is rebased now, you can try this.

@SolitaryThinker
Copy link

SolitaryThinker commented May 22, 2024

@andoorve thank you for rebasing!

I am still seeing the same error from CudaGraph when using both PP and PP+TP. What is your pytorch version? I did a clean install using pip install -e .
I've included my environment info below.

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.35

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-167-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A5000
GPU 1: NVIDIA RTX A5000
GPU 2: NVIDIA RTX A5000
GPU 3: NVIDIA RTX A5000

Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             96
On-line CPU(s) list:                0-95
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Gold 6342 CPU @ 2.80GHz
CPU family:                         6
Model:                              106
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          2
Stepping:                           6
Frequency boost:                    enabled
CPU max MHz:                        2801.0000
CPU min MHz:                        800.0000
BogoMIPS:                           5600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear pconfig flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          2.3 MiB (48 instances)
L1i cache:                          1.5 MiB (48 instances)
L2 cache:                           60 MiB (48 instances)
L3 cache:                           72 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-23,48-71
NUMA node1 CPU(s):                  24-47,72-95
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] triton==2.3.0
[pip3] vllm_nccl_cu12==2.18.1.0.4.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] torch                     2.3.0                    pypi_0    pypi
[conda] triton                    2.3.0                    pypi_0    pypi
[conda] vllm-nccl-cu12            2.18.1.0.4.0             pypi_0    pypiROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.2
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0	GPU1	GPU2	GPU3	NIC0	NIC1	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	PXB	PXB	SYS	SYS	SYS	0-23,48-71	0		N/A
GPU1	PXB	 X 	PIX	SYS	SYS	SYS	0-23,48-71	0		N/A
GPU2	PXB	PIX	 X 	SYS	SYS	SYS	0-23,48-71	0		N/A
GPU3	SYS	SYS	SYS	 X 	SYS	SYS	24-47,72-95	1		N/A
NIC0	SYS	SYS	SYS	SYS	 X 	PIX				
NIC1	SYS	SYS	SYS	SYS	PIX	 X 				

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1

Please see the following error (output from other ranks omitted, but they all fail at cudagraph)

ERROR 05-22 06:39:07 worker_base.py:145] Error executing method initialize_cache. This might cause deadlock in distributed execution.
ERROR 05-22 06:39:07 worker_base.py:145] Traceback (most recent call last):
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
ERROR 05-22 06:39:07 worker_base.py:145]     return func(*args, **kwargs)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1886, in send
ERROR 05-22 06:39:07 worker_base.py:145]     group.send([tensor], group_dst_rank, tag).wait()
ERROR 05-22 06:39:07 worker_base.py:145] AttributeError: 'NoneType' object has no attribute 'wait'
ERROR 05-22 06:39:07 worker_base.py:145] 
ERROR 05-22 06:39:07 worker_base.py:145] During handling of the above exception, another exception occurred:
ERROR 05-22 06:39:07 worker_base.py:145] 
ERROR 05-22 06:39:07 worker_base.py:145] Traceback (most recent call last):
ERROR 05-22 06:39:07 worker_base.py:145]   File "/workspace/will/vllm/vllm/worker/model_runner.py", line 962, in capture
ERROR 05-22 06:39:07 worker_base.py:145]     hidden_states = self.model(
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 05-22 06:39:07 worker_base.py:145]     return self._call_impl(*args, **kwargs)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 05-22 06:39:07 worker_base.py:145]     return forward_call(*args, **kwargs)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/workspace/will/vllm/vllm/model_executor/models/llama.py", line 399, in forward
ERROR 05-22 06:39:07 worker_base.py:145]     hidden_states = self.model(input_ids, positions, kv_caches,
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 05-22 06:39:07 worker_base.py:145]     return self._call_impl(*args, **kwargs)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 05-22 06:39:07 worker_base.py:145]     return forward_call(*args, **kwargs)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/workspace/will/vllm/vllm/model_executor/models/llama.py", line 331, in forward
ERROR 05-22 06:39:07 worker_base.py:145]     send_next_rank([hidden_states, residual])
ERROR 05-22 06:39:07 worker_base.py:145]   File "/workspace/will/vllm/vllm/distributed/communication_op.py", line 321, in send_next_rank
ERROR 05-22 06:39:07 worker_base.py:145]     torch.distributed.send(combined_tensor,
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 77, in wrapper
ERROR 05-22 06:39:07 worker_base.py:145]     msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 50, in _get_msg_dict
ERROR 05-22 06:39:07 worker_base.py:145]     "args": f"{args}, {kwargs}",
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/_tensor.py", line 464, in __repr__
ERROR 05-22 06:39:07 worker_base.py:145]     return torch._tensor_str._str(self, tensor_contents=tensor_contents)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/_tensor_str.py", line 697, in _str
ERROR 05-22 06:39:07 worker_base.py:145]     return _str_intern(self, tensor_contents=tensor_contents)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/_tensor_str.py", line 617, in _str_intern
ERROR 05-22 06:39:07 worker_base.py:145]     tensor_str = _tensor_str(self, indent)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/_tensor_str.py", line 349, in _tensor_str
ERROR 05-22 06:39:07 worker_base.py:145]     formatter = _Formatter(get_summarized_data(self) if summarize else self)
ERROR 05-22 06:39:07 worker_base.py:145]   File "/root/anaconda3/envs/vllm10-and/lib/python3.10/site-packages/torch/_tensor_str.py", line 137, in __init__
ERROR 05-22 06:39:07 worker_base.py:145]     nonzero_finite_vals = torch.masked_select(
ERROR 05-22 06:39:07 worker_base.py:145] RuntimeError: CUDA error: operation not permitted when stream is capturing
ERROR 05-22 06:39:07 worker_base.py:145] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
ERROR 05-22 06:39:07 worker_base.py:145] For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
ERROR 05-22 06:39:07 worker_base.py:145] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


@andoorve
Copy link
Contributor Author

This is due to PyTorch changes in 2.3.0, see: pytorch/pytorch#120270

You can workaround this with a quick fix by changing send to isend and recv to irecv. Or, you can pull the latest changes I provide here to use PyNCCL instead.

@SolitaryThinker
Copy link

@andoorve Great, thanks for the quick fix. Your pynccl changes are identical to mine so that's reassuring.

commit 921bb1a014d435089db634fea9451b8c9f945459
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Wed May 22 02:28:47 2024 +0000

    Add back driver worker arg

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 39c6019865192737ce3cd09c50d13db2a32e1ca5
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Thu May 9 00:22:12 2024 +0000

    Test fix

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit b60f7ea8779ae5e35c68868f327569df2167b88f
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Wed May 8 04:54:52 2024 +0000

    Refactoring and test fixes

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 7e993601f47e68afe31b30ac66f9252956ce58c9
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Wed May 8 00:22:33 2024 +0000

    Formatting

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 2091dd91d06070d1db0f82670e82120d5f7ad5f4
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue May 7 21:48:21 2024 +0000

    Basic PP tests

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 016e25664434dc6f63eed9526e5982048757d7a2
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue May 7 20:40:54 2024 +0000

    Formatting

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit ee86cd204666eab815e42be703c5f434c41af255
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue May 7 20:40:36 2024 +0000

    Fix condition for PP support

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit df9b0c45cee14395b2b2dff9c4e3343ab2a019a1
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue May 7 18:16:55 2024 +0000

    Fix hangs

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 2180531ed5592d49cfa7492cebc92269693094ee
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue May 7 04:01:29 2024 +0000

    Fix typo

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit a17fcfe02c820f7b83bdcc3704059fcb35a231b8
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue May 7 01:50:24 2024 +0000

    Assert out model architectures that are unsupported

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit f784fda224144f82065c19c643912390ab29b849
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue May 7 01:17:33 2024 +0000

    More test fixes

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 04b5fe903ac4598b5337d457afd684426e384690
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Sun May 5 17:28:42 2024 +0000

    Change condition for prepare_input_tensors to broadcast

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 526bade032dbeba73f6523009701f8a5f4b222f9
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Sun May 5 17:14:48 2024 +0000

    Fixed bug with TP + PP execution

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 9d698fa
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Thu May 2 18:38:41 2024 +0000

    Format and test changes

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 16a5aac
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Thu May 2 18:30:46 2024 +0000

    Format and test changes

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 65a5300
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Wed May 1 06:42:13 2024 +0000

    Simplify weight loading logic

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit daddc19
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Wed May 1 05:56:53 2024 +0000

    Formatting

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 1be32c8
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Wed May 1 05:55:12 2024 +0000

    Revert "PyNCCL changes"

    This reverts commit 99bb187.

commit 99bb187
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Wed May 1 05:29:42 2024 +0000

    PyNCCL changes

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit bd12e70
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Tue Apr 30 22:46:12 2024 +0000

    Fixed testing errors

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit fbb2b2e
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Sat Apr 27 08:48:36 2024 +0000

    Formatting

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

commit 06609d9
Author: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Date:   Sat Apr 27 08:39:03 2024 +0000

    Pipeline Parallel

    Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
@SolitaryThinker
Copy link

SolitaryThinker commented May 22, 2024

@andoorve

Not sure if you are aware, but it seems that one of your pushes today broke PP. I am seeing empty responses from the api_server. With or without cudagraph for both tp-only and tp+pp.

The last time that your branch worked for me was when you sent this message above.

This is due to PyTorch changes in 2.3.0, see: pytorch/pytorch#120270
You can workaround this with a quick fix by changing send to isend and recv to irecv. Or, you can pull the latest changes I provide here to use PyNCCL instead.

Below is a small script for openai client I have been using to check for correctness. Hopefully it will be of use.

from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    # defaults to os.environ.get("OPENAI_API_KEY")
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

ps =     ["Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is"] * 10

# Completion API
stream = False
completion = client.completions.create(
    model=model,
    prompt=ps,
    echo=False,
    temperature=0,
    stream=stream)

print("Completion results:")
if stream:
    for c in completion:
        print(c)
else:
    print(completion)
for idx, c in enumerate(completion.choices):
    print('Prompt:', ps[idx])
    print('Decode:', c.text)

Output on latest commit (meta-llama/Llama-2-13b-hf):

Prompt: Hello, my name is
Decode: 
Prompt: The president of the United States is
Decode: 
Prompt: The capital of France is
Decode: 
Prompt: The future of AI is
Decode: 
...

expected output (meta-llama/Llama-2-13b-hf):

Prompt: Hello, my name is
Decode:  Katie and I am a recovering perfectionist.
I’ve
Prompt: The president of the United States is
Decode:  the head of state and head of government of the United States, indirectly elected
Prompt: The capital of France is
Decode:  a city of contrasts. It is a city of history, culture, and
Prompt: The future of AI is
Decode:  in the hands of the people who build it.
The future of AI
...
  • will

@andoorve
Copy link
Contributor Author

@SolitaryThinker thanks for letting me know - let me take a look.

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
@andoorve
Copy link
Contributor Author

It was passing GPT2 but not LLaMa. Missed a line when rebasing a LLaMa change which I added back. Should pass your script now @SolitaryThinker

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC]: Initial support for Pipeline Paralleism
10 participants