@@ -325,7 +325,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
325
325
batch_size = len (seqlens_in_batch )
326
326
device = seqlens_in_batch .device
327
327
328
- if forward_batch .forward_mode .is_decode ():
328
+ if forward_batch .forward_mode .is_decode_or_idle ():
329
329
# Draft Decode
330
330
if forward_batch .spec_info is not None :
331
331
metadata .cache_seqlens_int32 = (
@@ -527,7 +527,9 @@ def forward_extend(
527
527
else (- 1 , - 1 )
528
528
)
529
529
k_descale , v_descale = None , None
530
- if self .kv_cache_dtype_str != "auto" :
530
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
531
+ # has corresponding quantization method so that layer.k_scale is not None
532
+ if self .kv_cache_dtype_str != "auto" and layer .k_scale is not None :
531
533
descale_shape = (forward_batch .batch_size , layer .tp_k_head_num )
532
534
k_descale = layer .k_scale .expand (descale_shape )
533
535
v_descale = layer .v_scale .expand (descale_shape )
@@ -670,10 +672,13 @@ def forward_decode(
670
672
causal = not layer .is_cross_attention
671
673
672
674
k_descale , v_descale = None , None
675
+ # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
676
+ # has corresponding quantization method so that layer.k_scale is not None
673
677
if self .kv_cache_dtype_str != "auto" :
674
- descale_shape = (forward_batch .batch_size , layer .tp_k_head_num )
675
- k_descale = layer .k_scale .expand (descale_shape )
676
- v_descale = layer .v_scale .expand (descale_shape )
678
+ if layer .k_scale is not None :
679
+ descale_shape = (forward_batch .batch_size , layer .tp_k_head_num )
680
+ k_descale = layer .k_scale .expand (descale_shape )
681
+ v_descale = layer .v_scale .expand (descale_shape )
677
682
q = q .to (self .kv_cache_dtype )
678
683
679
684
if not self .use_mla :
@@ -834,7 +839,7 @@ def init_forward_metadata_capture_cuda_graph(
834
839
"""Initialize forward metadata for capturing CUDA graph."""
835
840
metadata = FlashAttentionMetadata ()
836
841
device = seq_lens .device
837
- if forward_mode .is_decode ():
842
+ if forward_mode .is_decode_or_idle ():
838
843
if spec_info is not None :
839
844
# Draft Decode
840
845
metadata .cache_seqlens_int32 = self .decode_cuda_graph_metadata [
@@ -937,7 +942,7 @@ def init_forward_metadata_replay_cuda_graph(
937
942
seq_lens = seq_lens [:bs ]
938
943
seq_lens_cpu = seq_lens_cpu [:bs ]
939
944
req_pool_indices = req_pool_indices [:bs ]
940
- if forward_mode .is_decode ():
945
+ if forward_mode .is_decode_or_idle ():
941
946
metadata = self .decode_cuda_graph_metadata [bs ]
942
947
943
948
if spec_info is not None :
0 commit comments