Skip to content

[Bugfix] Fix GLM4 model #16618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 17, 2025
Merged

[Bugfix] Fix GLM4 model #16618

merged 6 commits into from
Apr 17, 2025

Conversation

intervitens
Copy link
Contributor

@intervitens intervitens commented Apr 14, 2025

FIX #16617
FIX #16655
FIX #16687
FIX #16740
Currently the GLM4 model does not work and fails to load at all.
This PR enables the model to load and makes the outputs mostly identical to outputs from HF transformers.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@intervitens
Copy link
Contributor Author

intervitens commented Apr 14, 2025

The model works with --enforce-eager, however without it model loads, but produces garbage outputs.
Edit: also works fine without --enforce-eager and with VLLM_USE_V1=0

Signed-off-by: intervitens <[email protected]>
Signed-off-by: intervitens <[email protected]>
@jeejeelee jeejeelee changed the title Fix GLM4 model [Bugfix] Fix GLM4 model Apr 15, 2025
@DarkLight1337
Copy link
Member

cc @zRzRzRzRzRzRzR can you check?

@zRzRzRzRzRzRzR
Copy link
Contributor

yes, I will do this

@zRzRzRzRzRzRzR
Copy link
Contributor

hidden_states = residual + hidden_states

This section should be retained. See here

@zRzRzRzRzRzRzR
Copy link
Contributor

There seem to be some issues, I need to take a closer look. I found that the model cannot run normally now but it could before(as I pr). I need to spend some time checking it out.

@zRzRzRzRzRzRzR
Copy link
Contributor

This PR caused the model output to be garbled, @intervitens, have you encountered this problem? I am using GLM-4-9B-0414.

@kalomaze
Copy link

kalomaze commented Apr 15, 2025

This PR caused the model output to be garbled, @intervitens, have you encountered this problem? I am using GLM-4-9B-0414.

The 9b and 32b are not identical architecturally. 9b seems to have attention biases, unlike 32b.
Also, the KV head count for the reasoner 32b + DeepResearch-esque 32b (the Z1 and Z1-Rumination 32b models) seems to be larger, strangely enough, if you check the configuration.

@Chandler-Bing
Copy link

This PR caused the model output to be garbled, @intervitens, have you encountered this problem? I am using GLM-4-9B-0414.

add --enforce-eager would output normally.

@zRzRzRzRzRzRzR
Copy link
Contributor

This PR caused the model output to be garbled, @intervitens, have you encountered this problem? I am using GLM-4-9B-0414.

The 9b and 32b are not identical architecturally. 9b seems to have attention biases, unlike 32b. Also, the KV head count for the reasoner 32b + DeepResearch-esque 32b (the Z1 and Z1-Rumination 32b models) seems to be larger, strangely enough, if you check the configuration.

I don’t think that’s the issue. The 9B and 32B models released by GLM do have differences in bias—9B has bias while 32B doesn’t. However, in the attention_bias, they have already been configured.

@zRzRzRzRzRzRzR
Copy link
Contributor

zRzRzRzRzRzRzR commented Apr 15, 2025

vllm serve THUDM/GLM-4-9B-0414 –-enforce-eager

I wasn’t successful. The error you mentioned does indeed exist.
The modification to

hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

|is correct, but strangely, I got completely different outputs under the same model compared to the PR I submitted back then.
This really puzzles me.
Your change to the dim is also correct.

@zRzRzRzRzRzRzR
Copy link
Contributor

zRzRzRzRzRzRzR commented Apr 15, 2025

I tried reinstalling vLLM from source, and the issue was resolved. Under the current circumstances, your PR works correctly.

is_neox_style=False,

is not necessary

and can you change

THUDM/GLM-4-32B-Chat-0414 docs/source/models/supported_models.md as THUDM/GLM-4-32B-0414

As we rename the model. There is no more -Chat AnyMore

@zRzRzRzRzRzRzR
Copy link
Contributor

zRzRzRzRzRzRzR commented Apr 15, 2025

cc @DarkLight1337 @intervitens Thank you so much for your support Again

Also:

vllm serve THUDM/GLM-4-9B-0414 

with out –-enforce-eager is working

Signed-off-by: intervitens <[email protected]>
@mergify mergify bot added the documentation Improvements or additions to documentation label Apr 15, 2025
@intervitens
Copy link
Contributor Author

Removing

is_neox_style=False,

causes the model output to become significantly degraded and repetitive

VLLM_USE_V1=1 vllm serve THUDM/GLM-4-9B-0414

still doesn't work for me, @zRzRzRzRzRzRzR did you figure out any changes to the PR that fixed it for you?

@zRzRzRzRzRzRzR
Copy link
Contributor

zRzRzRzRzRzRzR commented Apr 15, 2025

In my scenario, both is_neox_style=False and is_neox_style=True can run normally, and I think it's better to add it for safety as False

@zRzRzRzRzRzRzR
Copy link
Contributor

This might be related to the CUDA version. I tested it on H100 with CUDA 12.4, and I'm not sure if it's related to this.

@ad1192214879
Copy link

Has this issue been resolved?

Signed-off-by: intervitens <[email protected]>
@intervitens
Copy link
Contributor Author

I fixed the error that made the model output garbage without eager mode or VLLM_USE_V1=0
Should be ready to merge now.

@DarkLight1337
Copy link
Member

Can you verify again @zRzRzRzRzRzRzR ?

@icelinks
Copy link

icelinks commented Apr 17, 2025

GLM-Z1-9B-0414 is ok,but GLM-Z1-32B-0414 repeat with !!!!!
run with CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 vllm serve GLM-Z1-32B-0414 --dtype half --max-model-len 65536 --tensor-parallel-size 8 --port 10033 --enforce-eager --rope-scaling '{"rope_ty pe": "yarn","factor": 4.0,"original_max_position_embeddings": 32768}' in Tesla T4

@zRzRzRzRzRzRzR
Copy link
Contributor

It should be a very normal behavior that T4 does not support,

  1. FP16 cannot reason this model normally, it is necessary to use BF16
  2. This exclamation mark error also seems to occur in some specific situations, but there is still no stable code for reproduction, and it seems to be unrelated to this PR (this problem still occurs when using other frameworks).

You can submit the prompt words corresponding to the "infinite output!" issue to the THUDM/GLM-4 repository, and the staff will record and try to reproduce and find the cause of the problem.

@zRzRzRzRzRzRzR
Copy link
Contributor

I fixed the error that made the model output garbage without eager mode or VLLM_USE_V1=0 Should be ready to merge now.

It is working for me

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) April 17, 2025 04:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 17, 2025
@solrex
Copy link

solrex commented Apr 17, 2025

I fixed the error that made the model output garbage without eager mode or VLLM_USE_V1=0 Should be ready to merge now.

Tested locally, works as expected. +1

@vllm-bot vllm-bot merged commit 5b1aca2 into vllm-project:main Apr 17, 2025
60 of 63 checks passed
lionelvillard pushed a commit to lionelvillard/vllm that referenced this pull request Apr 17, 2025
@icelinks
Copy link

It should be a very normal behavior that T4 does not support,

  1. FP16 cannot reason this model normally, it is necessary to use BF16
  2. This exclamation mark error also seems to occur in some specific situations, but there is still no stable code for reproduction, and it seems to be unrelated to this PR (this problem still occurs when using other frameworks).

You can submit the prompt words corresponding to the "infinite output!" issue to the THUDM/GLM-4 repository, and the staff will record and try to reproduce and find the cause of the problem.

OK, thanks. I'll try, at least now GLM-Z1-9B-0414 is correct.

@warlockedward
Copy link

Also in v100, if dtype float16 is configured , all information output by the big model is !!!!!!!!!

@Curious-chen
Copy link

Also in v100, if dtype float16 is configured , all information output by the big model is !!!!!!!!!

Yes, I also tried setting dtype float16 on A6000 and only output !!!!!!!!!

@icelinks
Copy link

icelinks commented Apr 18, 2025

Also in v100, if dtype float16 is configured , all information output by the big model is !!!!!!!!!

Yes, I also tried setting dtype float16 on A6000 and only output !!!!!!!!!

It's wired, my colleague said it's also appeared with bf16. Only GLM-Z1-32B-0414.

yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
Signed-off-by: intervitens <[email protected]>
Signed-off-by: Yang Wang <[email protected]>
@rangehow
Copy link

rangehow commented Apr 22, 2025

It seems there is still a problem. I am using multiple large models for standalone generation, including Qwen, Mistral-Large, Llama 4, Command-A, and Gemma 3-27B. All of the above models are running normally, except for GLM4-32B.

14:56:29.289 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] WorkerProc hit an exception.
14:56:29.289 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Traceback (most recent call last):
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py", line 465, in worker_busy_loop
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = func(*args, **kwargs)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = self.model_runner.execute_model(scheduler_output)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1148, in execute_model
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     valid_sampled_token_ids = sampled_token_ids.tolist()
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] RuntimeError: CUDA error: device-side assert triggered
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] 
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Traceback (most recent call last):
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py", line 465, in worker_busy_loop
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = func(*args, **kwargs)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = self.model_runner.execute_model(scheduler_output)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1148, in execute_model
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     valid_sampled_token_ids = sampled_token_ids.tolist()
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] RuntimeError: CUDA error: device-side assert triggered
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] 
14:56:29.290 �[1;36m(VllmWorker rank=2 pid=6072)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] 
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] WorkerProc hit an exception.
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Traceback (most recent call last):
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py", line 465, in worker_busy_loop
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = func(*args, **kwargs)
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = self.model_runner.execute_model(scheduler_output)
14:56:29.347 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1148, in execute_model
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     valid_sampled_token_ids = sampled_token_ids.tolist()
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] RuntimeError: CUDA error: device-side assert triggered
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] 
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Traceback (most recent call last):
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py", line 465, in worker_busy_loop
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = func(*args, **kwargs)
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = self.model_runner.execute_model(scheduler_output)
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1148, in execute_model
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     valid_sampled_token_ids = sampled_token_ids.tolist()
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] RuntimeError: CUDA error: device-side assert triggered
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] 
14:56:29.348 �[1;36m(VllmWorker rank=3 pid=6073)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] 
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] WorkerProc hit an exception.
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Traceback (most recent call last):
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py", line 465, in worker_busy_loop
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = func(*args, **kwargs)
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = self.model_runner.execute_model(scheduler_output)
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.396 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1148, in execute_model
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     valid_sampled_token_ids = sampled_token_ids.tolist()
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] RuntimeError: CUDA error: device-side assert triggered
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] 
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Traceback (most recent call last):
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py", line 465, in worker_busy_loop
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = func(*args, **kwargs)
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     output = self.model_runner.execute_model(scheduler_output)
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     return func(*args, **kwargs)
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]            ^^^^^^^^^^^^^^^^^^^^^
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]   File "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-aipnlp/INS/ruanjunhao04/miniforge3/envs/sglang/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1148, in execute_model
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]     valid_sampled_token_ids = sampled_token_ids.tolist()
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] RuntimeError: CUDA error: device-side assert triggered
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
14:56:29.397 �[1;36m(VllmWorker rank=0 pid=6069)�[0;0m ERROR 04-22 14:56:29 [multiproc_executor.py:470] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

One suspicious point is that I set the hyperparameter to enable the model to handle a longer context:

import os
os.environ['VLLM_ALLOW_LONG_MAX_MODEL_LEN'] = '1'

llm = LLM(model=model_path,tensor_parallel_size=device_count(),enable_prefix_caching=True,task='generate',max_model_len=50000,dtype='bfloat16')

@darkness8i8
Copy link

@yangw-dev @DarkLight1337 the issue with !!!! output is happening to all models I am training with dtype = torch.float16, . I usually train with Llama 3.1 8B. Can you please look at this problem holistically I don't believe it is model specific.

@DarkLight1337
Copy link
Member

If the model is originally trained on bfloat16, then there may be numerical stability issues when using float16 for inference due to the narrower range float16 supports.

@darkness8i8
Copy link

@DarkLight1337 that could be true, thanks. Please consider adding better error outputs than !!! This would be super helpful

jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: intervitens <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: intervitens <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet