5
5
import numpy as np
6
6
from lightllm .models .deepseek2 .layer_weights .transformer_layer_weight import Deepseek2TransformerLayerWeight
7
7
from lightllm .models .deepseek2 .triton_kernel .destindex_copy_kv import destindex_copy_kv
8
+ from lightllm .models .deepseek2 .triton_kernel .destindex_copy_kv_fp8 import destindex_copy_kv_fp8
8
9
from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad import (
9
10
context_attention_fwd ,
10
11
context_attention_fwd_no_prompt_cache ,
23
24
from functools import partial
24
25
from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
25
26
import os
26
- from lightllm .common . quantization import vLLMFP8w8a8QuantizationMethod
27
+ from lightllm .utils . envs_utils import enable_env_vars
27
28
28
29
29
30
class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -67,7 +68,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
67
68
self .tp_o_head_num_ = self .tp_q_head_num_
68
69
self .num_heads = network_config ["num_attention_heads" ]
69
70
self .num_kv_heads = network_config ["num_key_value_heads" ]
70
- self .enable_opt_decoding_mha = os .getenv ("ENABLE_OPT_DECODE_MHA" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
71
71
return
72
72
73
73
def _bind_func (self ):
@@ -96,18 +96,33 @@ def _bind_attention(self):
96
96
)
97
97
else :
98
98
self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
99
- self ._token_attention_kernel = partial (
100
- Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
101
- )
102
- if self .enable_cc_method :
103
- if "triton_fp8kv" in self .mode :
104
- self ._context_attention_kernel = partial (
105
- Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
99
+ if enable_env_vars ("ENABLE_FLASHINFER_DECODE_MLA" ):
100
+ self ._token_attention_kernel = partial (
101
+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashinfer , self
106
102
)
107
103
else :
108
- self ._context_attention_kernel = partial (
109
- Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
104
+ self ._token_attention_kernel = partial (
105
+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
110
106
)
107
+ if self .enable_cc_method :
108
+ if "triton_fp8kv" in self .mode :
109
+ if enable_env_vars ("ENABLE_FLASHINFER_PREFILLED" ):
110
+ self ._context_attention_kernel = partial (
111
+ Deepseek2TransformerLayerInfer ._context_attention_flashinfer_kernel_with_CC_fp8 , self
112
+ )
113
+ else :
114
+ self ._context_attention_kernel = partial (
115
+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
116
+ )
117
+ else :
118
+ if enable_env_vars ("ENABLE_FLASHINFER_PREFILLED" ):
119
+ self ._context_attention_kernel = partial (
120
+ Deepseek2TransformerLayerInfer ._context_attention_flashinfer_kernel_with_CC , self
121
+ )
122
+ else :
123
+ self ._context_attention_kernel = partial (
124
+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
125
+ )
111
126
else :
112
127
if "triton_fp8kv" in self .mode :
113
128
self ._context_attention_kernel = partial (
@@ -205,6 +220,38 @@ def _decompress_kv(
205
220
k_nope , v = torch .split (kv_nope , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
206
221
return k_nope , k_rope , v
207
222
223
+ def _context_attention_flashinfer_kernel_with_CC (
224
+ self ,
225
+ q : torch .Tensor ,
226
+ kv ,
227
+ infer_state : Deepseek2InferStateInfo ,
228
+ layer_weight : Deepseek2TransformerLayerWeight ,
229
+ out = None ,
230
+ ) -> torch .Tensor :
231
+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
232
+ o_tensor = (
233
+ self .alloc_tensor ((q .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ), dtype = q .dtype ) if out is None else out
234
+ )
235
+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , self .tp_q_head_num_ , dim = - 2 )], dim = - 1 )
236
+ infer_state .prefill_wrapper .run (q , k , v , out = o_tensor )
237
+ return o_tensor
238
+
239
+ def _context_attention_flashinfer_kernel_with_CC_fp8 (
240
+ self ,
241
+ q : torch .Tensor ,
242
+ kv ,
243
+ infer_state : Deepseek2InferStateInfo ,
244
+ layer_weight : Deepseek2TransformerLayerWeight ,
245
+ out = None ,
246
+ ) -> torch .Tensor :
247
+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , True )
248
+ o_tensor = (
249
+ self .alloc_tensor ((q .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ), dtype = q .dtype ) if out is None else out
250
+ )
251
+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , self .tp_q_head_num_ , dim = - 2 )], dim = - 1 )
252
+ infer_state .prefill_wrapper .run (q , k , v , out = o_tensor )
253
+ return o_tensor
254
+
208
255
def _context_attention_kernel_with_CC (
209
256
self ,
210
257
q : torch .Tensor ,
@@ -345,6 +392,25 @@ def _context_attention_kernel_origin_fp8(
345
392
346
393
return o_tensor
347
394
395
+ def _token_gqa_decode_attention_flashinfer (
396
+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
397
+ ):
398
+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
399
+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
400
+
401
+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
402
+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
403
+
404
+ infer_state .decode_wrapper .run (
405
+ q_nope ,
406
+ q_rope ,
407
+ kv [:, :, : - self .qk_rope_head_dim ],
408
+ kv [:, :, - self .qk_rope_head_dim :],
409
+ out = o_tensor ,
410
+ return_lse = False ,
411
+ )
412
+ return o_tensor
413
+
348
414
def _token_gqa_decode_attention_flashdecoding (
349
415
self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
350
416
):
@@ -354,7 +420,7 @@ def _token_gqa_decode_attention_flashdecoding(
354
420
kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
355
421
o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
356
422
357
- if self . enable_opt_decoding_mha :
423
+ if enable_env_vars ( "ENABLE_OPT_DECODE_MHA" ) :
358
424
q = torch .cat ([q_nope , q_rope ], dim = - 1 )
359
425
q_nope , q_rope = None , None
360
426
import lightllm_ppl_mla
@@ -368,7 +434,7 @@ def _token_gqa_decode_attention_flashdecoding(
368
434
infer_state .b_req_idx ,
369
435
self .softmax_scale ,
370
436
q .shape [- 1 ],
371
- q_nope . shape [ - 1 ] ,
437
+ self . kv_lora_rank ,
372
438
)
373
439
return o_tensor
374
440
else :
@@ -421,16 +487,13 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
421
487
return
422
488
423
489
def _copy_kv_to_mem_cache_fp8 (self , buffer , mem_index , mem_manager ):
424
- quant_method = vLLMFP8w8a8QuantizationMethod ()
425
- quant , scale = quant_method .quantize_scaled_mm_fp8 (buffer .reshape (- 1 , buffer .shape [- 1 ]))
426
- destindex_copy_kv (
427
- quant .T .unsqueeze (1 )[:, :, : self .kv_lora_rank ].view (torch .uint8 ),
428
- quant .T .unsqueeze (1 )[:, :, self .kv_lora_rank :].view (torch .uint8 ),
490
+ destindex_copy_kv_fp8 (
491
+ buffer [:, :, : self .kv_lora_rank ],
492
+ buffer [:, :, self .kv_lora_rank :],
429
493
mem_index ,
430
- mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ],
431
- mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank : - 2 ],
432
- mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :],
433
- scale .to (buffer .dtype ).view (torch .uint8 ),
494
+ mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ].view (torch .float8_e4m3fn ),
495
+ mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank : - 2 ].view (torch .float8_e4m3fn ),
496
+ mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (buffer .dtype ),
434
497
)
435
498
return
436
499
0 commit comments