Skip to content

Commit 9af2b0c

Browse files
author
niushengxiao
committed
feat: 1. fuse fp8 quant in kv coping in the deepseek2
2. add flashinfer prefill and decode mla operators in the deepseek2
1 parent c483b1e commit 9af2b0c

File tree

13 files changed

+817
-61
lines changed

13 files changed

+817
-61
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class TpPartBaseModel:
3636
infer_state_class = InferStateInfo
3737

3838
def __init__(self, kvargs):
39+
self.infer_state = self.infer_state_class()
3940
self.run_mode = kvargs["run_mode"]
4041
self.tp_rank_ = kvargs["tp_rank"]
4142
self.world_size_ = kvargs["world_size"]
@@ -330,7 +331,9 @@ def _decode(
330331
b_seq_len,
331332
multimodal_params,
332333
):
333-
infer_state = self.infer_state_class()
334+
infer_state = self.infer_state
335+
if self.graph is None or self.graph.need_capture(batch_size) or infer_state.is_prefill:
336+
infer_state = self.infer_state_class()
334337
infer_state.is_prefill = False
335338
infer_state.batch_size = batch_size
336339
infer_state.total_token_num = total_token_num
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import torch.distributed as dist
5+
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
6+
from lightllm.utils.envs_utils import enable_env_vars
7+
import flashinfer
8+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
9+
10+
11+
class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo):
12+
def __init__(self):
13+
super().__init__()
14+
self.prefill_wrapper = None
15+
self.decode_wrapper = None
16+
17+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
18+
super().init_some_extra_state(model, input_ids)
19+
20+
if not self.is_prefill:
21+
if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA"):
22+
self.tp_q_head_num = model.flashinfer_state.tp_q_head_num
23+
self.kv_lora_rank = model.flashinfer_state.kv_lora_rank
24+
self.qk_rope_head_dim = model.flashinfer_state.qk_rope_head_dim
25+
self.qk_nope_head_dim = model.flashinfer_state.qk_nope_head_dim
26+
self.softmax_scale = model.flashinfer_state.softmax_scale
27+
self.q_data_type = model.flashinfer_state.data_type
28+
self.kv_data_type = model.flashinfer_state.data_type
29+
30+
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
31+
self.kv_indices = torch.empty(
32+
self.batch_size * model.flashinfer_state.max_seq_length, dtype=torch.int32
33+
).to(input_ids.device)
34+
repack_kv_index(
35+
self.req_manager.req_to_token_indexs,
36+
self.b_req_idx,
37+
self.b_seq_len,
38+
self.b_start_loc,
39+
self.max_len_in_batch,
40+
self.kv_indices,
41+
)
42+
if self.decode_wrapper is None:
43+
self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
44+
model.flashinfer_state.workspace_buffer,
45+
use_cuda_graph=True,
46+
qo_indptr=self.q_indptr,
47+
kv_indices=self.kv_indices,
48+
kv_indptr=self.kv_starts,
49+
kv_len_arr=self.b_seq_len,
50+
)
51+
self.decode_wrapper.plan(
52+
self.q_indptr,
53+
self.kv_starts,
54+
self.kv_indices,
55+
self.b_seq_len,
56+
self.tp_q_head_num,
57+
self.kv_lora_rank,
58+
self.qk_rope_head_dim,
59+
1,
60+
False, # causal
61+
self.softmax_scale,
62+
self.q_data_type,
63+
self.kv_data_type,
64+
)
65+
else:
66+
if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"):
67+
self.tp_q_head_num = model.flashinfer_state.tp_q_head_num
68+
self.qk_rope_head_dim = model.flashinfer_state.qk_rope_head_dim
69+
self.qk_nope_head_dim = model.flashinfer_state.qk_nope_head_dim
70+
self.softmax_scale = model.flashinfer_state.softmax_scale
71+
self.q_data_type = model.flashinfer_state.data_type
72+
73+
q_starts = torch.cat(
74+
[self.b_start_loc, self.b_start_loc[-1:] + (self.b_seq_len - self.b_ready_cache_len)[-1:]], dim=0
75+
).int()
76+
kv_starts = torch.cat(
77+
[self.b_kv_start_loc, self.b_kv_start_loc[-1:] + self.b_seq_len[-1:]], dim=0
78+
).int()
79+
if self.prefill_wrapper is None:
80+
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
81+
model.flashinfer_state.workspace_buffer, "NHD"
82+
)
83+
self.prefill_wrapper.plan(
84+
qo_indptr=q_starts,
85+
kv_indptr=kv_starts,
86+
num_qo_heads=self.tp_q_head_num,
87+
num_kv_heads=self.tp_q_head_num,
88+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
89+
head_dim_vo=self.qk_nope_head_dim,
90+
q_data_type=self.q_data_type,
91+
causal=True,
92+
sm_scale=self.softmax_scale,
93+
)
94+
return
95+
96+
def copy_for_cuda_graph(self, new_infer_state):
97+
super().copy_for_cuda_graph(new_infer_state)
98+
if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA") and not self.is_prefill:
99+
self.decode_wrapper.plan(
100+
new_infer_state.q_indptr,
101+
new_infer_state.kv_starts,
102+
new_infer_state.kv_indices,
103+
new_infer_state.b_seq_len,
104+
new_infer_state.tp_q_head_num,
105+
new_infer_state.kv_lora_rank,
106+
new_infer_state.qk_rope_head_dim,
107+
1,
108+
False, # causal
109+
new_infer_state.softmax_scale,
110+
new_infer_state.q_data_type,
111+
new_infer_state.kv_data_type,
112+
)
113+
return

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 85 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
77
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv
8+
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8
89
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import (
910
context_attention_fwd,
1011
context_attention_fwd_no_prompt_cache,
@@ -23,7 +24,7 @@
2324
from functools import partial
2425
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2526
import os
26-
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
27+
from lightllm.utils.envs_utils import enable_env_vars
2728

2829

2930
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -67,7 +68,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6768
self.tp_o_head_num_ = self.tp_q_head_num_
6869
self.num_heads = network_config["num_attention_heads"]
6970
self.num_kv_heads = network_config["num_key_value_heads"]
70-
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
7171
return
7272

7373
def _bind_func(self):
@@ -96,18 +96,33 @@ def _bind_attention(self):
9696
)
9797
else:
9898
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
99-
self._token_attention_kernel = partial(
100-
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
101-
)
102-
if self.enable_cc_method:
103-
if "triton_fp8kv" in self.mode:
104-
self._context_attention_kernel = partial(
105-
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
99+
if enable_env_vars("ENABLE_FLASHINFER_DECODE_MLA"):
100+
self._token_attention_kernel = partial(
101+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self
106102
)
107103
else:
108-
self._context_attention_kernel = partial(
109-
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
104+
self._token_attention_kernel = partial(
105+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
110106
)
107+
if self.enable_cc_method:
108+
if "triton_fp8kv" in self.mode:
109+
if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"):
110+
self._context_attention_kernel = partial(
111+
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self
112+
)
113+
else:
114+
self._context_attention_kernel = partial(
115+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
116+
)
117+
else:
118+
if enable_env_vars("ENABLE_FLASHINFER_PREFILLED"):
119+
self._context_attention_kernel = partial(
120+
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self
121+
)
122+
else:
123+
self._context_attention_kernel = partial(
124+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
125+
)
111126
else:
112127
if "triton_fp8kv" in self.mode:
113128
self._context_attention_kernel = partial(
@@ -205,6 +220,38 @@ def _decompress_kv(
205220
k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
206221
return k_nope, k_rope, v
207222

223+
def _context_attention_flashinfer_kernel_with_CC(
224+
self,
225+
q: torch.Tensor,
226+
kv,
227+
infer_state: Deepseek2InferStateInfo,
228+
layer_weight: Deepseek2TransformerLayerWeight,
229+
out=None,
230+
) -> torch.Tensor:
231+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
232+
o_tensor = (
233+
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
234+
)
235+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
236+
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
237+
return o_tensor
238+
239+
def _context_attention_flashinfer_kernel_with_CC_fp8(
240+
self,
241+
q: torch.Tensor,
242+
kv,
243+
infer_state: Deepseek2InferStateInfo,
244+
layer_weight: Deepseek2TransformerLayerWeight,
245+
out=None,
246+
) -> torch.Tensor:
247+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
248+
o_tensor = (
249+
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
250+
)
251+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
252+
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
253+
return o_tensor
254+
208255
def _context_attention_kernel_with_CC(
209256
self,
210257
q: torch.Tensor,
@@ -345,6 +392,25 @@ def _context_attention_kernel_origin_fp8(
345392

346393
return o_tensor
347394

395+
def _token_gqa_decode_attention_flashinfer(
396+
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
397+
):
398+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
399+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
400+
401+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
402+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)
403+
404+
infer_state.decode_wrapper.run(
405+
q_nope,
406+
q_rope,
407+
kv[:, :, : -self.qk_rope_head_dim],
408+
kv[:, :, -self.qk_rope_head_dim :],
409+
out=o_tensor,
410+
return_lse=False,
411+
)
412+
return o_tensor
413+
348414
def _token_gqa_decode_attention_flashdecoding(
349415
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
350416
):
@@ -354,7 +420,7 @@ def _token_gqa_decode_attention_flashdecoding(
354420
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
355421
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)
356422

357-
if self.enable_opt_decoding_mha:
423+
if enable_env_vars("ENABLE_OPT_DECODE_MHA"):
358424
q = torch.cat([q_nope, q_rope], dim=-1)
359425
q_nope, q_rope = None, None
360426
import lightllm_ppl_mla
@@ -368,7 +434,7 @@ def _token_gqa_decode_attention_flashdecoding(
368434
infer_state.b_req_idx,
369435
self.softmax_scale,
370436
q.shape[-1],
371-
q_nope.shape[-1],
437+
self.kv_lora_rank,
372438
)
373439
return o_tensor
374440
else:
@@ -421,16 +487,13 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
421487
return
422488

423489
def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager):
424-
quant_method = vLLMFP8w8a8QuantizationMethod()
425-
quant, scale = quant_method.quantize_scaled_mm_fp8(buffer.reshape(-1, buffer.shape[-1]))
426-
destindex_copy_kv(
427-
quant.T.unsqueeze(1)[:, :, : self.kv_lora_rank].view(torch.uint8),
428-
quant.T.unsqueeze(1)[:, :, self.kv_lora_rank :].view(torch.uint8),
490+
destindex_copy_kv_fp8(
491+
buffer[:, :, : self.kv_lora_rank],
492+
buffer[:, :, self.kv_lora_rank :],
429493
mem_index,
430-
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank],
431-
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2],
432-
mem_manager.kv_buffer[self.layer_num_][:, :, -2:],
433-
scale.to(buffer.dtype).view(torch.uint8),
494+
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn),
495+
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn),
496+
mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype),
434497
)
435498
return
436499

lightllm/models/deepseek2/model.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,53 @@
22
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
33
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
44
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
5+
from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo
56
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
67

78
from lightllm.models.llama.model import LlamaTpPartModel
89
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
910
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1011
from lightllm.utils.log_utils import init_logger
12+
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
13+
from lightllm.utils.envs_utils import enable_env_vars
1114

1215

1316
logger = init_logger(__name__)
1417

1518

19+
class FlashInferStateExtraInfo:
20+
def __init__(self, model):
21+
num_heads = model.config["num_attention_heads"]
22+
self.tp_q_head_num = num_heads if enable_env_vars("ENABLE_DP") else num_heads // model.world_size_
23+
self.qk_nope_head_dim = model.qk_nope_head_dim
24+
self.qk_rope_head_dim = model.qk_rope_head_dim
25+
self.kv_lora_rank = model.kv_lora_rank
26+
self.data_type = model.data_type
27+
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(model.tp_rank_)
28+
self.max_seq_length = model.max_seq_length
29+
self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5)
30+
if model.config["rope_scaling"] is not None:
31+
rope_scaling = model.config["rope_scaling"]
32+
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0)
33+
scaling_factor = rope_scaling["factor"]
34+
if mscale_all_dim:
35+
mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim)
36+
self.softmax_scale = self.softmax_scale * mscale * mscale
37+
38+
1639
class Deepseek2TpPartModel(LlamaTpPartModel):
1740
# weight class
1841
transformer_weight_class = Deepseek2TransformerLayerWeight
1942

2043
# infer class
2144
transformer_layer_infer_class = Deepseek2TransformerLayerInfer
2245

46+
enable_flashinfer = enable_env_vars("ENABLE_FLASHINFER_PREFILLED") or enable_env_vars(
47+
"ENABLE_FLASHINFER_DECODE_MLA"
48+
)
49+
2350
# infer state class
24-
infer_state_class = Deepseek2InferStateInfo
51+
infer_state_class = Deepseek2FlashInferStateInfo if enable_flashinfer else Deepseek2InferStateInfo
2552

2653
def __init__(self, kvargs):
2754
super().__init__(kvargs)
@@ -37,6 +64,8 @@ def _init_some_value(self):
3764
self.q_lora_rank = self.config["q_lora_rank"]
3865
self.kv_lora_rank = self.config["kv_lora_rank"]
3966
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
67+
if self.enable_flashinfer:
68+
self.flashinfer_state = FlashInferStateExtraInfo(self)
4069

4170
def _init_custom(self):
4271
self._init_to_get_yarn_rotary()

0 commit comments

Comments
 (0)