Skip to content

[Feature] support sequence parallelism using compilation pass #16155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 27, 2025

Conversation

cascade812
Copy link
Contributor

@cascade812 cascade812 commented Apr 7, 2025

This PR support sequence parallelism using below compilation config

config = CompilationConfig(
    level=3,
    custom_ops=["+rms_norm"],
    compile_sizes=[4, 8, 16],
    splitting_ops=[],
)
config.pass_config.enable_sequence_parallelism= True

llm = LLM(model="llama/Llama-3.2-1B-Instruct",
          enforce_eager=False,
          tensor_parallel_size=2,
          dtype=torch.float16,
          max_num_batched_tokens=2048,
          compilation_config=config)

Copy link

github-actions bot commented Apr 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@cascade812 cascade812 marked this pull request as draft April 7, 2025 04:04
@robertgshaw2-redhat
Copy link
Collaborator

nice!

Signed-off-by: cascade812 <[email protected]>
@cascade812 cascade812 marked this pull request as ready for review April 10, 2025 02:43
Signed-off-by: cascade812 <[email protected]>
@mergify mergify bot added the ci/build label Apr 10, 2025
@yaochengji yaochengji requested a review from tlrmchlsmth April 11, 2025 03:53
@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Apr 13, 2025

I will take a closer look on Tuesday but my initial thoughts are that I agree we should improve the pass interface. At the same time, I like that currently all passes conform to the PyTorch CustomGraphPass interface (__call__(graph: fx.Graph)) and so I would prefer to pass the additional information through a context like in this PR. Also, just checking if a shape is "supported" is not enough; sometimes the pass needs to know the actual runtime shape

I also think we should better separate the vllm.config.CompilationConfig and the torch._inductor.config. In my attention fusion pass I will need access to the full CompilationConfig (to access attention layers) inside a pass (and not just PassConfig), which would recreate a cycle.

We should decide if we want to re-instantiate passes for different runtime shapes. It could be good if we want to register shape-dependent patterns upon initialization. But it might also mean a bunch of redundant work. Perhaps the best solution is to instantiate each pass per-shape and then at least reuse it in piecewise compilation.

Signed-off-by: cascade812 <[email protected]>
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution!

Signed-off-by: cascade812 <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM when green, thanks for the excellent work

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few notes on certain edge cases. We should at least add TODOs if we don't address them here

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for the kinda-late review, I've been head-down on the attention-quant fusion pass.

This is a great PR! I had a few minor notes - I've submitted some above, and I had a few more notes on testing I thought of so submitting them now.

I'm happy to provide more explanation for the comments & discuss if I misunderstood something. Feel free to ping me on the vLLM Slack or here whether for discussion or if you've addressed the comments and want me to take another look.

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a few more comments and a bunch more nits, thanks for addressing all of the comments so far.

A big question is whether the tp_world_size can be different for two LLM instances, because this pass being a singleton means that it'll only use the tp_size that was present during the first initialization. @tlrmchlsmth do you know the answer to this?

I think we should be able to resolve this by checking the tp_world_size during the instance() call, and then we can register more patterns on the same instance or maybe even return a new instance with the new tp_size (if that works). Or, because we clear _seen_patterns, we could get rid of the singleton structure anyway (you'd have to verify it works).

@cascade812
Copy link
Contributor Author

cascade812 commented Apr 26, 2025

Had a few more comments and a bunch more nits, thanks for addressing all of the comments so far.

A big question is whether the tp_world_size can be different for two LLM instances, because this pass being a singleton means that it'll only use the tp_size that was present during the first initialization. @tlrmchlsmth do you know the answer to this?

I think we should be able to resolve this by checking the tp_world_size during the instance() call, and then we can register more patterns on the same instance or maybe even return a new instance with the new tp_size (if that works). Or, because we clear _seen_patterns, we could get rid of the singleton structure anyway (you'd have to verify it works).

Good point. I tested two llm instances with both singleton and non-singleton setups, and both worked. That said, I agree non-singleton makes more sense here, so I’ve removed the singleton.

Signed-off-by: cascade812 <[email protected]>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) April 26, 2025 19:09
@vllm-bot vllm-bot merged commit 690fe01 into vllm-project:main Apr 27, 2025
71 of 73 checks passed
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
…roject#16155)

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…roject#16155)

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
…roject#16155)

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…roject#16155)

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…roject#16155)

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
@Juelianqvq
Copy link
Contributor

Thanks for the great work! @cascade812 Just a simple question, I've noticed that you initialized the paramater rms_norm_weight with an empty 2D tensor during the middleRMSNorm and the lastRMSNorm SP pass, but on the CUDA kernel side, the function accepts it as a 1D tensor. Why wasn't it conflicted with each other?

@cascade812
Copy link
Contributor Author

Thanks for the great work! @cascade812 Just a simple question, I've noticed that you initialized the paramater rms_norm_weight with an empty 2D tensor during the middleRMSNorm and the lastRMSNorm SP pass, but on the CUDA kernel side, the function accepts it as a 1D tensor. Why wasn't it conflicted with each other?

I think it's because pattern matcher is more flexible, it doesn't validate kernel implementation details - it just ensures the operation pattern is recognizable and shapes are broadcastable.

minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…roject#16155)

Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Signed-off-by: minpeter <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants