diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py new file mode 100644 index 000000000..7aae30ed1 --- /dev/null +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -0,0 +1,103 @@ +import os +import torch +import numpy as np +import torch.distributed as dist +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.utils.envs_utils import enable_env_vars +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index + + +class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): + def __init__(self): + super().__init__() + self.prefill_wrapper = None + self.decode_wrapper = None + self.flashinfer_extra_state = None + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + super().init_some_extra_state(model, input_ids) + self.flashinfer_extra_state = model.flashinfer_extra_state + + import flashinfer + + if not self.is_prefill: + if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA"): + self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) + self.kv_indices = torch.empty( + self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32 + ).to(input_ids.device) + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + self.b_start_loc, + self.max_len_in_batch, + self.kv_indices, + ) + if self.decode_wrapper is None: + self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.flashinfer_extra_state.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.q_indptr, + kv_indices=self.kv_indices, + kv_indptr=self.kv_starts, + kv_len_arr=self.b_seq_len, + ) + self.decode_wrapper.plan( + self.q_indptr, + self.kv_starts, + self.kv_indices, + self.b_seq_len, + self.flashinfer_extra_state.tp_q_head_num, + self.flashinfer_extra_state.kv_lora_rank, + self.flashinfer_extra_state.qk_rope_head_dim, + 1, + False, # causal + self.flashinfer_extra_state.softmax_scale, + self.flashinfer_extra_state.q_data_type, + self.flashinfer_extra_state.kv_data_type, + ) + else: + if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"): + q_starts = torch.cat( + [self.b_start_loc, self.b_start_loc[-1:] + (self.b_seq_len - self.b_ready_cache_len)[-1:]], dim=0 + ).int() + kv_starts = torch.cat( + [self.b_kv_start_loc, self.b_kv_start_loc[-1:] + self.b_seq_len[-1:]], dim=0 + ).int() + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.flashinfer_extra_state.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.flashinfer_extra_state.tp_q_head_num, + num_kv_heads=self.flashinfer_extra_state.tp_q_head_num, + head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim + + self.flashinfer_extra_state.qk_rope_head_dim, + head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim, + q_data_type=self.flashinfer_extra_state.q_data_type, + causal=True, + sm_scale=self.flashinfer_extra_state.softmax_scale, + ) + return + + def copy_for_cuda_graph(self, new_infer_state): + super().copy_for_cuda_graph(new_infer_state) + if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA") and not self.is_prefill: + self.decode_wrapper.plan( + new_infer_state.q_indptr, + new_infer_state.kv_starts, + new_infer_state.kv_indices, + new_infer_state.b_seq_len, + new_infer_state.flashinfer_extra_state.tp_q_head_num, + new_infer_state.flashinfer_extra_state.kv_lora_rank, + new_infer_state.flashinfer_extra_state.qk_rope_head_dim, + 1, + False, # causal + new_infer_state.flashinfer_extra_state.softmax_scale, + new_infer_state.flashinfer_extra_state.q_data_type, + new_infer_state.flashinfer_extra_state.kv_data_type, + ) + return diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index c0e66cc88..b39e784cd 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -5,6 +5,7 @@ import numpy as np from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv +from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd, context_attention_fwd_no_prompt_cache, @@ -20,10 +21,11 @@ from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale import os -from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod +from lightllm.utils.envs_utils import enable_env_vars class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer): @@ -67,7 +69,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): self.tp_o_head_num_ = self.tp_q_head_num_ self.num_heads = network_config["num_attention_heads"] self.num_kv_heads = network_config["num_key_value_heads"] - self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"] return def _bind_func(self): @@ -96,18 +97,33 @@ def _bind_attention(self): ) else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self - ) - if self.enable_cc_method: - if "triton_fp8kv" in self.mode: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self + if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA"): + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self ) else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self ) + if self.enable_cc_method: + if "triton_fp8kv" in self.mode: + if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"): + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self + ) + else: + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self + ) + else: + if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"): + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self + ) + else: + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self + ) else: if "triton_fp8kv" in self.mode: self._context_attention_kernel = partial( @@ -205,6 +221,38 @@ def _decompress_kv( k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) return k_nope, k_rope, v + def _context_attention_flashinfer_kernel_with_CC( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2FlashInferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False) + o_tensor = ( + self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out + ) + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) + infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + + def _context_attention_flashinfer_kernel_with_CC_fp8( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2FlashInferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True) + o_tensor = ( + self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out + ) + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) + infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + def _context_attention_kernel_with_CC( self, q: torch.Tensor, @@ -345,6 +393,25 @@ def _context_attention_kernel_origin_fp8( return o_tensor + def _token_gqa_decode_attention_flashinfer( + self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) + + infer_state.decode_wrapper.run( + q_nope, + q_rope, + kv[:, :, : -self.qk_rope_head_dim], + kv[:, :, -self.qk_rope_head_dim :], + out=o_tensor, + return_lse=False, + ) + return o_tensor + def _token_gqa_decode_attention_flashdecoding( self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): @@ -354,7 +421,7 @@ def _token_gqa_decode_attention_flashdecoding( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) - if self.enable_opt_decoding_mha: + if enable_env_vars("ENABLE_OPT_DECODE_MHA"): q = torch.cat([q_nope, q_rope], dim=-1) q_nope, q_rope = None, None import lightllm_ppl_mla @@ -368,7 +435,7 @@ def _token_gqa_decode_attention_flashdecoding( infer_state.b_req_idx, self.softmax_scale, q.shape[-1], - q_nope.shape[-1], + self.kv_lora_rank, ) return o_tensor else: @@ -421,16 +488,13 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): return def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager): - quant_method = vLLMFP8w8a8QuantizationMethod() - quant, scale = quant_method.quantize_scaled_mm_fp8(buffer.reshape(-1, buffer.shape[-1])) - destindex_copy_kv( - quant.T.unsqueeze(1)[:, :, : self.kv_lora_rank].view(torch.uint8), - quant.T.unsqueeze(1)[:, :, self.kv_lora_rank :].view(torch.uint8), + destindex_copy_kv_fp8( + buffer[:, :, : self.kv_lora_rank], + buffer[:, :, self.kv_lora_rank :], mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank], - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2], - mem_manager.kv_buffer[self.layer_num_][:, :, -2:], - scale.to(buffer.dtype).view(torch.uint8), + mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn), + mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn), + mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype), ) return diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index d7ffec407..cc8d81e45 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -2,17 +2,41 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale +from lightllm.utils.envs_utils import enable_env_vars logger = init_logger(__name__) +class FlashInferStateExtraInfo: + def __init__(self, model): + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads if enable_env_vars("ENABLE_DP") else num_heads // model.world_size_ + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(model.tp_rank_) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + class Deepseek2TpPartModel(LlamaTpPartModel): # weight class transformer_weight_class = Deepseek2TransformerLayerWeight @@ -20,8 +44,12 @@ class Deepseek2TpPartModel(LlamaTpPartModel): # infer class transformer_layer_infer_class = Deepseek2TransformerLayerInfer + enable_flashinfer = enable_env_vars("ENABLE_FLASHINFER_PREFILLED") or enable_env_vars( + "ENABLE_FLASHINFER_DECODE_MLA" + ) + # infer state class - infer_state_class = Deepseek2InferStateInfo + infer_state_class = Deepseek2FlashInferStateInfo if enable_flashinfer else Deepseek2InferStateInfo def __init__(self, kvargs): super().__init__(kvargs) @@ -37,6 +65,8 @@ def _init_some_value(self): self.q_lora_rank = self.config["q_lora_rank"] self.kv_lora_rank = self.config["kv_lora_rank"] self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim + if self.enable_flashinfer: + self.flashinfer_extra_state = FlashInferStateExtraInfo(self) def _init_custom(self): self._init_to_get_yarn_rotary() diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py index 9295ae5e3..5b922604d 100644 --- a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py @@ -12,11 +12,9 @@ def _is_power_of_two(n): def _fwd_kernel_destindex_copy_kv( KV_nope, KV_rope, - KV_scale, Dest_loc, O_nope, O_rope, - O_scale, stride_kv_nope_bs, stride_kv_nope_h, stride_kv_nope_d, @@ -29,9 +27,6 @@ def _fwd_kernel_destindex_copy_kv( stride_o_rope_bs, stride_o_rope_h, stride_o_rope_d, - kv_nope_head_num, - kv_rope_head_num, - HAS_SCALE: tl.constexpr, BLOCK_DMODEL_NOPE: tl.constexpr, BLOCK_DMODEL_ROPE: tl.constexpr, ): @@ -50,24 +45,14 @@ def _fwd_kernel_destindex_copy_kv( kv_nope = tl.load(kv_nope_ptrs) kv_rope = tl.load(kv_rope_ptrs) - if HAS_SCALE: - offs_d_scale = tl.arange(0, 2) - o_scale_ptrs = O_scale + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_scale[None, :] - kv_scale_ptrs = KV_scale + cur_index * 2 + offs_d_scale[None, :] - kv_scale = tl.load(kv_scale_ptrs) - tl.store(o_scale_ptrs, kv_scale) - tl.store(o_nope_ptrs, kv_nope) tl.store(o_rope_ptrs, kv_rope) return @torch.no_grad() -def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope, O_scale=None, KV_scale=None): +def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope): seq_len = DestLoc.shape[0] - kv_nope_head_num = KV_nope.shape[1] - kv_rope_head_num = KV_rope.shape[1] - kv_nope_head_dim = KV_nope.shape[2] kv_rope_head_dim = KV_rope.shape[2] @@ -81,11 +66,9 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope, O_scale=None, K _fwd_kernel_destindex_copy_kv[grid]( KV_nope, KV_rope, - KV_scale, DestLoc, O_nope, O_rope, - O_scale, KV_nope.stride(0), KV_nope.stride(1), KV_nope.stride(2), @@ -98,9 +81,6 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope, O_scale=None, K O_rope.stride(0), O_rope.stride(1), O_rope.stride(2), - kv_nope_head_num, - kv_rope_head_num, - HAS_SCALE=KV_scale is not None, BLOCK_DMODEL_NOPE=kv_nope_head_dim, BLOCK_DMODEL_ROPE=kv_rope_head_dim, num_warps=num_warps, @@ -109,23 +89,19 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope, O_scale=None, K return -def test1(): - B, N_CTX, H, H1, D, D1 = 32, 1024, 12, 1, 128, 64 - KV_nope = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - KV_rope = torch.randn((B * N_CTX, H1, D1), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") - O_nope = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - O_rope = torch.randn((B * N_CTX, H1, D1), dtype=torch.float16).cuda() - - destindex_copy_kv(KV_nope, KV_rope, dest_loc, O_nope, O_rope) +if __name__ == "__main__": + import torch.nn.functional as F - print("max ", torch.max(torch.abs(O_nope - KV_nope))) - print("mean ", torch.mean(torch.abs(O_nope - KV_nope))) - assert torch.allclose(O_nope, KV_nope, atol=1e-2, rtol=0) - print("max ", torch.max(torch.abs(O_rope - KV_rope))) - print("mean ", torch.mean(torch.abs(O_rope - KV_rope))) - assert torch.allclose(O_rope, KV_rope, atol=1e-2, rtol=0) + B, N_CTX, H, NOPE_HEAD, ROPE_HEAD = 32, 1024, 1, 512, 64 + dtype = torch.bfloat16 + dest_loc = torch.randint(0, 100, (50,), device="cuda").unique() + kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() + O_nope = torch.zeros((B * N_CTX, H, NOPE_HEAD), dtype=dtype).cuda() + O_rope = torch.zeros((B * N_CTX, H, ROPE_HEAD), dtype=dtype).cuda() + kv_nope = kv[:, :, :NOPE_HEAD] + kv_rope = kv[:, :, NOPE_HEAD:] + destindex_copy_kv(kv_nope, kv_rope, dest_loc, O_nope, O_rope) -if __name__ == "__main__": - test1() + assert torch.allclose(O_nope[dest_loc], kv_nope, atol=1e-2, rtol=0) + assert torch.allclose(O_rope[dest_loc], kv_rope, atol=1e-2, rtol=0) diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py new file mode 100644 index 000000000..d0f676a93 --- /dev/null +++ b/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv_fp8.py @@ -0,0 +1,152 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_kv_fp8( + KV_nope, + KV_rope, + Dest_loc, + O_nope, + O_rope, + O_scale, + stride_kv_nope_bs, + stride_kv_nope_h, + stride_kv_nope_d, + stride_kv_rope_bs, + stride_kv_rope_h, + stride_kv_rope_d, + stride_o_nope_bs, + stride_o_nope_h, + stride_o_nope_d, + stride_o_rope_bs, + stride_o_rope_h, + stride_o_rope_d, + stride_o_scale_bs, + stride_o_scale_h, + stride_o_scale_d, + BLOCK_DMODEL_NOPE: tl.constexpr, + BLOCK_DMODEL_ROPE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE) + offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE) + + dest_index = tl.load(Dest_loc + cur_index) + + kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] + + o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope[None, :] + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope[None, :] + + # to fp8 + kv_nope = tl.load(kv_nope_ptrs) + kv_rope = tl.load(kv_rope_ptrs) + max_nope = tl.max(tl.abs(kv_nope), axis=1) + max_rope = tl.max(tl.abs(kv_rope), axis=1) + max_kv = tl.maximum(tl.maximum(max_nope, max_rope), 1e-12) + kv_scale = (max_kv / FP8_MAX).to(kv_nope.dtype) + kv_nope_fp8 = tl.clamp(kv_nope / kv_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + kv_rope_fp8 = tl.clamp(kv_rope / kv_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + + # save kv_scale + offs_d_scale = tl.arange(0, 1) + o_scale_ptrs = O_scale + dest_index * stride_o_scale_bs + stride_o_scale_d * offs_d_scale + tl.store(o_scale_ptrs, kv_scale) + + # save fp8 kv + tl.store(o_nope_ptrs, kv_nope_fp8) + tl.store(o_rope_ptrs, kv_rope_fp8) + return + + +@torch.no_grad() +def destindex_copy_kv_fp8(KV_nope, KV_rope, DestLoc, O_nope, O_rope, O_scale): + seq_len = DestLoc.shape[0] + kv_nope_head_dim = KV_nope.shape[2] + kv_rope_head_dim = KV_rope.shape[2] + + assert KV_nope.shape[1] == O_nope.shape[1] + assert KV_nope.shape[2] == O_nope.shape[2] + assert KV_rope.shape[1] == O_rope.shape[1] + assert KV_rope.shape[2] == O_rope.shape[2] + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_kv_fp8[grid]( + KV_nope, + KV_rope, + DestLoc, + O_nope, + O_rope, + O_scale, + KV_nope.stride(0), + KV_nope.stride(1), + KV_nope.stride(2), + KV_rope.stride(0), + KV_rope.stride(1), + KV_rope.stride(2), + O_nope.stride(0), + O_nope.stride(1), + O_nope.stride(2), + O_rope.stride(0), + O_rope.stride(1), + O_rope.stride(2), + O_scale.stride(0), + O_scale.stride(1), + O_scale.stride(2), + BLOCK_DMODEL_NOPE=kv_nope_head_dim, + BLOCK_DMODEL_ROPE=kv_rope_head_dim, + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + num_warps=num_warps, + num_stages=1, + ) + return + + +def test(): + import torch.nn.functional as F + from .destindex_copy_kv import destindex_copy_kv + + B, N_CTX, H, NOPE_HEAD, ROPE_HEAD = 32, 1024, 1, 512, 64 + dtype = torch.bfloat16 + dest_loc = torch.randint(0, 100, (50,), device="cuda").unique() + kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() + O_nope = torch.zeros((B * N_CTX, H, NOPE_HEAD), dtype=dtype).cuda() + O_rope = torch.zeros((B * N_CTX, H, ROPE_HEAD), dtype=dtype).cuda() + + kv_nope = kv[:, :, :NOPE_HEAD] + kv_rope = kv[:, :, NOPE_HEAD:] + destindex_copy_kv(kv_nope, kv_rope, dest_loc, O_nope, O_rope) + + assert torch.allclose(O_nope[dest_loc], kv_nope, atol=1e-2, rtol=0) + assert torch.allclose(O_rope[dest_loc], kv_rope, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + import torch.nn.functional as F + + B, N_CTX, H, NOPE_HEAD, ROPE_HEAD = 32, 1024, 1, 512, 64 + dtype = torch.bfloat16 + NUM = 20 + dest_loc = torch.arange(NUM).cuda() + kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() + out = torch.zeros((B * N_CTX, H, NOPE_HEAD + ROPE_HEAD + 2), dtype=torch.uint8).cuda() + + fp8_type = torch.float8_e4m3fn + kv_nope = kv[:, :, :NOPE_HEAD] + kv_rope = kv[:, :, NOPE_HEAD:] + O_nope = out[:, :, :NOPE_HEAD].view(fp8_type) + O_rope = out[:, :, NOPE_HEAD:-2].view(fp8_type) + O_scale = out[:, :, -2:].view(dtype) + destindex_copy_kv_fp8(kv_nope, kv_rope, dest_loc, O_nope, O_rope, O_scale) + + cos1 = F.cosine_similarity(O_nope[:NUM].to(dtype) * O_scale[:NUM], kv_nope).mean() + cos2 = F.cosine_similarity(O_rope[:NUM].to(dtype) * O_scale[:NUM], kv_rope).mean() + print(cos1, cos2) diff --git a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py new file mode 100644 index 000000000..c218d15e0 --- /dev/null +++ b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py @@ -0,0 +1,88 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_repack_kv_index( + kv_index, + req_index, + out_kv_index, + seq_len, + start_loc, + kv_stride_h, + SEQ_BLOCK: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_seq_n = tl.program_id(1) + + cur_batch_seq_len = tl.load(seq_len + cur_batch) + cur_batch_req_idx = tl.load(req_index + cur_batch) + cur_batch_start_loc = tl.load(start_loc + cur_batch) + + offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) + kv_index_data = tl.load( + kv_index + kv_stride_h * cur_batch_req_idx + offs_seq, + mask=offs_seq < block_end_loc, + other=0, + ) + out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq + tl.store(out_kv_index_ptr, kv_index_data, mask=offs_seq < block_end_loc) + return + + +@torch.no_grad() +def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): + batch_size = req_index.shape[0] + BLOCK = 64 + grid = ( + batch_size, + triton.cdiv(max_seq_len, BLOCK), + ) + + _fwd_kernel_repack_kv_index[grid]( + kv_index, + req_index, + out_kv_index, + seq_len, + start_loc, + kv_index.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + return + + +def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + output[start : start + sl] = req_to_token_indexs[b][:sl] + + +if __name__ == "__main__": + import torch.nn.functional as F + + BATCH, MAX_SEQ_LEN = 10, 1024 + rand_idx = torch.randperm(2 * MAX_SEQ_LEN * BATCH).cuda().int() + b_req_idx = torch.randperm(BATCH).cuda().int() + b_seq_len = torch.randint(1, MAX_SEQ_LEN, (BATCH,)).cuda().int() + req_to_token_indexs = torch.zeros((2 * BATCH, 2 * MAX_SEQ_LEN)).cuda().int() + b_start_loc = ( + torch.cat([torch.zeros([1], device=b_seq_len.device, dtype=b_seq_len.dtype), b_seq_len[0:-1].cumsum(0)]) + .cuda() + .int() + ) + + output = torch.zeros((b_seq_len.sum(),)).cuda().int() + ref = torch.zeros((b_seq_len.sum(),)).cuda().int() + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + req_to_token_indexs[b][:sl] = rand_idx[start : start + sl] + + fn1 = lambda: repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) + fn2 = lambda: repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) + ms1 = triton.testing.do_bench(fn1) + ms2 = triton.testing.do_bench_cudagraph(fn2) + print(ms1, ms2) + assert torch.allclose(output.float(), ref.float()) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index d4f63860f..c606fd238 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -32,3 +32,8 @@ def get_env_start_args(): start_args: StartArgs = json.loads(os.environ["LIGHTLLM_START_ARGS"]) start_args: StartArgs = EasyDict(start_args) return start_args + + +@lru_cache(maxsize=None) +def enable_env_vars(args): + return os.getenv(args, "False").upper() in ["ON", "TRUE", "1"] diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv.py b/unit_tests/models/deepseek2/test_destindex_copy_kv.py new file mode 100644 index 000000000..1379dc72d --- /dev/null +++ b/unit_tests/models/deepseek2/test_destindex_copy_kv.py @@ -0,0 +1,42 @@ +import torch +import pytest +from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv +from lightllm.utils.log_utils import init_logger +import torch.nn.functional as F + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, heads, nope_head, rope_head, copy_len", + [ + (a, b, c, d, e, f) + for a in [1, 16, 32, 128, 512] + for b in [1024, 2048] + for c in [1] + for d in [512] + for e in [64] + for f in [10, 20, 100, 1024] + ], +) +def test_destindex_copy_kv(batch, seqlen, heads, nope_head, rope_head, copy_len): + B, N_CTX, H, NOPE_HEAD, ROPE_HEAD, COPY_LEN = batch, seqlen, heads, nope_head, rope_head, copy_len + dtype = torch.bfloat16 + dest_loc = torch.randperm(COPY_LEN).cuda() + kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() + O_nope = torch.zeros((B * N_CTX, H, NOPE_HEAD), dtype=dtype).cuda() + O_rope = torch.zeros((B * N_CTX, H, ROPE_HEAD), dtype=dtype).cuda() + + kv_nope = kv[:, :, :NOPE_HEAD] + kv_rope = kv[:, :, NOPE_HEAD:] + destindex_copy_kv(kv_nope, kv_rope, dest_loc, O_nope, O_rope) + + assert torch.allclose(O_nope[dest_loc], kv_nope, atol=1e-2, rtol=0) + assert torch.allclose(O_rope[dest_loc], kv_rope, atol=1e-2, rtol=0) diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py b/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py new file mode 100644 index 000000000..4f9c0a337 --- /dev/null +++ b/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py @@ -0,0 +1,48 @@ +import torch +import pytest +from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 +from lightllm.utils.log_utils import init_logger +import torch.nn.functional as F + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, heads, nope_head, rope_head, copy_len", + [ + (a, b, c, d, e, f) + for a in [1, 16, 32, 128, 512] + for b in [1024, 2048] + for c in [1] + for d in [512] + for e in [64] + for f in [10, 20, 100, 1024] + ], +) +def test_destindex_copy_kv_fp8(batch, seqlen, heads, nope_head, rope_head, copy_len): + B, N_CTX, H, NOPE_HEAD, ROPE_HEAD, COPY_LEN = batch, seqlen, heads, nope_head, rope_head, copy_len + dtype = torch.bfloat16 + NUM = COPY_LEN + dest_loc = torch.arange(NUM).cuda() + kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() + out = torch.zeros((B * N_CTX, H, NOPE_HEAD + ROPE_HEAD + 2), dtype=torch.uint8).cuda() + + fp8_type = torch.float8_e4m3fn + kv_nope = kv[:, :, :NOPE_HEAD] + kv_rope = kv[:, :, NOPE_HEAD:] + O_nope = out[:, :, :NOPE_HEAD].view(fp8_type) + O_rope = out[:, :, NOPE_HEAD:-2].view(fp8_type) + O_scale = out[:, :, -2:].view(dtype) + destindex_copy_kv_fp8(kv_nope, kv_rope, dest_loc, O_nope, O_rope, O_scale) + + cos1 = F.cosine_similarity(O_nope[:NUM].to(dtype) * O_scale[:NUM], kv_nope).mean() + cos2 = F.cosine_similarity(O_rope[:NUM].to(dtype) * O_scale[:NUM], kv_rope).mean() + assert cos1 > 0.98 + assert cos2 > 0.98 diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding.py b/unit_tests/models/deepseek2/test_gqa_flash_decoding.py new file mode 100644 index 000000000..d0bc670ec --- /dev/null +++ b/unit_tests/models/deepseek2/test_gqa_flash_decoding.py @@ -0,0 +1,115 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.common.req_manager import ReqManager + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, heads, nope_head, rope_head", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 2048] + for c in [16] + for d in [512] + for e in [64] + ], +) +def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head): + Z, N_CTX, H, D_HEAD, ROPE_HEAD = batch, seqlen, heads, nope_head, rope_head + dtype = torch.bfloat16 + sm_scale = 1.0 / ((D_HEAD + ROPE_HEAD) ** 0.5) + q_nope = torch.randn((Z, H, D_HEAD), dtype=dtype, device="cuda") + q_rope = torch.randn((Z, H, ROPE_HEAD), dtype=dtype, device="cuda") + + kv = torch.randn((Z * N_CTX, 1, D_HEAD + ROPE_HEAD), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = torch.arange(Z).cuda().int() * N_CTX + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + kv_starts = torch.cat([b_start_loc, b_start_loc[-1:] + b_seq_len[-1:]], dim=0) + + o = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda") + + infer_state = Deepseek2InferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager.req_to_token_indexs = req_to_token_indexs + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.kv_starts = kv_starts + + kv_nope = kv[:, :, :D_HEAD] + kv_rope = kv[:, :, D_HEAD:] + gqa_token_decode_attention_flash_decoding( + q_nope, + q_rope, + kv_nope, + kv_rope, + infer_state, + H, + D_HEAD, + ROPE_HEAD, + D_HEAD, + sm_scale, + o, + ) + + batch_size = Z + head_dim_ckv = D_HEAD + head_dim_kpe = ROPE_HEAD + num_heads = H + page_size = 1 + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + q_indptr = torch.arange(batch_size + 1).to(0).int() + kv_indptr = infer_state.kv_starts + kv_indices = torch.arange(Z * N_CTX).cuda().int() + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] + kv_lens = b_seq_len + wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + workspace_buffer, + use_cuda_graph=True, + qo_indptr=q_indptr, + kv_indices=kv_indices, + kv_indptr=kv_indptr, + kv_len_arr=kv_lens, + ) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + num_heads, + head_dim_ckv, + head_dim_kpe, + page_size, + False, # causal + sm_scale, + q_nope.dtype, + kv.dtype, + ) + wrapper.run(q_nope, q_rope, kv_nope, kv_rope, out=o1, return_lse=False) + + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py b/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py new file mode 100644 index 000000000..72d9d9acc --- /dev/null +++ b/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py @@ -0,0 +1,79 @@ +import torch +import pytest +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding +from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.common.req_manager import ReqManager + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, heads, nope_head, rope_head", + [(a, b, c, d, e) for a in [1, 16, 32, 128] for b in [16, 32, 512, 2048] for c in [16] for d in [512] for e in [64]], +) +def test_gqa_flash_decoding_fp8(batch, seqlen, heads, nope_head, rope_head): + Z, N_CTX, H, D_HEAD, ROPE_HEAD = batch, seqlen, heads, nope_head, rope_head + dtype = torch.bfloat16 + sm_scale = 1.0 / ((D_HEAD + ROPE_HEAD) ** 0.5) + q = torch.randn((Z, H, D_HEAD), dtype=dtype, device="cuda") + q_rope = torch.randn((Z, H, ROPE_HEAD), dtype=dtype, device="cuda") + + kv = torch.randn((Z * N_CTX, 1, D_HEAD + ROPE_HEAD), dtype=dtype, device="cuda") + kv_scale = torch.randn((Z * N_CTX, 1, 1), dtype=dtype, device="cuda") + kv_fp8 = kv.to(torch.float8_e4m3fn) + + req_to_token_indexs = torch.zeros((10, Z * N_CTX), dtype=torch.int32, device="cuda") + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") + + b_seq_len[0] = N_CTX + b_req_idx[0] = 0 + req_to_token_indexs[0][:N_CTX] = torch.tensor(np.arange(N_CTX), dtype=torch.int32).cuda() + + o = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") + o1 = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") + + infer_state = Deepseek2InferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager.req_to_token_indexs = req_to_token_indexs + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + + kv_nope = kv_fp8[:, :, :D_HEAD].to(dtype) * kv_scale + kv_rope = kv_fp8[:, :, D_HEAD:].to(dtype) * kv_scale + gqa_token_decode_attention_flash_decoding( + q, + q_rope, + kv_nope, + kv_rope, + infer_state, + H, + D_HEAD, + ROPE_HEAD, + D_HEAD, + sm_scale, + o, + ) + + kv_nope_fp8 = kv_fp8[:, :, :D_HEAD] + kv_rope_fp8 = kv_fp8[:, :, D_HEAD:] + gqa_token_decode_attention_flash_decoding_fp8( + q, q_rope, kv_nope_fp8, kv_rope_fp8, kv_scale, infer_state, H, D_HEAD, ROPE_HEAD, D_HEAD, sm_scale, o1 + ) + + cos_sim = F.cosine_similarity(o, o1).mean() + assert cos_sim > 0.99 diff --git a/unit_tests/models/deepseek2/test_repack_kv_index.py b/unit_tests/models/deepseek2/test_repack_kv_index.py new file mode 100644 index 000000000..f9e5928a9 --- /dev/null +++ b/unit_tests/models/deepseek2/test_repack_kv_index.py @@ -0,0 +1,43 @@ +import torch +import pytest +from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, max_seq_len", + [(a, b) for a in [1, 16, 32, 128, 512] for b in [16, 32, 512, 2048]], +) +def test_repack_kv_index(batch, max_seq_len): + def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + output[start : start + sl] = req_to_token_indexs[b][:sl] + + BATCH, MAX_SEQ_LEN = batch, max_seq_len + rand_idx = torch.randperm(2 * MAX_SEQ_LEN * BATCH).cuda().int() + b_req_idx = torch.randperm(BATCH).cuda().int() + b_seq_len = torch.randint(1, MAX_SEQ_LEN, (BATCH,)).cuda().int() + req_to_token_indexs = torch.zeros((2 * BATCH, 2 * MAX_SEQ_LEN)).cuda().int() + b_start_loc = ( + torch.cat([torch.zeros([1], device=b_seq_len.device, dtype=b_seq_len.dtype), b_seq_len[0:-1].cumsum(0)]) + .cuda() + .int() + ) + + output = torch.zeros((b_seq_len.sum(),)).cuda().int() + ref = torch.zeros((b_seq_len.sum(),)).cuda().int() + for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): + req_to_token_indexs[b][:sl] = rand_idx[start : start + sl] + + repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) + repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) + assert torch.allclose(output.float(), ref.float())