13
13
14
14
import torch
15
15
16
+ from sglang .srt .configs .model_config import AttentionArch
16
17
from sglang .srt .layers .attention .base_attn_backend import AttentionBackend
18
+ from sglang .srt .managers .schedule_batch import global_server_args_dict
17
19
from sglang .srt .model_executor .forward_batch_info import ForwardBatch , ForwardMode
18
20
19
21
if TYPE_CHECKING :
@@ -57,7 +59,9 @@ def __init__(
57
59
self .device = model_runner .device
58
60
self .decode_cuda_graph_metadata = {}
59
61
self .req_to_token = model_runner .req_to_token_pool .req_to_token
60
- self .page_size = model_runner .page_size
62
+ self .use_mla = (
63
+ model_runner .model_config .attention_arch == AttentionArch .MLA
64
+ ) and (not global_server_args_dict ["disable_mla" ])
61
65
62
66
def init_forward_metadata (self , forward_batch : ForwardBatch ):
63
67
"""Initialize forward metadata to cache repetitive calculations."""
@@ -79,17 +83,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
79
83
metadata .page_table = forward_batch .req_to_token_pool .req_to_token [
80
84
forward_batch .req_pool_indices , : metadata .max_seq_len_k
81
85
]
82
-
83
- # Precompute strided indices
84
- # [0, page_size, 2 * page_size, ...]
85
- if self .page_size > 1 :
86
- self .strided_indices = torch .arange (
87
- 0 , metadata .page_table .shape [1 ], self .page_size , device = self .device
88
- )
89
- metadata .page_table = (
90
- metadata .page_table [:, self .strided_indices ] // self .page_size
91
- )
92
-
93
86
if forward_batch .forward_mode == ForwardMode .DECODE :
94
87
# Precompute cumulative sequence lengths
95
88
metadata .cu_seqlens_q = torch .arange (
@@ -117,23 +110,28 @@ def forward_extend(
117
110
forward_batch : ForwardBatch ,
118
111
save_kv_cache = True ,
119
112
):
120
- cache_loc = (
121
- forward_batch .out_cache_loc
122
- if not layer .is_cross_attention
123
- else forward_batch .encoder_out_cache_loc
124
- )
125
113
126
- if k is not None :
127
- assert v is not None
128
- if save_kv_cache :
114
+ if k is not None and v is not None and save_kv_cache :
115
+ cache_loc = (
116
+ forward_batch .out_cache_loc
117
+ if not layer .is_cross_attention
118
+ else forward_batch .encoder_out_cache_loc
119
+ )
120
+ if not self .use_mla :
129
121
forward_batch .token_to_kv_pool .set_kv_buffer (
130
122
layer , cache_loc , k , v , layer .k_scale , layer .v_scale
131
123
)
124
+ else :
125
+ forward_batch .token_to_kv_pool .set_kv_buffer (
126
+ layer ,
127
+ cache_loc ,
128
+ k ,
129
+ v ,
130
+ )
132
131
133
132
# Use precomputed metadata
134
133
metadata = self .forward_metadata
135
134
136
- # # Use Flash Attention for prefill
137
135
# Calculate window size (can be moved to metadata if layer properties don't change)
138
136
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
139
137
# here is two side inclusive
@@ -142,36 +140,55 @@ def forward_extend(
142
140
if layer .sliding_window_size is not None
143
141
else (- 1 , - 1 )
144
142
)
145
- kv_cache = forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id )
146
- key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
147
143
148
- key_cache = key_cache .view (
149
- - 1 , self .page_size , layer .tp_k_head_num , layer .head_dim
150
- )
151
- value_cache = value_cache .view (
152
- - 1 , self .page_size , layer .tp_v_head_num , layer .head_dim
153
- )
154
-
155
- page_table = metadata .page_table
156
-
157
- o = flash_attn_with_kvcache (
158
- q = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
159
- k_cache = key_cache ,
160
- v_cache = value_cache ,
161
- page_table = page_table ,
162
- cache_seqlens = metadata .cache_seqlens_int32 ,
163
- cu_seqlens_q = metadata .cu_seqlens_q ,
164
- cu_seqlens_k_new = metadata .cu_seqlens_k ,
165
- max_seqlen_q = metadata .max_seq_len_q ,
166
- softmax_scale = layer .scaling ,
167
- causal = True ,
168
- window_size = window_size ,
169
- softcap = layer .logit_cap ,
170
- k_descale = layer .k_scale ,
171
- v_descale = layer .v_scale ,
172
- )
144
+ # # Use Flash Attention for prefill
145
+ if not self .use_mla :
146
+ # Do multi-head attention
147
+ kv_cache = forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id )
148
+ key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
149
+ o = flash_attn_with_kvcache (
150
+ q = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
151
+ k_cache = key_cache .unsqueeze (1 ),
152
+ v_cache = value_cache .unsqueeze (1 ),
153
+ page_table = metadata .page_table ,
154
+ cache_seqlens = metadata .cache_seqlens_int32 ,
155
+ cu_seqlens_q = metadata .cu_seqlens_q ,
156
+ cu_seqlens_k_new = metadata .cu_seqlens_k ,
157
+ max_seqlen_q = metadata .max_seq_len_q ,
158
+ softmax_scale = layer .scaling ,
159
+ causal = True ,
160
+ window_size = window_size ,
161
+ softcap = layer .logit_cap ,
162
+ k_descale = layer .k_scale ,
163
+ v_descale = layer .v_scale ,
164
+ )
165
+ else :
166
+ # Do absorbed multi-latent attention
167
+ kv_cache = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id )
168
+ c_kv = kv_cache [:, :, : layer .v_head_dim ]
169
+ k_rope = kv_cache [:, :, layer .v_head_dim :]
170
+
171
+ q_all = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim )
172
+ q_nope = q_all [:, :, : layer .v_head_dim ]
173
+ q_rope = q_all [:, :, layer .v_head_dim :]
174
+ o = flash_attn_with_kvcache (
175
+ q = q_rope ,
176
+ k_cache = k_rope .unsqueeze (1 ),
177
+ v_cache = c_kv .unsqueeze (1 ),
178
+ qv = q_nope ,
179
+ page_table = metadata .page_table ,
180
+ cache_seqlens = metadata .cache_seqlens_int32 ,
181
+ cu_seqlens_q = metadata .cu_seqlens_q ,
182
+ cu_seqlens_k_new = metadata .cu_seqlens_k ,
183
+ max_seqlen_q = metadata .max_seq_len_q ,
184
+ softmax_scale = layer .scaling ,
185
+ causal = True ,
186
+ softcap = layer .logit_cap ,
187
+ k_descale = layer .k_scale ,
188
+ v_descale = layer .v_scale ,
189
+ )
173
190
174
- return o .view (- 1 , layer .tp_q_head_num * layer .head_dim )
191
+ return o .view (- 1 , layer .tp_q_head_num * layer .v_head_dim )
175
192
176
193
def forward_decode (
177
194
self ,
@@ -190,18 +207,21 @@ def forward_decode(
190
207
if not layer .is_cross_attention
191
208
else forward_batch .encoder_out_cache_loc
192
209
)
193
- forward_batch .token_to_kv_pool .set_kv_buffer (
194
- layer , cache_loc , k , v , layer .k_scale , layer .v_scale
195
- )
210
+ if not self .use_mla :
211
+ forward_batch .token_to_kv_pool .set_kv_buffer (
212
+ layer , cache_loc , k , v , layer .k_scale , layer .v_scale
213
+ )
214
+ else :
215
+ forward_batch .token_to_kv_pool .set_kv_buffer (
216
+ layer ,
217
+ cache_loc ,
218
+ k ,
219
+ v ,
220
+ )
196
221
197
- # Get KV cache
198
- kv_cache = forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id )
199
- key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
200
222
# Use precomputed metadata
201
223
metadata = self .forward_metadata
202
224
203
- # Pre-reshape query tensor
204
- q_reshaped = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim )
205
225
# Calculate window size (can be moved to metadata if layer properties don't change)
206
226
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
207
227
# here is two side inclusive
@@ -210,33 +230,62 @@ def forward_decode(
210
230
if layer .sliding_window_size is not None
211
231
else (- 1 , - 1 )
212
232
)
213
- # Run attention with precomputed values
214
- key_cache = key_cache .view (
215
- - 1 , self .page_size , layer .tp_k_head_num , layer .head_dim
216
- )
217
- value_cache = value_cache .view (
218
- - 1 , self .page_size , layer .tp_v_head_num , layer .head_dim
219
- )
220
233
221
- page_table = metadata .page_table
222
-
223
- o = flash_attn_with_kvcache (
224
- q = q_reshaped ,
225
- k_cache = key_cache ,
226
- v_cache = value_cache ,
227
- page_table = page_table ,
228
- cache_seqlens = metadata .cache_seqlens_int32 ,
229
- cu_seqlens_q = metadata .cu_seqlens_q ,
230
- cu_seqlens_k_new = metadata .cu_seqlens_k ,
231
- max_seqlen_q = 1 ,
232
- softmax_scale = layer .scaling ,
233
- causal = True ,
234
- window_size = window_size ,
235
- softcap = layer .logit_cap ,
236
- k_descale = layer .k_scale ,
237
- v_descale = layer .v_scale ,
238
- )
239
- return o .view (- 1 , layer .tp_q_head_num * layer .head_dim )
234
+ if not self .use_mla :
235
+ # Do multi-head attention
236
+
237
+ # Get KV cache
238
+ kv_cache = forward_batch .token_to_kv_pool .get_kv_buffer (layer .layer_id )
239
+ key_cache , value_cache = kv_cache [0 ], kv_cache [1 ]
240
+
241
+ # Pre-reshape query tensor
242
+ q_reshaped = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim )
243
+
244
+ # Run attention with precomputed values
245
+ o = flash_attn_with_kvcache (
246
+ q = q_reshaped ,
247
+ k_cache = key_cache .unsqueeze (1 ),
248
+ v_cache = value_cache .unsqueeze (1 ),
249
+ page_table = metadata .page_table ,
250
+ cache_seqlens = metadata .cache_seqlens_int32 ,
251
+ cu_seqlens_q = metadata .cu_seqlens_q ,
252
+ cu_seqlens_k_new = metadata .cu_seqlens_k ,
253
+ max_seqlen_q = 1 ,
254
+ softmax_scale = layer .scaling ,
255
+ causal = True ,
256
+ window_size = window_size ,
257
+ softcap = layer .logit_cap ,
258
+ k_descale = layer .k_scale ,
259
+ v_descale = layer .v_scale ,
260
+ )
261
+ else :
262
+ # Do absorbed multi-latent attention
263
+ kv_cache = forward_batch .token_to_kv_pool .get_key_buffer (layer .layer_id )
264
+ c_kv = kv_cache [:, :, : layer .v_head_dim ]
265
+ k_rope = kv_cache [:, :, layer .v_head_dim :]
266
+
267
+ q_all = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim )
268
+ q_nope = q_all [:, :, : layer .v_head_dim ]
269
+ q_rope = q_all [:, :, layer .v_head_dim :]
270
+
271
+ o = flash_attn_with_kvcache (
272
+ q = q_rope ,
273
+ k_cache = k_rope .unsqueeze (1 ),
274
+ v_cache = c_kv .unsqueeze (1 ),
275
+ qv = q_nope ,
276
+ page_table = metadata .page_table ,
277
+ cache_seqlens = metadata .cache_seqlens_int32 ,
278
+ cu_seqlens_q = metadata .cu_seqlens_q ,
279
+ cu_seqlens_k_new = metadata .cu_seqlens_k ,
280
+ max_seqlen_q = 1 ,
281
+ softmax_scale = layer .scaling ,
282
+ causal = True ,
283
+ softcap = layer .logit_cap ,
284
+ k_descale = layer .k_scale ,
285
+ v_descale = layer .v_scale ,
286
+ )
287
+
288
+ return o .view (- 1 , layer .tp_q_head_num * layer .v_head_dim )
240
289
241
290
def init_cuda_graph_state (self , max_bs : int ):
242
291
"""Initialize CUDA graph state for the attention backend.
@@ -251,13 +300,7 @@ def init_cuda_graph_state(self, max_bs: int):
251
300
self .decode_cuda_graph_metadata = {
252
301
# Page table for token mapping (batch_size, max_context_len)
253
302
"page_table" : torch .zeros (
254
- max_bs ,
255
- (self .max_context_len + self .page_size - 1 ) // self .page_size ,
256
- dtype = torch .int32 ,
257
- device = self .device ,
258
- ),
259
- "strided_indices" : torch .arange (
260
- 0 , self .max_context_len , self .page_size , device = self .device
303
+ max_bs , self .max_context_len , dtype = torch .int32 , device = self .device
261
304
),
262
305
}
263
306
@@ -286,7 +329,6 @@ def init_forward_metadata_capture_cuda_graph(
286
329
metadata .page_table = self .decode_cuda_graph_metadata ["page_table" ][
287
330
req_pool_indices , :
288
331
]
289
-
290
332
if forward_mode == ForwardMode .DECODE :
291
333
# Precompute cumulative sequence lengths
292
334
metadata .cu_seqlens_q = torch .arange (
0 commit comments