Skip to content

Commit cc1e0f7

Browse files
intervitensdbyoung18
authored andcommitted
[Bugfix] Fix GLM4 model (vllm-project#16618)
Signed-off-by: intervitens <[email protected]>
1 parent 2e5d913 commit cc1e0f7

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ See [this page](#generative-models) for more information on how to use generativ
338338
* ✅︎
339339
- * `Glm4ForCausalLM`
340340
* GLM-4-0414
341-
* `THUDM/GLM-4-32B-Chat-0414`, etc.
341+
* `THUDM/GLM-4-32B-0414`, etc.
342342
* ✅︎
343343
* ✅︎
344344
- * `GPT2LMHeadModel`

tests/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def check_available_online(
147147
min_transformers_version="4.50"),
148148
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
149149
"Glm4ForCausalLM": _HfExamplesInfo(
150-
"THUDM/GLM-4-32B-Chat-0414",
150+
"THUDM/GLM-4-32B-0414",
151151
is_available_online=False,
152152
min_transformers_version="4.52.dev0"
153153
),

vllm/model_executor/models/glm4.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self,
8282
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
8383
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
8484
self.head_dim = head_dim or hidden_size // self.total_num_heads
85-
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
85+
self.rotary_dim = self.head_dim
8686
self.q_size = self.num_heads * self.head_dim
8787
self.kv_size = self.num_kv_heads * self.head_dim
8888
self.scaling = self.head_dim**-0.5
@@ -110,6 +110,7 @@ def __init__(self,
110110
base=self.rope_theta,
111111
rope_scaling=rope_scaling,
112112
partial_rotary_factor=partial_rotary_factor,
113+
is_neox_style=False,
113114
)
114115
self.attn = Attention(self.num_heads,
115116
self.head_dim,
@@ -197,13 +198,12 @@ def forward(
197198
)
198199

199200
hidden_states = self.post_self_attn_layernorm(hidden_states)
200-
hidden_states = residual + hidden_states
201201

202202
# Fully Connected
203-
hidden_states = self.post_attention_layernorm(hidden_states, residual)
203+
hidden_states, residual = self.post_attention_layernorm(
204+
hidden_states, residual)
204205
hidden_states = self.mlp(hidden_states)
205206
hidden_states = self.post_mlp_layernorm(hidden_states)
206-
hidden_states = residual + hidden_states
207207

208208
return hidden_states, residual
209209

0 commit comments

Comments
 (0)