Skip to content

Commit 3c51248

Browse files
author
niushengxiao
committed
feat: add flashinfer decode mla operator in the attention module
1 parent a07dabe commit 3c51248

File tree

5 files changed

+193
-11
lines changed

5 files changed

+193
-11
lines changed

lightllm/models/deepseek2/infer_struct.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,70 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
7+
import flashinfer
68

79

810
class Deepseek2InferStateInfo(LlamaInferStateInfo):
911
def __init__(self):
1012
super().__init__()
1113
self.kv_starts = None
1214
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
15+
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
16+
"ON",
17+
"TRUE",
18+
"1",
19+
]
1320

1421
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1522
super().init_some_extra_state(model, input_ids)
1623
# 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
1724
if not self.is_prefill:
1825
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
1926
self.total_token_num_tensor = torch.sum(self.b_seq_len)
27+
if self.enable_flashinfer_decode_mla:
28+
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(input_ids.device)
29+
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
30+
self.kv_indices = torch.empty(self.batch_size * model.max_seq_length, dtype=torch.int32).to(
31+
input_ids.device
32+
)
33+
repack_kv_index(
34+
self.req_manager.req_to_token_indexs,
35+
self.b_req_idx,
36+
self.b_seq_len,
37+
self.b_start_loc,
38+
self.max_len_in_batch,
39+
self.kv_indices,
40+
)
41+
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
42+
self.workspace_buffer,
43+
backend="fa2",
44+
use_cuda_graph=True,
45+
qo_indptr=self.q_indptr,
46+
kv_indices=self.kv_indices,
47+
kv_indptr=self.kv_starts,
48+
kv_len_arr=self.b_seq_len,
49+
)
50+
self.head_num = model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
51+
self.kv_lora_rank = model.kv_lora_rank
52+
self.qk_rope_head_dim = model.qk_rope_head_dim
53+
self.softmax_scale = model.softmax_scale
54+
self.q_data_type = model.data_type
55+
self.kv_data_type = model.data_type
56+
self.wrapper.plan(
57+
self.q_indptr,
58+
self.kv_starts,
59+
self.kv_indices,
60+
self.b_seq_len,
61+
self.head_num,
62+
self.kv_lora_rank,
63+
self.qk_rope_head_dim,
64+
1,
65+
False, # causal
66+
self.softmax_scale,
67+
self.q_data_type,
68+
self.kv_data_type,
69+
)
2070

2171
if self.is_prefill:
2272
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len
@@ -36,3 +86,22 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3686
self.end_idx = self.all_end_idx[rank]
3787

3888
return
89+
90+
def copy_for_cuda_graph(self, new_infer_state):
91+
super().copy_for_cuda_graph(new_infer_state)
92+
if self.enable_flashinfer_decode_mla:
93+
self.wrapper.plan(
94+
self.q_indptr,
95+
self.kv_starts,
96+
self.kv_indices,
97+
self.b_seq_len,
98+
self.head_num,
99+
self.kv_lora_rank,
100+
self.qk_rope_head_dim,
101+
1,
102+
False, # causal
103+
self.softmax_scale,
104+
self.q_data_type,
105+
self.kv_data_type,
106+
)
107+
return

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6969
self.num_heads = network_config["num_attention_heads"]
7070
self.num_kv_heads = network_config["num_key_value_heads"]
7171
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
72+
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
73+
"ON",
74+
"TRUE",
75+
"1",
76+
]
7277
return
7378

7479
def _bind_func(self):
@@ -369,7 +374,17 @@ def _token_gqa_decode_attention_flashdecoding(
369374
infer_state.b_req_idx,
370375
self.softmax_scale,
371376
q.shape[-1],
372-
q_nope.shape[-1],
377+
self.kv_lora_rank,
378+
)
379+
return o_tensor
380+
elif self.enable_flashinfer_decode_mla:
381+
infer_state.wrapper.run(
382+
q_nope,
383+
q_rope,
384+
kv[:, :, : -self.qk_rope_head_dim],
385+
kv[:, :, -self.qk_rope_head_dim :],
386+
out=o_tensor,
387+
return_lse=False,
373388
)
374389
return o_tensor
375390
else:

lightllm/models/deepseek2/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
99
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1010
from lightllm.utils.log_utils import init_logger
11+
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
1112

1213

1314
logger = init_logger(__name__)
@@ -37,6 +38,15 @@ def _init_some_value(self):
3738
self.q_lora_rank = self.config["q_lora_rank"]
3839
self.kv_lora_rank = self.config["kv_lora_rank"]
3940
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
41+
self.tp_q_head_num_ = self.config["num_attention_heads"] // self.world_size_
42+
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
43+
if self.config["rope_scaling"] is not None:
44+
rope_scaling = self.config["rope_scaling"]
45+
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0)
46+
scaling_factor = rope_scaling["factor"]
47+
if mscale_all_dim:
48+
mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim)
49+
self.softmax_scale = self.softmax_scale * mscale * mscale
4050

4151
def _init_custom(self):
4252
self._init_to_get_yarn_rotary()

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _fwd_kernel_calcu_index_and_block_seq(
165165
import flashinfer
166166
import lightllm_ppl_mla
167167

168-
Z, N_CTX, H, D_HEAD, ROPE_HEAD = 200, 16384, 16, 512, 64
168+
Z, N_CTX, H, D_HEAD, ROPE_HEAD = 10, 1024, 16, 512, 64
169169
dtype = torch.bfloat16
170170
sm_scale = 1.0 / ((D_HEAD + ROPE_HEAD) ** 0.5)
171171
q_nope = torch.randn((Z, H, D_HEAD), dtype=dtype, device="cuda")
@@ -181,11 +181,11 @@ def _fwd_kernel_calcu_index_and_block_seq(
181181
b_start_loc = torch.arange(Z).cuda().int() * N_CTX
182182
b_start_loc[0] = 0
183183
b_req_idx = torch.arange(Z).cuda().int()
184-
req_to_token_indexs = torch.arange(Z * N_CTX, dtype=torch.int32).cuda().view(req_to_token_indexs.shape)
185184
kv_starts = torch.cat([b_start_loc, b_start_loc[-1:] + b_seq_len[-1:]], dim=0)
186185

187186
o = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda")
188187
o1 = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda")
188+
o2 = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda")
189189

190190
infer_state = Deepseek2InferStateInfo()
191191
infer_state.batch_size = Z
@@ -212,7 +212,6 @@ def _fwd_kernel_calcu_index_and_block_seq(
212212
sm_scale,
213213
o,
214214
)
215-
fn1()
216215

217216
q = torch.cat([q_nope, q_rope], dim=-1)
218217
fn2 = lambda: lightllm_ppl_mla.decode_mla(
@@ -226,17 +225,18 @@ def _fwd_kernel_calcu_index_and_block_seq(
226225
D_HEAD + ROPE_HEAD,
227226
D_HEAD,
228227
)
229-
fn2()
230228

231229
batch_size = Z
232230
head_dim_ckv = D_HEAD
233231
head_dim_kpe = ROPE_HEAD
234232
num_heads = H
235233
page_size = 1
236234
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
237-
q_indptr = torch.arange(0, batch_size + 1).to(0).int()
235+
q_indptr = torch.arange(batch_size + 1).to(0).int()
238236
kv_indptr = infer_state.kv_starts
239237
kv_indices = torch.arange(Z * N_CTX).cuda().int()
238+
for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc):
239+
kv_indices[start : start + sl] = req_to_token_indexs[b][:sl]
240240
kv_lens = b_seq_len
241241
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
242242
workspace_buffer,
@@ -261,15 +261,15 @@ def _fwd_kernel_calcu_index_and_block_seq(
261261
q_nope.dtype,
262262
kv.dtype,
263263
)
264-
o2 = wrapper.run(q_nope, q_rope, kv_nope, kv_rope, return_lse=False)
265-
fn3 = lambda: wrapper.run(q_nope, q_rope, kv_nope, kv_rope, return_lse=False)
264+
fn3 = lambda: wrapper.run(q_nope, q_rope, kv_nope, kv_rope, out=o2, return_lse=False)
266265

267-
cos_sim1 = F.cosine_similarity(o, o1).mean()
268-
cos_sim2 = F.cosine_similarity(o, o2).mean()
269-
print(cos_sim1, cos_sim2)
270266
ms1 = triton.testing.do_bench_cudagraph(fn1)
271267
ms2 = triton.testing.do_bench_cudagraph(fn2)
272268
ms3 = triton.testing.do_bench_cudagraph(fn3)
273269
print(ms1)
274270
print(ms2)
275271
print(ms3)
272+
273+
cos_sim1 = F.cosine_similarity(o, o1).mean()
274+
cos_sim2 = F.cosine_similarity(o, o2).mean()
275+
print(cos_sim1, cos_sim2)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _fwd_kernel_repack_kv_index(
9+
kv_index,
10+
req_index,
11+
out_kv_index,
12+
seq_len,
13+
start_loc,
14+
kv_stride_h,
15+
SEQ_BLOCK: tl.constexpr,
16+
):
17+
cur_batch = tl.program_id(0)
18+
start_seq_n = tl.program_id(1)
19+
20+
cur_batch_seq_len = tl.load(seq_len + cur_batch)
21+
cur_batch_req_idx = tl.load(req_index + cur_batch)
22+
cur_batch_start_loc = tl.load(start_loc + cur_batch)
23+
24+
offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
25+
block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len)
26+
kv_index_data = tl.load(
27+
kv_index + kv_stride_h * cur_batch_req_idx + offs_seq,
28+
mask=offs_seq < block_end_loc,
29+
other=0,
30+
)
31+
out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq
32+
tl.store(out_kv_index_ptr, kv_index_data, mask=offs_seq < block_end_loc)
33+
return
34+
35+
36+
@torch.no_grad()
37+
def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index):
38+
batch_size = req_index.shape[0]
39+
BLOCK = 64
40+
grid = (
41+
batch_size,
42+
triton.cdiv(max_seq_len, BLOCK),
43+
)
44+
45+
_fwd_kernel_repack_kv_index[grid](
46+
kv_index,
47+
req_index,
48+
out_kv_index,
49+
seq_len,
50+
start_loc,
51+
kv_index.stride(0),
52+
SEQ_BLOCK=BLOCK,
53+
num_warps=8,
54+
num_stages=1,
55+
)
56+
return
57+
58+
59+
def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output):
60+
for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc):
61+
ref[start : start + sl] = req_to_token_indexs[b][:sl]
62+
63+
64+
if __name__ == "__main__":
65+
import torch.nn.functional as F
66+
67+
BATCH, MAX_SEQ_LEN = 10, 1024
68+
rand_idx = torch.randperm(2 * MAX_SEQ_LEN * BATCH).cuda().int()
69+
b_req_idx = torch.randperm(BATCH).cuda().int()
70+
b_seq_len = torch.randint(1, MAX_SEQ_LEN, (BATCH,)).cuda().int()
71+
req_to_token_indexs = torch.zeros((2 * BATCH, 2 * MAX_SEQ_LEN)).cuda().int()
72+
b_start_loc = (
73+
torch.cat([torch.zeros([1], device=b_seq_len.device, dtype=b_seq_len.dtype), b_seq_len[0:-1].cumsum(0)])
74+
.cuda()
75+
.int()
76+
)
77+
78+
output = torch.zeros((b_seq_len.sum(),)).cuda().int()
79+
ref = torch.zeros((b_seq_len.sum(),)).cuda().int()
80+
for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc):
81+
req_to_token_indexs[b][:sl] = rand_idx[start : start + sl]
82+
83+
fn1 = lambda: repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref)
84+
fn2 = lambda: repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output)
85+
ms1 = triton.testing.do_bench(fn1)
86+
ms2 = triton.testing.do_bench_cudagraph(fn2)
87+
print(ms1, ms2)
88+
assert torch.allclose(output.float(), ref.float())

0 commit comments

Comments
 (0)