Skip to content

Commit 8d2c308

Browse files
yundai424qingquansong
authored andcommitted
Fix loading KV quantization scale; Enable modelopt kv cache (sgl-project#4686)
Co-authored-by: qingquansong <[email protected]>
1 parent 8c75966 commit 8d2c308

38 files changed

+153
-78
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def _parse_quant_hf_config(self):
239239
# check if is modelopt model -- modelopt doesn't have corresponding field
240240
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
241241
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
242-
is_local = os.path.isdir(self.model_path)
242+
is_local = os.path.exists(self.model_path)
243243
modelopt_quant_config = {"quant_method": "modelopt"}
244244
if not is_local:
245245
from huggingface_hub import HfApi

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ def __init__(
292292
self.decode_cuda_graph_metadata = {}
293293
self.target_verify_metadata = {}
294294
self.req_to_token = model_runner.req_to_token_pool.req_to_token
295+
self.kv_cache_dtype = model_runner.kv_cache_dtype
296+
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
295297
self.page_size = model_runner.page_size
296298
self.use_mla = (
297299
model_runner.model_config.attention_arch == AttentionArch.MLA
@@ -520,6 +522,12 @@ def forward_extend(
520522
if layer.sliding_window_size is not None
521523
else (-1, -1)
522524
)
525+
k_descale, v_descale = None, None
526+
if self.kv_cache_dtype_str != "auto":
527+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
528+
k_descale = layer.k_scale.expand(descale_shape)
529+
v_descale = layer.v_scale.expand(descale_shape)
530+
q = q.to(self.kv_cache_dtype)
523531
causal = not layer.is_cross_attention
524532

525533
# Check if we should use local attention
@@ -576,8 +584,8 @@ def forward_extend(
576584
causal=causal,
577585
window_size=window_size,
578586
softcap=layer.logit_cap,
579-
k_descale=layer.k_scale,
580-
v_descale=layer.v_scale,
587+
k_descale=k_descale,
588+
v_descale=v_descale,
581589
)
582590
else:
583591
# Do absorbed multi-latent attention
@@ -609,8 +617,8 @@ def forward_extend(
609617
softmax_scale=layer.scaling,
610618
causal=True,
611619
softcap=layer.logit_cap,
612-
k_descale=layer.k_scale,
613-
v_descale=layer.v_scale,
620+
k_descale=k_descale,
621+
v_descale=v_descale,
614622
)
615623

616624
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -657,6 +665,13 @@ def forward_decode(
657665
)
658666
causal = not layer.is_cross_attention
659667

668+
k_descale, v_descale = None, None
669+
if self.kv_cache_dtype_str != "auto":
670+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
671+
k_descale = layer.k_scale.expand(descale_shape)
672+
v_descale = layer.v_scale.expand(descale_shape)
673+
q = q.to(self.kv_cache_dtype)
674+
660675
if not self.use_mla:
661676
# Do multi-head attention
662677

@@ -694,8 +709,8 @@ def forward_decode(
694709
causal=causal,
695710
window_size=window_size,
696711
softcap=layer.logit_cap,
697-
k_descale=layer.k_scale,
698-
v_descale=layer.v_scale,
712+
k_descale=k_descale,
713+
v_descale=v_descale,
699714
)
700715
else:
701716
# Do absorbed multi-latent attention
@@ -729,8 +744,8 @@ def forward_decode(
729744
softmax_scale=layer.scaling,
730745
causal=True,
731746
softcap=layer.logit_cap,
732-
k_descale=layer.k_scale,
733-
v_descale=layer.v_scale,
747+
k_descale=k_descale,
748+
v_descale=v_descale,
734749
)
735750
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
736751

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
self.max_context_len = model_runner.model_config.context_len
8383
self.skip_prefill = skip_prefill
8484
self.is_multimodal = model_runner.model_config.is_multimodal
85+
self.kv_cache_dtype = model_runner.kv_cache_dtype
86+
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
8587

8688
assert not (
8789
model_runner.sliding_window_size is not None
@@ -391,6 +393,8 @@ def forward_extend(
391393
forward_batch: ForwardBatch,
392394
save_kv_cache=True,
393395
):
396+
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
397+
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
394398
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
395399
self._get_wrapper_idx(layer)
396400
]
@@ -407,7 +411,7 @@ def forward_extend(
407411
assert v is not None
408412
if save_kv_cache:
409413
forward_batch.token_to_kv_pool.set_kv_buffer(
410-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
414+
layer, cache_loc, k, v, k_scale, v_scale
411415
)
412416

413417
o = prefill_wrapper_paged.forward(
@@ -417,8 +421,8 @@ def forward_extend(
417421
sm_scale=layer.scaling,
418422
window_left=layer.sliding_window_size,
419423
logits_soft_cap=logits_soft_cap,
420-
k_scale=layer.k_scale,
421-
v_scale=layer.v_scale,
424+
k_scale=k_scale,
425+
v_scale=v_scale,
422426
)
423427
else:
424428
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -445,7 +449,7 @@ def forward_extend(
445449

446450
if save_kv_cache:
447451
forward_batch.token_to_kv_pool.set_kv_buffer(
448-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
452+
layer, cache_loc, k, v, k_scale, v_scale
449453
)
450454

451455
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -459,6 +463,8 @@ def forward_decode(
459463
forward_batch: ForwardBatch,
460464
save_kv_cache=True,
461465
):
466+
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
467+
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
462468
decode_wrapper = self.forward_metadata.decode_wrappers[
463469
self._get_wrapper_idx(layer)
464470
]
@@ -472,16 +478,16 @@ def forward_decode(
472478
assert v is not None
473479
if save_kv_cache:
474480
forward_batch.token_to_kv_pool.set_kv_buffer(
475-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
481+
layer, cache_loc, k, v, k_scale, v_scale
476482
)
477483

478484
o = decode_wrapper.forward(
479485
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
480486
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
481487
sm_scale=layer.scaling,
482488
logits_soft_cap=layer.logit_cap,
483-
k_scale=layer.k_scale,
484-
v_scale=layer.v_scale,
489+
k_scale=k_scale,
490+
v_scale=v_scale,
485491
)
486492

487493
return o.view(-1, layer.tp_q_head_num * layer.head_dim)

python/sglang/srt/layers/quantization/kv_cache.py

Lines changed: 43 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
QuantizationConfig,
99
QuantizeMethodBase,
1010
)
11+
from sglang.srt.layers.radix_attention import RadixAttention
1112
from sglang.srt.utils import is_hip
1213

1314
_is_hip = is_hip()
@@ -17,7 +18,7 @@
1718

1819
class BaseKVCacheMethod(QuantizeMethodBase):
1920
"""
20-
Quant method that adds `_k_scale` and `_v_scale` attributes to the
21+
Quant method that adds `k_scale` and `v_scale` attributes to the
2122
Attention layer to support loading those scaling factors from checkpoints.
2223
The k/v_scale will be used to:
2324
- quantize k/v_cache entries before saving them to the cache
@@ -36,8 +37,12 @@ def create_weights(self, layer: torch.nn.Module):
3637
# Initialize the KV cache scales to -1.0, which is an invalid value.
3738
# If the k/v_scale appears in the checkpoint, it will be
3839
# overwritten when loading weights.
39-
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
40-
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
40+
layer.k_scale = torch.nn.Parameter(
41+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
42+
)
43+
layer.v_scale = torch.nn.Parameter(
44+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
45+
)
4146

4247
@classmethod
4348
def is_fp8_fnuz(cls) -> bool:
@@ -47,52 +52,38 @@ def is_fp8_fnuz(cls) -> bool:
4752
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
4853
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
4954

50-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
51-
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
52-
# regardless whether the kv-scale is available in the checkpoint.
53-
# No need to process kv scales after loading if we are going to
54-
# calculate them on the fly.
55-
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
56-
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
57-
# We prefer to use separate k_scale and v_scale if present
58-
k_scale = layer.k_scale.to("cpu").tolist()
59-
v_scale = layer.v_scale.to("cpu").tolist()
60-
if _is_hip and self.is_fp8_fnuz():
61-
k_scale *= 2
62-
v_scale *= 2
63-
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
64-
# If no scales were loaded (both scales are invalid negative
65-
# values), use the default value of 1.0
66-
k_scale = 1.0
67-
v_scale = 1.0
68-
else:
69-
# If we find a single kv_scale in the checkpoint, we remap
70-
# kv_scale to k_scale during weight loading, and duplicate
71-
# k_scale to v_scale here
72-
assert layer.k_scale > 0.0
73-
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74-
k_scale = scale_to_duplicate.to("cpu").tolist()
75-
v_scale = scale_to_duplicate.to("cpu").tolist()
76-
if _is_hip and self.is_fp8_fnuz():
77-
k_scale *= 2
78-
v_scale *= 2
79-
80-
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
81-
raise ValueError(
82-
"Only support per-tensor scaling factor " "for fp8 KV cache"
83-
)
84-
85-
# These are used in the final Attention.forward()
86-
layer._k_scale.copy_(k_scale)
87-
layer._v_scale.copy_(v_scale)
88-
layer._k_scale_float = k_scale
89-
layer._v_scale_float = v_scale
90-
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
91-
logger.warning(
92-
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
93-
"may cause accuracy issues. Please make sure k/v_scale "
94-
"scaling factors are available in the fp8 checkpoint."
95-
)
96-
97-
del layer.k_scale
98-
del layer.v_scale
55+
def process_weights_after_loading(self, layer: RadixAttention) -> None:
56+
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
57+
# We prefer to use separate k_scale and v_scale if present
58+
k_scale = layer.k_scale.to("cpu").tolist()
59+
v_scale = layer.v_scale.to("cpu").tolist()
60+
if _is_hip and self.is_fp8_fnuz():
61+
k_scale *= 2
62+
v_scale *= 2
63+
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
64+
# If no scales were loaded (both scales are invalid negative
65+
# values), use the default value of 1.0
66+
k_scale = 1.0
67+
v_scale = 1.0
68+
else:
69+
# If we find a single kv_scale in the checkpoint, we remap
70+
# kv_scale to k_scale during weight loading, and duplicate
71+
# k_scale to v_scale here
72+
assert layer.k_scale > 0.0
73+
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74+
k_scale = scale_to_duplicate.to("cpu").tolist()
75+
v_scale = scale_to_duplicate.to("cpu").tolist()
76+
if _is_hip and self.is_fp8_fnuz():
77+
k_scale *= 2
78+
v_scale *= 2
79+
80+
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
81+
raise ValueError(
82+
"Only support per-tensor scaling factor " "for fp8 KV cache"
83+
)
84+
85+
# These are used in the final Attention.forward()
86+
layer.k_scale.copy_(k_scale)
87+
layer.v_scale.copy_(v_scale)
88+
layer.k_scale_float = k_scale
89+
layer.v_scale_float = v_scale

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
from torch.nn.parameter import Parameter
88

9-
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
109
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
1110
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
1211
from sglang.srt.layers.quantization.base_config import (
@@ -22,6 +21,7 @@
2221
convert_to_channelwise,
2322
requantize_with_max_scale,
2423
)
24+
from sglang.srt.layers.radix_attention import RadixAttention
2525

2626
# Initialize logger for the module
2727
logger = logging.getLogger(__name__)
@@ -33,12 +33,19 @@
3333
class ModelOptFp8Config(QuantizationConfig):
3434
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
3535

36-
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
36+
def __init__(
37+
self,
38+
is_checkpoint_fp8_serialized: bool = False,
39+
kv_cache_quant_method: Optional[str] = None,
40+
exclude_modules: Optional[List[str]] = None,
41+
) -> None:
3742
"""
3843
Args:
3944
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
4045
"""
4146
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
47+
self.kv_cache_quant_method = kv_cache_quant_method
48+
self.exclude_modules = exclude_modules
4249
if is_checkpoint_fp8_serialized:
4350
logger.warning(
4451
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
@@ -63,22 +70,36 @@ def get_config_filenames(cls) -> List[str]:
6370
@classmethod
6471
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
6572
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
73+
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
74+
"kv_cache_quant_algo"
75+
)
76+
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
77+
"exclude_modules"
78+
)
6679

6780
if "FP8" not in quant_method:
6881
raise ValueError(
6982
"ModelOpt only supports static FP8 quantization in SGLang. "
7083
"Check the `hf_quant_config.json` file for your model's configuration."
7184
)
7285

73-
return cls(is_checkpoint_fp8_serialized=True)
86+
return cls(
87+
is_checkpoint_fp8_serialized=True,
88+
kv_cache_quant_method=kv_cache_quant_method,
89+
exclude_modules=exclude_modules,
90+
)
7491

7592
def get_quant_method(
7693
self, layer: torch.nn.Module, prefix: str
7794
) -> Optional["QuantizeMethodBase"]:
95+
if self.exclude_modules and any(
96+
module in prefix for module in self.exclude_modules
97+
):
98+
return None
7899

79100
if isinstance(layer, LinearBase):
80101
return ModelOptFp8LinearMethod(self)
81-
if isinstance(layer, AttentionBackend):
102+
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
82103
return ModelOptFp8KVCacheMethod(self)
83104

84105
return None

python/sglang/srt/layers/radix_attention.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
# ==============================================================================
1414
"""Radix attention."""
1515

16+
from typing import Optional
17+
1618
from torch import nn
1719

20+
from sglang.srt.layers.linear import UnquantizedLinearMethod
21+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
1822
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
1923

2024

@@ -34,6 +38,7 @@ def __init__(
3438
v_head_dim: int = -1,
3539
sliding_window_size: int = -1,
3640
is_cross_attention: bool = False,
41+
quant_config: Optional[QuantizationConfig] = None,
3742
prefix: str = "",
3843
use_irope: bool = False,
3944
):
@@ -49,9 +54,16 @@ def __init__(
4954
self.logit_cap = logit_cap
5055
self.sliding_window_size = sliding_window_size or -1
5156
self.is_cross_attention = is_cross_attention
57+
self.use_irope = use_irope
5258
self.k_scale = None
5359
self.v_scale = None
54-
self.use_irope = use_irope
60+
self.k_scale_float = None
61+
self.v_scale_float = None
62+
self.quant_method = None
63+
if quant_config is not None:
64+
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
65+
if self.quant_method is not None:
66+
self.quant_method.create_weights(self)
5567

5668
def forward(
5769
self,

0 commit comments

Comments
 (0)