Skip to content

Commit d654659

Browse files
committed
add entrance and solve conflict
1 parent 47e1b84 commit d654659

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
self.device = model_runner.device
6060
self.decode_cuda_graph_metadata = {}
6161
self.req_to_token = model_runner.req_to_token_pool.req_to_token
62+
self.page_size = model_runner.page_size
6263
self.use_mla = (
6364
model_runner.model_config.attention_arch == AttentionArch.MLA
6465
) and (not global_server_args_dict["disable_mla"])
@@ -83,6 +84,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
8384
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
8485
forward_batch.req_pool_indices, : metadata.max_seq_len_k
8586
]
87+
88+
# Precompute strided indices
89+
# [0, page_size, 2 * page_size, ...]
90+
if self.page_size > 1:
91+
self.strided_indices = torch.arange(
92+
0, metadata.page_table.shape[1], self.page_size, device=self.device
93+
)
94+
metadata.page_table = (
95+
metadata.page_table[:, self.strided_indices] // self.page_size
96+
)
97+
8698
if forward_batch.forward_mode == ForwardMode.DECODE:
8799
# Precompute cumulative sequence lengths
88100
metadata.cu_seqlens_q = torch.arange(
@@ -141,16 +153,24 @@ def forward_extend(
141153
else (-1, -1)
142154
)
143155

156+
page_table = metadata.page_table
157+
144158
# # Use Flash Attention for prefill
145159
if not self.use_mla:
146160
# Do multi-head attention
147161
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
148162
key_cache, value_cache = kv_cache[0], kv_cache[1]
163+
key_cache = key_cache.view(
164+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
165+
)
166+
value_cache = value_cache.view(
167+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
168+
)
149169
o = flash_attn_with_kvcache(
150170
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
151-
k_cache=key_cache.unsqueeze(1),
152-
v_cache=value_cache.unsqueeze(1),
153-
page_table=metadata.page_table,
171+
k_cache=key_cache,
172+
v_cache=value_cache,
173+
page_table=page_table,
154174
cache_seqlens=metadata.cache_seqlens_int32,
155175
cu_seqlens_q=metadata.cu_seqlens_q,
156176
cu_seqlens_k_new=metadata.cu_seqlens_k,
@@ -176,7 +196,7 @@ def forward_extend(
176196
k_cache=k_rope.unsqueeze(1),
177197
v_cache=c_kv.unsqueeze(1),
178198
qv=q_nope,
179-
page_table=metadata.page_table,
199+
page_table=page_table,
180200
cache_seqlens=metadata.cache_seqlens_int32,
181201
cu_seqlens_q=metadata.cu_seqlens_q,
182202
cu_seqlens_k_new=metadata.cu_seqlens_k,
@@ -231,22 +251,30 @@ def forward_decode(
231251
else (-1, -1)
232252
)
233253

254+
page_table = metadata.page_table
255+
234256
if not self.use_mla:
235257
# Do multi-head attention
236258

237259
# Get KV cache
238260
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
239261
key_cache, value_cache = kv_cache[0], kv_cache[1]
262+
key_cache = key_cache.view(
263+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
264+
)
265+
value_cache = value_cache.view(
266+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
267+
)
240268

241269
# Pre-reshape query tensor
242270
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
243271

244272
# Run attention with precomputed values
245273
o = flash_attn_with_kvcache(
246274
q=q_reshaped,
247-
k_cache=key_cache.unsqueeze(1),
248-
v_cache=value_cache.unsqueeze(1),
249-
page_table=metadata.page_table,
275+
k_cache=key_cache,
276+
v_cache=value_cache,
277+
page_table=page_table,
250278
cache_seqlens=metadata.cache_seqlens_int32,
251279
cu_seqlens_q=metadata.cu_seqlens_q,
252280
cu_seqlens_k_new=metadata.cu_seqlens_k,
@@ -273,7 +301,7 @@ def forward_decode(
273301
k_cache=k_rope.unsqueeze(1),
274302
v_cache=c_kv.unsqueeze(1),
275303
qv=q_nope,
276-
page_table=metadata.page_table,
304+
page_table=page_table,
277305
cache_seqlens=metadata.cache_seqlens_int32,
278306
cu_seqlens_q=metadata.cu_seqlens_q,
279307
cu_seqlens_k_new=metadata.cu_seqlens_k,
@@ -300,7 +328,13 @@ def init_cuda_graph_state(self, max_bs: int):
300328
self.decode_cuda_graph_metadata = {
301329
# Page table for token mapping (batch_size, max_context_len)
302330
"page_table": torch.zeros(
303-
max_bs, self.max_context_len, dtype=torch.int32, device=self.device
331+
max_bs,
332+
(self.max_context_len + self.page_size - 1) // self.page_size,
333+
dtype=torch.int32,
334+
device=self.device,
335+
),
336+
"strided_indices": torch.arange(
337+
0, self.max_context_len, self.page_size, device=self.device
304338
),
305339
}
306340

python/sglang/srt/model_executor/model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def model_specific_adjustment(self):
230230
elif server_args.enable_flashmla:
231231
logger.info("MLA optimization is turned on. Use flashmla decode.")
232232
server_args.attention_backend = "flashmla"
233+
elif server_args.attention_backend == "fa3":
234+
logger.info(
235+
f"MLA optimization is turned on. Use flash attention 3 backend."
236+
)
233237
else:
234238
logger.info("MLA optimization is turned on. Use triton backend.")
235239
server_args.attention_backend = "triton"
@@ -879,7 +883,7 @@ def init_attention_backend(self):
879883
"Please use `--attention-backend flashinfer`."
880884
)
881885
logger.warning(
882-
"FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
886+
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
883887
)
884888
from sglang.srt.layers.attention.flashattention_backend import (
885889
FlashAttentionBackend,

python/sglang/srt/models/deepseek_v2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ def __init__(
655655
self.flashinfer_mla_disable_ragged = global_server_args_dict[
656656
"flashinfer_mla_disable_ragged"
657657
]
658+
self.attention_backend = global_server_args_dict["attention_backend"]
658659
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
659660

660661
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
@@ -667,6 +668,9 @@ def no_absorb(self, forward_batch: ForwardBatch) -> bool:
667668
and not forward_batch.forward_mode.is_draft_extend()
668669
and sum(forward_batch.extend_prefix_lens_cpu) == 0
669670
)
671+
elif self.attention_backend == "fa3":
672+
# Flash Attention: Keep absorbing for all extend/decode
673+
return False
670674
else:
671675
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
672676
return (

0 commit comments

Comments
 (0)