Skip to content

Commit 00e4de7

Browse files
authored
fix: fix a precision bug in the context_flashattention (#743)
1 parent 9313a08 commit 00e4de7

File tree

2 files changed

+88
-7
lines changed

2 files changed

+88
-7
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _context_attention_kernel_with_CC(
227227
v,
228228
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
229229
infer_state.b_start_loc,
230+
infer_state.b_kv_start_loc,
230231
infer_state.b_seq_len,
231232
infer_state.b_ready_cache_len,
232233
infer_state.max_len_in_batch,
@@ -253,6 +254,7 @@ def _context_attention_kernel_with_CC_fp8(
253254
v,
254255
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
255256
infer_state.b_start_loc,
257+
infer_state.b_kv_start_loc,
256258
infer_state.b_seq_len,
257259
infer_state.b_ready_cache_len,
258260
infer_state.max_len_in_batch,

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def _fwd_kernel_with_v(
1818
V,
1919
sm_scale,
2020
B_Start_Loc,
21+
B_Kv_Start_Loc,
2122
B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度
2223
Out,
2324
stride_q_bs,
@@ -44,7 +45,8 @@ def _fwd_kernel_with_v(
4445

4546
cur_k_head = cur_head
4647

47-
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
48+
cur_batch_in_q_start_index = tl.load(B_Start_Loc + cur_batch)
49+
cur_batch_in_kv_start_index = tl.load(B_Kv_Start_Loc + cur_batch)
4850
prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)
4951
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len
5052

@@ -55,9 +57,9 @@ def _fwd_kernel_with_v(
5557
offs_d = tl.arange(0, BLOCK_DMODEL)
5658
offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)
5759
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
58-
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :]
60+
off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :]
5961
off_q_rope = (
60-
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs
62+
(cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs
6163
+ cur_head * stride_q_rope_h
6264
+ offs_rope_d[None, :]
6365
)
@@ -84,12 +86,12 @@ def _fwd_kernel_with_v(
8486
start_n = tl.multiple_of(start_n, BLOCK_N)
8587
# -- compute qk ----
8688
k = tl.load(
87-
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_bs,
89+
k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs,
8890
mask=(start_n + offs_n[None, :]) < block_end_loc,
8991
other=0.0,
9092
)
9193
k_rope = tl.load(
92-
k_rope_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_rope_bs,
94+
k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs,
9395
mask=(start_n + offs_n[None, :]) < block_end_loc,
9496
other=0.0,
9597
)
@@ -119,7 +121,7 @@ def _fwd_kernel_with_v(
119121
acc = acc * acc_scale[:, None]
120122
# update acc
121123
v = tl.load(
122-
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
124+
v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs,
123125
mask=(start_n + offs_n[:, None]) < block_end_loc,
124126
other=0.0,
125127
)
@@ -129,7 +131,7 @@ def _fwd_kernel_with_v(
129131
l_i = l_i_new
130132
m_i = m_i_new
131133
# initialize pointers to output
132-
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]
134+
off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]
133135
out_ptrs = Out + off_o
134136
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
135137
return
@@ -144,6 +146,7 @@ def context_attention_fwd_with_v(
144146
v,
145147
o,
146148
b_start_loc,
149+
b_kv_start_loc,
147150
b_seq_len,
148151
b_prompt_cache_len,
149152
max_input_len,
@@ -181,6 +184,7 @@ def context_attention_fwd_with_v(
181184
v,
182185
sm_scale,
183186
b_start_loc,
187+
b_kv_start_loc,
184188
b_seq_len,
185189
o,
186190
q_nope.stride(0),
@@ -204,3 +208,78 @@ def context_attention_fwd_with_v(
204208
num_stages=1,
205209
)
206210
return
211+
212+
213+
if __name__ == "__main__":
214+
import torch
215+
import flashinfer
216+
217+
Z, N_CTX, H, D_HEAD, ROPE_HEAD = 32, 1024, 16, 128, 64
218+
dtype = torch.bfloat16
219+
220+
k_nope = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda")
221+
k_rope = torch.randn((Z * N_CTX, 1, ROPE_HEAD), dtype=dtype, device="cuda")
222+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, H, dim=-2)], dim=-1)
223+
v = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda")
224+
225+
max_input_len = Z * N_CTX
226+
softmax_scale = 0.117
227+
b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX
228+
b_prompt_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda")
229+
b_prompt_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda")
230+
q_lens = b_seq_len - b_prompt_cache_len
231+
q_start_loc = q_lens.cumsum(0) - q_lens
232+
kv_start_loc = b_seq_len.cumsum(0) - b_seq_len
233+
234+
q_nope = torch.randn((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")
235+
q_rope = torch.randn((q_lens.sum(), H, ROPE_HEAD), dtype=dtype, device="cuda")
236+
q = torch.cat([q_nope, q_rope], dim=-1)
237+
238+
o = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")
239+
o1 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")
240+
o2 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")
241+
242+
fn1 = lambda: context_attention_fwd_with_v(
243+
q_nope,
244+
q_rope,
245+
k_nope,
246+
k_rope,
247+
v,
248+
o,
249+
q_start_loc,
250+
kv_start_loc,
251+
b_seq_len,
252+
b_prompt_cache_len,
253+
max_input_len,
254+
softmax_scale,
255+
)
256+
257+
q_starts = torch.zeros((Z + 1,)).int().cuda()
258+
q_starts[1:] = torch.cumsum(b_seq_len - b_prompt_cache_len, dim=0)
259+
kv_starts = torch.zeros_like(q_starts)
260+
kv_starts[1:] = torch.cumsum(b_seq_len, dim=0)
261+
kv_layout = "NHD"
262+
batch_size = Z
263+
q_indptr = q_starts
264+
kv_indptr = kv_starts
265+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
266+
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, kv_layout)
267+
wrapper.plan(
268+
qo_indptr=q_indptr,
269+
kv_indptr=kv_indptr,
270+
num_qo_heads=H,
271+
num_kv_heads=H,
272+
head_dim_qk=D_HEAD + ROPE_HEAD,
273+
head_dim_vo=D_HEAD,
274+
q_data_type=dtype,
275+
causal=True,
276+
sm_scale=softmax_scale,
277+
)
278+
fn2 = lambda: wrapper.run(q, k, v, out=o1)
279+
280+
ms1 = triton.testing.do_bench(fn1)
281+
ms2 = triton.testing.do_bench(fn2)
282+
cos_sim1 = F.cosine_similarity(o, o1).mean()
283+
print(cos_sim1)
284+
print(ms1)
285+
print(ms2)

0 commit comments

Comments
 (0)