diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index c7012015b..8913bcd71 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -227,6 +227,7 @@ def _context_attention_kernel_with_CC( v, o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, infer_state.b_ready_cache_len, infer_state.max_len_in_batch, @@ -253,6 +254,7 @@ def _context_attention_kernel_with_CC_fp8( v, o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), infer_state.b_start_loc, + infer_state.b_kv_start_loc, infer_state.b_seq_len, infer_state.b_ready_cache_len, infer_state.max_len_in_batch, diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py index 04e6facbb..0efb77a68 100644 --- a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py +++ b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py @@ -18,6 +18,7 @@ def _fwd_kernel_with_v( V, sm_scale, B_Start_Loc, + B_Kv_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 Out, stride_q_bs, @@ -44,7 +45,8 @@ def _fwd_kernel_with_v( cur_k_head = cur_head - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_q_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_kv_start_index = tl.load(B_Kv_Start_Loc + cur_batch) prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len @@ -55,9 +57,9 @@ def _fwd_kernel_with_v( offs_d = tl.arange(0, BLOCK_DMODEL) offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] + off_q = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] off_q_rope = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs + (cur_batch_in_q_start_index + offs_m[:, None]) * stride_q_rope_bs + cur_head * stride_q_rope_h + offs_rope_d[None, :] ) @@ -84,12 +86,12 @@ def _fwd_kernel_with_v( start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_bs, + k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_bs, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0, ) k_rope = tl.load( - k_rope_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_rope_bs, + k_rope_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_k_rope_bs, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0, ) @@ -119,7 +121,7 @@ def _fwd_kernel_with_v( acc = acc * acc_scale[:, None] # update acc v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0, ) @@ -129,7 +131,7 @@ def _fwd_kernel_with_v( l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] + off_o = (cur_batch_in_q_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return @@ -144,6 +146,7 @@ def context_attention_fwd_with_v( v, o, b_start_loc, + b_kv_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, @@ -181,6 +184,7 @@ def context_attention_fwd_with_v( v, sm_scale, b_start_loc, + b_kv_start_loc, b_seq_len, o, q_nope.stride(0), @@ -204,3 +208,78 @@ def context_attention_fwd_with_v( num_stages=1, ) return + + +if __name__ == "__main__": + import torch + import flashinfer + + Z, N_CTX, H, D_HEAD, ROPE_HEAD = 32, 1024, 16, 128, 64 + dtype = torch.bfloat16 + + k_nope = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda") + k_rope = torch.randn((Z * N_CTX, 1, ROPE_HEAD), dtype=dtype, device="cuda") + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, H, dim=-2)], dim=-1) + v = torch.randn((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + softmax_scale = 0.117 + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_prompt_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + b_prompt_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda") + q_lens = b_seq_len - b_prompt_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + kv_start_loc = b_seq_len.cumsum(0) - b_seq_len + + q_nope = torch.randn((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + q_rope = torch.randn((q_lens.sum(), H, ROPE_HEAD), dtype=dtype, device="cuda") + q = torch.cat([q_nope, q_rope], dim=-1) + + o = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + o1 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + o2 = torch.empty((q_lens.sum(), H, D_HEAD), dtype=dtype, device="cuda") + + fn1 = lambda: context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o, + q_start_loc, + kv_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + softmax_scale, + ) + + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_prompt_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + kv_layout = "NHD" + batch_size = Z + q_indptr = q_starts + kv_indptr = kv_starts + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper.plan( + qo_indptr=q_indptr, + kv_indptr=kv_indptr, + num_qo_heads=H, + num_kv_heads=H, + head_dim_qk=D_HEAD + ROPE_HEAD, + head_dim_vo=D_HEAD, + q_data_type=dtype, + causal=True, + sm_scale=softmax_scale, + ) + fn2 = lambda: wrapper.run(q, k, v, out=o1) + + ms1 = triton.testing.do_bench(fn1) + ms2 = triton.testing.do_bench(fn2) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + print(ms1) + print(ms2)