Skip to content

Commit 0bc0bf5

Browse files
authored
gemma3: impl get_attention_sliding_window_size for attn init (#4823)
1 parent f60f293 commit 0bc0bf5

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

python/sglang/srt/models/gemma3_causal.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
from sglang.srt.utils import add_prefix, make_layers
4848

4949

50+
# Aligned with HF's implementation, using sliding window inclusive with the last token
51+
# SGLang assumes exclusive
52+
def get_attention_sliding_window_size(config):
53+
return config.sliding_window - 1
54+
55+
5056
# Adapted from:
5157
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
5258
def extract_layer_index(prefix: str) -> int:
@@ -170,7 +176,7 @@ def __init__(
170176
self.rope_scaling = {"rope_type": "default"}
171177
# FIXME(mick): idk why vllm does this
172178
# self.sliding_window = config.interleaved_sliding_window
173-
self.sliding_window = config.sliding_window
179+
self.sliding_window = get_attention_sliding_window_size(config)
174180
else:
175181
# Global attention. Use the values in config.json.
176182
self.rope_theta = config.rope_theta
@@ -184,6 +190,8 @@ def __init__(
184190
num_kv_heads=self.num_kv_heads,
185191
layer_id=layer_id,
186192
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
193+
# Module must also define `get_attention_sliding_window_size` to correctly initialize
194+
# attention backend in `ForwardBatch`.
187195
sliding_window_size=self.sliding_window,
188196
prefix=add_prefix("attn", prefix),
189197
)
@@ -609,6 +617,9 @@ def __init__(
609617
def get_input_embeddings(self) -> nn.Embedding:
610618
return self.model.embed_tokens
611619

620+
def get_attention_sliding_window_size(self):
621+
return get_attention_sliding_window_size(self.config)
622+
612623
def dtype(self) -> torch.dtype:
613624
return next(self.parameters()).dtype
614625

@@ -621,7 +632,6 @@ def forward(
621632
input_embeds: torch.Tensor = None,
622633
**kwargs,
623634
) -> LogitsProcessor:
624-
625635
hidden_states = self.model(
626636
input_ids, positions, forward_batch, input_embeds, **kwargs
627637
)

python/sglang/srt/models/gemma3_mm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ def prepare_attn_masks(
268268
def get_input_embeddings(self) -> nn.Embedding:
269269
return self.language_model.get_input_embeddings()
270270

271+
def get_attention_sliding_window_size(self):
272+
"""
273+
This value is used to initialize attention backends in `ForwardBatch`.
274+
"""
275+
return self.language_model.get_attention_sliding_window_size()
276+
271277
def get_image_feature(self, image_input: MultimodalInputs):
272278
"""
273279
Projects the last hidden state from the vision model into language model space.

0 commit comments

Comments
 (0)