Skip to content

[Bugfix][Misc] Use TritonPlaceholderModule to defensively import triton #15099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025

Conversation

MengqingCao
Copy link
Contributor

@MengqingCao MengqingCao commented Mar 19, 2025

FIX #14888
FIX #16955

For devices not support triton, directly importing triton will break the inference. This pr offers a TritonPlaceholder and TritonLaguagePlaceholder module to hold the non-triton senario.

TritonPlaceholder and TritonLaguagePlaceholder will be injected to sys.modules to let system see the placeholder rather than the real triton. Thus during importing, it won't break the normal pipeline.

For the decorators like triton.jit, we implentment the func _dummy_decorator() to work as a dummy warpper, which do nothing inside.

1750250422557

cc @Hinsael plz try this pr and looking forward for your feedback! Thx!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 19, 2025
@@ -3,8 +3,11 @@

import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.triton_utils.importing import HAS_TRITON
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this doesn't really. For no triton case, we will have use case of triton.jit, which will still fail, right?

Copy link
Contributor Author

@MengqingCao MengqingCao Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review! You're right for vllm/v1/sample/rejection_sampler.py, which only contains features dependent on triton.
But in some scripts, e.g., vllm/model_executor/layers/fused_moe/fused_moe.py, contain many features not using triton.

I'll filter out the scripts only for triton, and revert the modification of them
I found that even if it is a script that only contains triton functions, it will be referenced by more files, scattering in all the project. Modifying these files will introduce greater changes.

Therefore, I think it is better to do defensive import in the file that directly references triton. In order to prevent the decorator from running without triton, I added some dummy decorators, in which nothing will be done, but only returns the original function.

cc @houseroad plz review the latest code, thanks!

@MengqingCao MengqingCao marked this pull request as draft March 20, 2025 01:25
Copy link

mergify bot commented Mar 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @MengqingCao.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@MengqingCao MengqingCao marked this pull request as ready for review March 20, 2025 03:34
@MengqingCao
Copy link
Contributor Author

@robertgshaw2-redhat could you help take a look at this pr, thanks!

@DarkLight1337
Copy link
Member

cc @mgoin @Isotr0py can you help review?

Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine to add defensive check for triton. But can we simplify it without adding if HAS_TRITON: everywhere?

Comment on lines 7 to 9
if HAS_TRITON:
import triton
import triton.language as tl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a PlaceholderModule to simplify triton import defense, otherwise we need to add the if statement everywhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

placeholder module looks better, we can try to import triton in vllm/utils.py , and if it fails, we can set sys.modules['triton'] to a placeholder module, with some logging to tell users what's happening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, sorry, please ignore the above question, I think we could check pytorch-triton-xpu to replace the platform check

Copy link

mergify bot commented Apr 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @MengqingCao.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much better now. Also cc @houseroad @Isotr0py if you have further comments.

Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now! Thanks for this effort!

@Isotr0py Isotr0py enabled auto-merge (squash) April 24, 2025 03:02
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 24, 2025
@MengqingCao
Copy link
Contributor Author

Not sure why CI failed with vllm.third_party.pynvml.NVMLError_InvalidArgument: Invalid Argument when getting device capability. I notice that there is only 1 card in the Nvidia-L4 mechine, and CUDA_VISIBLE_DEVICES not be set in workflow. Could you help figure this out? @Isotr0py

@Isotr0py
Copy link
Collaborator

Hmm, the failing spec decode tests passed on my side locally with this PR. Not sure why CI failed too...

Anyway, can you merge from main to update the kernels tests first? I think we have categorized the kernel test, which should avoid the CI timeout.

auto-merge was automatically disabled April 24, 2025 13:17

Head branch was pushed to by a user without write access

@MengqingCao
Copy link
Contributor Author

Hmm, the failing spec decode tests passed on my side locally with this PR. Not sure why CI failed too...

Anyway, can you merge from main to update the kernels tests first? I think we have categorized the kernel test, which should avoid the CI timeout.

Thanks for your help, now the rebase is done.

@MengqingCao MengqingCao changed the title [Bugfix][Misc] Add a defensive check before importing triton [Bugfix][Misc] Use TritonPlaceholderModule to defensively import triton Apr 25, 2025
@Isotr0py Isotr0py enabled auto-merge (squash) April 25, 2025 03:14
@vllm-bot vllm-bot merged commit 2f54045 into vllm-project:main Apr 25, 2025
45 checks passed
gshtras added a commit to ROCm/vllm that referenced this pull request Apr 25, 2025
* [BugFix] Remove default multiproc executor `collective_rpc` timeout (vllm-project#17000)

Signed-off-by: Nick Hill <[email protected]>

* [Core][V1][TPU] Enable structured decoding on TPU V1 (vllm-project#16499)

Signed-off-by: Chenyaaang <[email protected]>

* [Bugfix] validate urls object for multimodal content parts (vllm-project#16990)

Signed-off-by: Guillaume Calmettes <[email protected]>

* add Dockerfile build vllm against torch nightly (vllm-project#16936)

Signed-off-by: Yang Wang <[email protected]>

* [Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1 (vllm-project#13305)

Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: maleksan85 <[email protected]>
Signed-off-by: <>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: root <[email protected]>

* [V1][DP] More robust DP/EP dummy request coordination (vllm-project#16277)

Signed-off-by: Nick Hill <[email protected]>

* [BugFix] Revert ROCm Custom Paged Attention Env Flag Check (vllm-project#17022)

Signed-off-by: vllmellm <[email protected]>

* Revert "[Misc] Add S3 environment variables for better support of MinIO." (vllm-project#17021)

* [misc] tune some env vars for GB200 (vllm-project#16992)

Signed-off-by: youkaichao <[email protected]>

* [INTEL-HPU][v0] Port delayed sampling to upstream (vllm-project#16949)

Signed-off-by: Michal Adamczyk <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Co-authored-by: Michal Adamczyk <[email protected]>

* [doc] add download path tips (vllm-project#17013)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Triton FA function takes no keyword arguments (vllm-project#16902)

Signed-off-by: vllmellm <[email protected]>

* [V1] Avoid socket errors during shutdown when requests are in in-flight (vllm-project#16807)

Signed-off-by: Nick Hill <[email protected]>

* [BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) (vllm-project#16998)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Misc] Improve readability of get_open_port function. (vllm-project#17024)

Signed-off-by: gitover22 <[email protected]>

* [Bugfix] Fix AssertionError: skip_special_tokens=False is not supported for Mistral tokenizers (vllm-project#16964)

Signed-off-by: chaunceyjiang <[email protected]>

* [CI] Run v1/test_serial_utils.py in CI (vllm-project#16996)

Signed-off-by: Russell Bryant <[email protected]>

* Mistral-format support for compressed-tensors (vllm-project#16803)

Signed-off-by: mgoin <[email protected]>

* Categorize `tests/kernels/` based on kernel type (vllm-project#16799)

Signed-off-by: mgoin <[email protected]>

* [Doc] Add top anchor and a note to quantization/bitblas.md (vllm-project#17042)

Signed-off-by: windsonsea <[email protected]>

* Ensure that `pid` passed to `kill_process_tree` is `int` for `mypy` (vllm-project#17051)

Signed-off-by: Harry Mellor <[email protected]>

* [CI] Update structured-output label automation (vllm-project#17055)

Signed-off-by: Russell Bryant <[email protected]>

* Improve Transformers backend model loading QoL (vllm-project#17039)

Signed-off-by: Harry Mellor <[email protected]>

* `CacheConfig.block_size` should always be `int` when used (vllm-project#17052)

Signed-off-by: Harry Mellor <[email protected]>

* Use `@property` and private field for `data_parallel_rank_local` (vllm-project#17053)

Signed-off-by: Harry Mellor <[email protected]>

* [Frontend] Support guidance:no-additional-properties for compatibility with xgrammar (vllm-project#15949)

Signed-off-by: Travis Johnson <[email protected]>

* [BugFix][V1] Fix int32 token index overflow when preparing input ids (vllm-project#16806)

* [V1][Spec Decode] Always use argmax for sampling draft tokens  (vllm-project#16899)

Signed-off-by: Woosuk Kwon <[email protected]>

* [CI/Build] workaround for CI build failure (vllm-project#17070)

Signed-off-by: csy1204 <[email protected]>
Co-authored-by: Michael Goin <[email protected]>

* [Quantization]add prefix for commandA quantized model (vllm-project#17017)

* [Minor] Use larger batch sizes for A100/B100/B200/MI300x (vllm-project#17073)

Signed-off-by: Woosuk Kwon <[email protected]>

* [Bugfix] Enable V1 usage stats (vllm-project#16986)

Signed-off-by: mgoin <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* More informative error when using Transformers backend (vllm-project#16988)

Signed-off-by: Harry Mellor <[email protected]>

* Addendum Fix to support FIPS enabled machines with MD5 hashing (vllm-project#17043)

Signed-off-by: sydarb <[email protected]>

* [Bugfix][Core] add seq_id_to_seq_group clearing to avoid memory leak when s… (vllm-project#16472)

Signed-off-by: 开哲 <[email protected]>
Co-authored-by: 开哲 <[email protected]>

* [V1] Update structured output (vllm-project#16812)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [doc] update to hyperlink (vllm-project#17096)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* Add docs for runai_streamer_sharded (vllm-project#17093)

Signed-off-by: Omer Dayan (SW-GPU) <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [Chore] Remove Sampler from Model Code (vllm-project#17084)

Signed-off-by: Woosuk Kwon <[email protected]>

* Disable enforce_eager for V1 TPU sampler and structured output tests (vllm-project#17016)

Signed-off-by: mgoin <[email protected]>

* Simplify `TokenizerGroup` (vllm-project#16790)

Signed-off-by: Harry Mellor <[email protected]>

* Fix OOT registration test (vllm-project#17099)

Signed-off-by: Harry Mellor <[email protected]>

* [V1][PP] Optimization: continue scheduling prefill chunks (vllm-project#17080)

Signed-off-by: Rui Qiao <[email protected]>

* [Misc] Remove OLMo2 config copy (vllm-project#17066)

Signed-off-by: Isotr0py <[email protected]>

* Improve static type checking in `LoRAModelRunnerMixin` (vllm-project#17104)

Signed-off-by: Harry Mellor <[email protected]>

* [V1][Structured Output] Clear xgrammar compiler object when engine core shut down to avoid nanobind leaked warning (vllm-project#16954)

Signed-off-by: shen-shanshan <[email protected]>

* [Frontend] Using matryoshka_dimensions control the allowed output dimensions. (vllm-project#16970)

* Add missing rocm_skinny_gemms kernel test to CI (vllm-project#17060)

Signed-off-by: mgoin <[email protected]>

* [Misc] refactor example series - structured outputs (vllm-project#17040)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (vllm-project#16665)

Signed-off-by: Mark McLoughlin <[email protected]>

* [CI] Add automation for the `tool-calling` github label (vllm-project#17118)

Signed-off-by: Russell Bryant <[email protected]>

* Updating builkite job for IBM Power  (vllm-project#17111)

Signed-off-by: Aaruni Aggarwal <[email protected]>

* existing torch installation pip command fix for docs (vllm-project#17059)

* Molmo Requirements (vllm-project#17026)

Signed-off-by: Eyshika Agarwal <[email protected]>
Signed-off-by: eyshika <[email protected]>

* Add `:markdownhelp:` to `EngineArgs` docs so markdown docstrings render properly (vllm-project#17124)

Signed-off-by: Harry Mellor <[email protected]>

* Improve configs - `LoRAConfig` + `PromptAdapterConfig` (vllm-project#16980)

Signed-off-by: Harry Mellor <[email protected]>

* [Docs] Generate correct github links for decorated functions (vllm-project#17125)

Signed-off-by: Russell Bryant <[email protected]>

* Add collective_rpc to llm engine (vllm-project#16999)

Signed-off-by: Yinghai Lu <[email protected]>

* Add chat template for Llama 4 models (vllm-project#16428)

Signed-off-by: Max de Bayser <[email protected]>

* [Misc] Add example to run DeepSeek with Ray Serve LLM (vllm-project#17134)

Signed-off-by: Rui Qiao <[email protected]>

* Better error message for missing mistral params.json (vllm-project#17132)

Signed-off-by: mgoin <[email protected]>

* Use custom address for listening socket (vllm-project#15988)

Signed-off-by: Jens Glaser <[email protected]>

* [FEAT] [ROCm]: AITER Fused MOE V1 Support (vllm-project#16752)

Signed-off-by: vllmellm <[email protected]>
Co-authored-by: tjtanaa <[email protected]>

* [Attention] FA3 decode perf improvement - single mma warp group support for head dim 128 (vllm-project#16864)

Signed-off-by: Lucas Wilkinson <[email protected]>

* fix float16 support for kimi-vl (vllm-project#17156)

Co-authored-by: zhouzaida <[email protected]>

* [Doc] V1 : Update LoRA status (vllm-project#17133)

Signed-off-by: varun sundar rabindranath <[email protected]>
Co-authored-by: varun sundar rabindranath <[email protected]>

* [Docs] Fix True->true in supported_models.md (vllm-project#17141)

* Move missed `SchedulerConfig` args into scheduler config group in `EngineArgs` (vllm-project#17131)

Signed-off-by: Harry Mellor <[email protected]>

* [Misc] Clean up redundant code in uniproc_executor.py (vllm-project#16762)

Signed-off-by: Lifu Huang <[email protected]>

* [Bugfix][Misc] Use TritonPlaceholderModule to defensively import triton (vllm-project#15099)

Signed-off-by: Mengqing Cao <[email protected]>

* [Misc] Benchmark Serving Script Support Appending Results (vllm-project#17028)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Perf]Optimize rotary_emb implementation to use Triton operator for improved inference performance (vllm-project#16457)

Signed-off-by: cynthieye <[email protected]>
Co-authored-by: MagnetoWang <[email protected]>

* [Bugfix] remove fallback in guided_json (int range, patterns) (vllm-project#16725)

Signed-off-by: csy1204 <[email protected]>
Co-authored-by: 조상연[플레이스 AI] <[email protected]>

* [Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (vllm-project#15734)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>

* [Doc] Add headings to improve gptqmodel.md (vllm-project#17164)

Signed-off-by: windsonsea <[email protected]>

* Only turn on FastIncrementalDetokenizer when tokenizers >= 0.21.1 (vllm-project#17158)

* [Doc] Add two links to disagg_prefill.md (vllm-project#17168)

Signed-off-by: windsonsea <[email protected]>

* [Doc] Move todo out of beam search docstring (vllm-project#17183)

Signed-off-by: Alex-Brooks <[email protected]>

* [Bugfix] Fix mistral model tests (vllm-project#17181)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Fix Mistral ChatCompletionRequest Body Exception (vllm-project#16769)

Signed-off-by: Jasmond Loh <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* Fix API typo and remove FP8 on V1 restriction

---------

Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Guillaume Calmettes <[email protected]>
Signed-off-by: Yang Wang <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: maleksan85 <[email protected]>
Signed-off-by: <>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Michal Adamczyk <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: reidliu41 <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: gitover22 <[email protected]>
Signed-off-by: chaunceyjiang <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: windsonsea <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: csy1204 <[email protected]>
Signed-off-by: sydarb <[email protected]>
Signed-off-by: 开哲 <[email protected]>
Signed-off-by: Omer Dayan (SW-GPU) <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: shen-shanshan <[email protected]>
Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: Aaruni Aggarwal <[email protected]>
Signed-off-by: Eyshika Agarwal <[email protected]>
Signed-off-by: eyshika <[email protected]>
Signed-off-by: Yinghai Lu <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Jens Glaser <[email protected]>
Signed-off-by: varun sundar rabindranath <[email protected]>
Signed-off-by: Lifu Huang <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: cynthieye <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Jasmond Loh <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: Chenyaaang <[email protected]>
Co-authored-by: Guillaume Calmettes <[email protected]>
Co-authored-by: Yang Wang <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: qli88 <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: vllmellm <[email protected]>
Co-authored-by: Chauncey <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Chendi.Xue <[email protected]>
Co-authored-by: Michal Adamczyk <[email protected]>
Co-authored-by: Reid <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: huafeng <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Michael Yao <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Travis Johnson <[email protected]>
Co-authored-by: Yong Hoon Shin <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Sangyeon Cho <[email protected]>
Co-authored-by: Chen Xia <[email protected]>
Co-authored-by: Areeb Syed <[email protected]>
Co-authored-by: 张宇 <[email protected]>
Co-authored-by: 开哲 <[email protected]>
Co-authored-by: omer-dayan <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Shanshan Shen <[email protected]>
Co-authored-by: wang.yuqi <[email protected]>
Co-authored-by: Mark McLoughlin <[email protected]>
Co-authored-by: Aaruni Aggarwal <[email protected]>
Co-authored-by: Atilla <[email protected]>
Co-authored-by: Eyshika Agarwal <[email protected]>
Co-authored-by: Yinghai Lu <[email protected]>
Co-authored-by: Maximilien de Bayser <[email protected]>
Co-authored-by: jglaser <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Co-authored-by: Zaida Zhou <[email protected]>
Co-authored-by: zhouzaida <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: varun sundar rabindranath <[email protected]>
Co-authored-by: Lifu Huang <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
Co-authored-by: yexin(叶鑫) <[email protected]>
Co-authored-by: MagnetoWang <[email protected]>
Co-authored-by: 조상연[플레이스 AI] <[email protected]>
Co-authored-by: rasmith <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Alex Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Jasmond L <[email protected]>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: MiniCPM3 failed on ascend npu because of ModuleNotFoundError: No module named 'triton' [Usage]: ModuleNotFoundError: No module named 'triton'
6 participants