Skip to content

fix: fix a precision bug in the context_flashattention #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _context_attention_kernel_with_CC(
v,
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
infer_state.b_start_loc,
infer_state.b_kv_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
Expand All @@ -253,6 +254,7 @@ def _context_attention_kernel_with_CC_fp8(
v,
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
infer_state.b_start_loc,
infer_state.b_kv_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def _fwd_kernel_with_v(
V,
sm_scale,
B_Start_Loc,
B_Kv_Start_Loc,
B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度
Out,
stride_q_bs,
Expand All @@ -44,7 +45,8 @@ def _fwd_kernel_with_v(

cur_k_head = cur_head

cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_q_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_kv_start_index = tl.load(B_Kv_Start_Loc + cur_batch)
prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len

Expand All @@ -55,9 +57,9 @@ def _fwd_kernel_with_v(
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :]
off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :]
off_q_rope = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs
(cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs
+ cur_head * stride_q_rope_h
+ offs_rope_d[None, :]
)
Expand All @@ -84,12 +86,12 @@ def _fwd_kernel_with_v(
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_bs,
k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs,
mask=(start_n + offs_n[None, :]) < block_end_loc,
other=0.0,
)
k_rope = tl.load(
k_rope_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_rope_bs,
k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs,
mask=(start_n + offs_n[None, :]) < block_end_loc,
other=0.0,
)
Expand Down Expand Up @@ -119,7 +121,7 @@ def _fwd_kernel_with_v(
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < block_end_loc,
other=0.0,
)
Expand All @@ -129,7 +131,7 @@ def _fwd_kernel_with_v(
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]
off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :]
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
Expand All @@ -144,6 +146,7 @@ def context_attention_fwd_with_v(
v,
o,
b_start_loc,
b_kv_start_loc,
b_seq_len,
b_prompt_cache_len,
max_input_len,
Expand Down Expand Up @@ -181,6 +184,7 @@ def context_attention_fwd_with_v(
v,
sm_scale,
b_start_loc,
b_kv_start_loc,
b_seq_len,
o,
q_nope.stride(0),
Expand All @@ -204,3 +208,78 @@ def context_attention_fwd_with_v(
num_stages=1,
)
return


if __name__ == "__main__":
import torch
import flashinfer

Z, N_CTX, H, D_HEAD, ROPE_HEAD = 32, 1024, 16, 128, 64
dtype = torch.bfloat16

k_nope = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda")
k_rope = torch.randn((Z * N_CTX, 1, ROPE_HEAD), dtype=dtype, device="cuda")
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, H, dim=-2)], dim=-1)
v = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda")

max_input_len = Z * N_CTX
softmax_scale = 0.117
b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX
b_prompt_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda")
b_prompt_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda")
q_lens = b_seq_len - b_prompt_cache_len
q_start_loc = q_lens.cumsum(0) - q_lens
kv_start_loc = b_seq_len.cumsum(0) - b_seq_len

q_nope = torch.randn((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")
q_rope = torch.randn((q_lens.sum(), H, ROPE_HEAD), dtype=dtype, device="cuda")
q = torch.cat([q_nope, q_rope], dim=-1)

o = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")
o1 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")
o2 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda")

fn1 = lambda: context_attention_fwd_with_v(
q_nope,
q_rope,
k_nope,
k_rope,
v,
o,
q_start_loc,
kv_start_loc,
b_seq_len,
b_prompt_cache_len,
max_input_len,
softmax_scale,
)

q_starts = torch.zeros((Z + 1,)).int().cuda()
q_starts[1:] = torch.cumsum(b_seq_len - b_prompt_cache_len, dim=0)
kv_starts = torch.zeros_like(q_starts)
kv_starts[1:] = torch.cumsum(b_seq_len, dim=0)
kv_layout = "NHD"
batch_size = Z
q_indptr = q_starts
kv_indptr = kv_starts
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, kv_layout)
wrapper.plan(
qo_indptr=q_indptr,
kv_indptr=kv_indptr,
num_qo_heads=H,
num_kv_heads=H,
head_dim_qk=D_HEAD + ROPE_HEAD,
head_dim_vo=D_HEAD,
q_data_type=dtype,
causal=True,
sm_scale=softmax_scale,
)
fn2 = lambda: wrapper.run(q, k, v, out=o1)

ms1 = triton.testing.do_bench(fn1)
ms2 = triton.testing.do_bench(fn2)
cos_sim1 = F.cosine_similarity(o, o1).mean()
print(cos_sim1)
print(ms1)
print(ms2)