47
47
from sglang .srt .utils import add_prefix , make_layers
48
48
49
49
50
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
51
+ # SGLang assumes exclusive
52
+ def get_attention_sliding_window_size (config ):
53
+ return config .sliding_window - 1
54
+
55
+
50
56
# Adapted from:
51
57
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
52
58
def extract_layer_index (prefix : str ) -> int :
@@ -170,7 +176,7 @@ def __init__(
170
176
self .rope_scaling = {"rope_type" : "default" }
171
177
# FIXME(mick): idk why vllm does this
172
178
# self.sliding_window = config.interleaved_sliding_window
173
- self .sliding_window = config . sliding_window
179
+ self .sliding_window = get_attention_sliding_window_size ( config )
174
180
else :
175
181
# Global attention. Use the values in config.json.
176
182
self .rope_theta = config .rope_theta
@@ -184,6 +190,8 @@ def __init__(
184
190
num_kv_heads = self .num_kv_heads ,
185
191
layer_id = layer_id ,
186
192
logit_cap = getattr (self .config , "attn_logit_softcapping" , None ),
193
+ # Module must also define `get_attention_sliding_window_size` to correctly initialize
194
+ # attention backend in `ForwardBatch`.
187
195
sliding_window_size = self .sliding_window ,
188
196
prefix = add_prefix ("attn" , prefix ),
189
197
)
@@ -609,6 +617,9 @@ def __init__(
609
617
def get_input_embeddings (self ) -> nn .Embedding :
610
618
return self .model .embed_tokens
611
619
620
+ def get_attention_sliding_window_size (self ):
621
+ return get_attention_sliding_window_size (self .config )
622
+
612
623
def dtype (self ) -> torch .dtype :
613
624
return next (self .parameters ()).dtype
614
625
@@ -621,7 +632,6 @@ def forward(
621
632
input_embeds : torch .Tensor = None ,
622
633
** kwargs ,
623
634
) -> LogitsProcessor :
624
-
625
635
hidden_states = self .model (
626
636
input_ids , positions , forward_batch , input_embeds , ** kwargs
627
637
)
0 commit comments