Skip to content

Commit 47e1b84

Browse files
committed
implement mla for fa3
1 parent 6dea5c9 commit 47e1b84

File tree

1 file changed

+133
-91
lines changed

1 file changed

+133
-91
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 133 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
import torch
1515

16+
from sglang.srt.configs.model_config import AttentionArch
1617
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
18+
from sglang.srt.managers.schedule_batch import global_server_args_dict
1719
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
1820

1921
if TYPE_CHECKING:
@@ -57,7 +59,9 @@ def __init__(
5759
self.device = model_runner.device
5860
self.decode_cuda_graph_metadata = {}
5961
self.req_to_token = model_runner.req_to_token_pool.req_to_token
60-
self.page_size = model_runner.page_size
62+
self.use_mla = (
63+
model_runner.model_config.attention_arch == AttentionArch.MLA
64+
) and (not global_server_args_dict["disable_mla"])
6165

6266
def init_forward_metadata(self, forward_batch: ForwardBatch):
6367
"""Initialize forward metadata to cache repetitive calculations."""
@@ -79,17 +83,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
7983
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
8084
forward_batch.req_pool_indices, : metadata.max_seq_len_k
8185
]
82-
83-
# Precompute strided indices
84-
# [0, page_size, 2 * page_size, ...]
85-
if self.page_size > 1:
86-
self.strided_indices = torch.arange(
87-
0, metadata.page_table.shape[1], self.page_size, device=self.device
88-
)
89-
metadata.page_table = (
90-
metadata.page_table[:, self.strided_indices] // self.page_size
91-
)
92-
9386
if forward_batch.forward_mode == ForwardMode.DECODE:
9487
# Precompute cumulative sequence lengths
9588
metadata.cu_seqlens_q = torch.arange(
@@ -117,23 +110,28 @@ def forward_extend(
117110
forward_batch: ForwardBatch,
118111
save_kv_cache=True,
119112
):
120-
cache_loc = (
121-
forward_batch.out_cache_loc
122-
if not layer.is_cross_attention
123-
else forward_batch.encoder_out_cache_loc
124-
)
125113

126-
if k is not None:
127-
assert v is not None
128-
if save_kv_cache:
114+
if k is not None and v is not None and save_kv_cache:
115+
cache_loc = (
116+
forward_batch.out_cache_loc
117+
if not layer.is_cross_attention
118+
else forward_batch.encoder_out_cache_loc
119+
)
120+
if not self.use_mla:
129121
forward_batch.token_to_kv_pool.set_kv_buffer(
130122
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
131123
)
124+
else:
125+
forward_batch.token_to_kv_pool.set_kv_buffer(
126+
layer,
127+
cache_loc,
128+
k,
129+
v,
130+
)
132131

133132
# Use precomputed metadata
134133
metadata = self.forward_metadata
135134

136-
# # Use Flash Attention for prefill
137135
# Calculate window size (can be moved to metadata if layer properties don't change)
138136
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
139137
# here is two side inclusive
@@ -142,36 +140,55 @@ def forward_extend(
142140
if layer.sliding_window_size is not None
143141
else (-1, -1)
144142
)
145-
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
146-
key_cache, value_cache = kv_cache[0], kv_cache[1]
147143

148-
key_cache = key_cache.view(
149-
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
150-
)
151-
value_cache = value_cache.view(
152-
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
153-
)
154-
155-
page_table = metadata.page_table
156-
157-
o = flash_attn_with_kvcache(
158-
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
159-
k_cache=key_cache,
160-
v_cache=value_cache,
161-
page_table=page_table,
162-
cache_seqlens=metadata.cache_seqlens_int32,
163-
cu_seqlens_q=metadata.cu_seqlens_q,
164-
cu_seqlens_k_new=metadata.cu_seqlens_k,
165-
max_seqlen_q=metadata.max_seq_len_q,
166-
softmax_scale=layer.scaling,
167-
causal=True,
168-
window_size=window_size,
169-
softcap=layer.logit_cap,
170-
k_descale=layer.k_scale,
171-
v_descale=layer.v_scale,
172-
)
144+
# # Use Flash Attention for prefill
145+
if not self.use_mla:
146+
# Do multi-head attention
147+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
148+
key_cache, value_cache = kv_cache[0], kv_cache[1]
149+
o = flash_attn_with_kvcache(
150+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
151+
k_cache=key_cache.unsqueeze(1),
152+
v_cache=value_cache.unsqueeze(1),
153+
page_table=metadata.page_table,
154+
cache_seqlens=metadata.cache_seqlens_int32,
155+
cu_seqlens_q=metadata.cu_seqlens_q,
156+
cu_seqlens_k_new=metadata.cu_seqlens_k,
157+
max_seqlen_q=metadata.max_seq_len_q,
158+
softmax_scale=layer.scaling,
159+
causal=True,
160+
window_size=window_size,
161+
softcap=layer.logit_cap,
162+
k_descale=layer.k_scale,
163+
v_descale=layer.v_scale,
164+
)
165+
else:
166+
# Do absorbed multi-latent attention
167+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
168+
c_kv = kv_cache[:, :, : layer.v_head_dim]
169+
k_rope = kv_cache[:, :, layer.v_head_dim :]
170+
171+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
172+
q_nope = q_all[:, :, : layer.v_head_dim]
173+
q_rope = q_all[:, :, layer.v_head_dim :]
174+
o = flash_attn_with_kvcache(
175+
q=q_rope,
176+
k_cache=k_rope.unsqueeze(1),
177+
v_cache=c_kv.unsqueeze(1),
178+
qv=q_nope,
179+
page_table=metadata.page_table,
180+
cache_seqlens=metadata.cache_seqlens_int32,
181+
cu_seqlens_q=metadata.cu_seqlens_q,
182+
cu_seqlens_k_new=metadata.cu_seqlens_k,
183+
max_seqlen_q=metadata.max_seq_len_q,
184+
softmax_scale=layer.scaling,
185+
causal=True,
186+
softcap=layer.logit_cap,
187+
k_descale=layer.k_scale,
188+
v_descale=layer.v_scale,
189+
)
173190

174-
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
191+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
175192

176193
def forward_decode(
177194
self,
@@ -190,18 +207,21 @@ def forward_decode(
190207
if not layer.is_cross_attention
191208
else forward_batch.encoder_out_cache_loc
192209
)
193-
forward_batch.token_to_kv_pool.set_kv_buffer(
194-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
195-
)
210+
if not self.use_mla:
211+
forward_batch.token_to_kv_pool.set_kv_buffer(
212+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
213+
)
214+
else:
215+
forward_batch.token_to_kv_pool.set_kv_buffer(
216+
layer,
217+
cache_loc,
218+
k,
219+
v,
220+
)
196221

197-
# Get KV cache
198-
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
199-
key_cache, value_cache = kv_cache[0], kv_cache[1]
200222
# Use precomputed metadata
201223
metadata = self.forward_metadata
202224

203-
# Pre-reshape query tensor
204-
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
205225
# Calculate window size (can be moved to metadata if layer properties don't change)
206226
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
207227
# here is two side inclusive
@@ -210,33 +230,62 @@ def forward_decode(
210230
if layer.sliding_window_size is not None
211231
else (-1, -1)
212232
)
213-
# Run attention with precomputed values
214-
key_cache = key_cache.view(
215-
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
216-
)
217-
value_cache = value_cache.view(
218-
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
219-
)
220233

221-
page_table = metadata.page_table
222-
223-
o = flash_attn_with_kvcache(
224-
q=q_reshaped,
225-
k_cache=key_cache,
226-
v_cache=value_cache,
227-
page_table=page_table,
228-
cache_seqlens=metadata.cache_seqlens_int32,
229-
cu_seqlens_q=metadata.cu_seqlens_q,
230-
cu_seqlens_k_new=metadata.cu_seqlens_k,
231-
max_seqlen_q=1,
232-
softmax_scale=layer.scaling,
233-
causal=True,
234-
window_size=window_size,
235-
softcap=layer.logit_cap,
236-
k_descale=layer.k_scale,
237-
v_descale=layer.v_scale,
238-
)
239-
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
234+
if not self.use_mla:
235+
# Do multi-head attention
236+
237+
# Get KV cache
238+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
239+
key_cache, value_cache = kv_cache[0], kv_cache[1]
240+
241+
# Pre-reshape query tensor
242+
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
243+
244+
# Run attention with precomputed values
245+
o = flash_attn_with_kvcache(
246+
q=q_reshaped,
247+
k_cache=key_cache.unsqueeze(1),
248+
v_cache=value_cache.unsqueeze(1),
249+
page_table=metadata.page_table,
250+
cache_seqlens=metadata.cache_seqlens_int32,
251+
cu_seqlens_q=metadata.cu_seqlens_q,
252+
cu_seqlens_k_new=metadata.cu_seqlens_k,
253+
max_seqlen_q=1,
254+
softmax_scale=layer.scaling,
255+
causal=True,
256+
window_size=window_size,
257+
softcap=layer.logit_cap,
258+
k_descale=layer.k_scale,
259+
v_descale=layer.v_scale,
260+
)
261+
else:
262+
# Do absorbed multi-latent attention
263+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
264+
c_kv = kv_cache[:, :, : layer.v_head_dim]
265+
k_rope = kv_cache[:, :, layer.v_head_dim :]
266+
267+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
268+
q_nope = q_all[:, :, : layer.v_head_dim]
269+
q_rope = q_all[:, :, layer.v_head_dim :]
270+
271+
o = flash_attn_with_kvcache(
272+
q=q_rope,
273+
k_cache=k_rope.unsqueeze(1),
274+
v_cache=c_kv.unsqueeze(1),
275+
qv=q_nope,
276+
page_table=metadata.page_table,
277+
cache_seqlens=metadata.cache_seqlens_int32,
278+
cu_seqlens_q=metadata.cu_seqlens_q,
279+
cu_seqlens_k_new=metadata.cu_seqlens_k,
280+
max_seqlen_q=1,
281+
softmax_scale=layer.scaling,
282+
causal=True,
283+
softcap=layer.logit_cap,
284+
k_descale=layer.k_scale,
285+
v_descale=layer.v_scale,
286+
)
287+
288+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
240289

241290
def init_cuda_graph_state(self, max_bs: int):
242291
"""Initialize CUDA graph state for the attention backend.
@@ -251,13 +300,7 @@ def init_cuda_graph_state(self, max_bs: int):
251300
self.decode_cuda_graph_metadata = {
252301
# Page table for token mapping (batch_size, max_context_len)
253302
"page_table": torch.zeros(
254-
max_bs,
255-
(self.max_context_len + self.page_size - 1) // self.page_size,
256-
dtype=torch.int32,
257-
device=self.device,
258-
),
259-
"strided_indices": torch.arange(
260-
0, self.max_context_len, self.page_size, device=self.device
303+
max_bs, self.max_context_len, dtype=torch.int32, device=self.device
261304
),
262305
}
263306

@@ -286,7 +329,6 @@ def init_forward_metadata_capture_cuda_graph(
286329
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
287330
req_pool_indices, :
288331
]
289-
290332
if forward_mode == ForwardMode.DECODE:
291333
# Precompute cumulative sequence lengths
292334
metadata.cu_seqlens_q = torch.arange(

0 commit comments

Comments
 (0)