Skip to content

Commit 19933b6

Browse files
committed
refactor
1 parent f107edd commit 19933b6

File tree

6 files changed

+10
-24
lines changed

6 files changed

+10
-24
lines changed

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
FlashInfer is faster and Triton is easier to customize.
77
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
88
"""
9-
import logging
9+
1010
import os
1111
from dataclasses import dataclass
1212
from enum import Enum, auto
@@ -37,7 +37,7 @@
3737
from flashinfer.cascade import merge_state
3838
from flashinfer.decode import _get_range_buf, get_seq_lens
3939

40-
logger = logging.getLogger(__name__)
40+
4141
class WrapperDispatch(Enum):
4242
SLIDING_WINDOW = auto()
4343
CROSS_ATTENTION = auto()
@@ -83,6 +83,7 @@ def __init__(
8383
self.max_context_len = model_runner.model_config.context_len
8484
self.skip_prefill = skip_prefill
8585
self.is_multimodal = model_runner.model_config.is_multimodal
86+
self.kv_cache_dtype = model_runner.kv_cache_dtype
8687

8788
assert not (
8889
model_runner.sliding_window_size is not None
@@ -392,8 +393,8 @@ def forward_extend(
392393
forward_batch: ForwardBatch,
393394
save_kv_cache=True,
394395
):
395-
k_scale = layer.k_scale_float if layer.kv_cache_dtype != "auto" else None
396-
v_scale = layer.v_scale_float if layer.kv_cache_dtype != "auto" else None
396+
k_scale = layer.k_scale_float if self.kv_cache_dtype != "auto" else None
397+
v_scale = layer.v_scale_float if self.kv_cache_dtype != "auto" else None
397398
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
398399
self._get_wrapper_idx(layer)
399400
]
@@ -462,8 +463,8 @@ def forward_decode(
462463
forward_batch: ForwardBatch,
463464
save_kv_cache=True,
464465
):
465-
k_scale = layer.k_scale_float if layer.kv_cache_dtype != "auto" else None
466-
v_scale = layer.v_scale_float if layer.kv_cache_dtype != "auto" else None
466+
k_scale = layer.k_scale_float if self.kv_cache_dtype != "auto" else None
467+
v_scale = layer.v_scale_float if self.kv_cache_dtype != "auto" else None
467468
decode_wrapper = self.forward_metadata.decode_wrappers[
468469
self._get_wrapper_idx(layer)
469470
]

python/sglang/srt/layers/radix_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(
5656
self.v_scale = None
5757
self.k_scale_float = None
5858
self.v_scale_float = None
59-
self.kv_cache_dtype = "auto"
6059
if quant_config is not None:
6160
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
6261

python/sglang/srt/model_executor/model_runner.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,12 +388,6 @@ def load_model(self):
388388
monkey_patch_vllm_parallel_state(reverse=True)
389389
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
390390

391-
# Set KV cache dtype for RadixAttention if the model uses it
392-
if hasattr(self.model, "set_kv_cache_dtype"):
393-
self.model.set_kv_cache_dtype(self.server_args.kv_cache_dtype)
394-
logger.info(
395-
f"Set KV cache dtype to {self.server_args.kv_cache_dtype} for {type(self.model).__name__}"
396-
)
397391
if self.server_args.kv_cache_dtype == "fp8_e4m3":
398392
if self.server_args.quantization_param_path is not None:
399393
if callable(getattr(self.model, "load_kv_cache_scales", None)):

python/sglang/srt/model_loader/loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,6 @@ def load_model(
373373
for _, module in model.named_modules():
374374
quant_method = getattr(module, "quant_method", None)
375375
if quant_method is not None:
376-
logger.warning(f"{module.__class__.__name__}, {quant_method}")
377376
# When quant methods need to process weights after loading
378377
# (for repacking, quantizing, etc), they expect parameters
379378
# to be on the global target device. This scope is for the

python/sglang/srt/models/llama.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,6 @@ def forward(
185185
output, _ = self.o_proj(attn_output)
186186
return output
187187

188-
def set_kv_cache_dtype(self, kv_cache_dtype: str):
189-
self.attn.kv_cache_dtype = kv_cache_dtype
190188

191189
class LlamaDecoderLayer(nn.Module):
192190
def __init__(
@@ -238,9 +236,6 @@ def __init__(
238236
self.post_attention_layernorm = RMSNorm(
239237
config.hidden_size, eps=config.rms_norm_eps
240238
)
241-
242-
def set_kv_cache_dtype(self, kv_cache_dtype: str):
243-
self.self_attn.set_kv_cache_dtype(kv_cache_dtype)
244239

245240
def forward(
246241
self,
@@ -628,11 +623,6 @@ def set_eagle3_layers_to_capture(self):
628623
self.capture_aux_hidden_states = True
629624
num_layers = self.config.num_hidden_layers
630625
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
631-
632-
def set_kv_cache_dtype(self, kv_cache_dtype: str):
633-
for layer in self.model.layers:
634-
if hasattr(layer, "set_kv_cache_dtype"):
635-
layer.set_kv_cache_dtype(kv_cache_dtype)
636626

637627

638628
class Phi3ForCausalLM(LlamaForCausalLM):

test/srt/test_eval_fp8_accuracy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def _run_test(self, model, other_args, expected_score):
137137
finally:
138138
kill_process_tree(process.pid)
139139

140+
@unittest.skipIf(
141+
torch.version.hip is not None, "modelopt quantization unsupported on ROCm"
142+
)
140143
def test_mmlu_offline_only(self):
141144
"""Test with offline quantization only."""
142145
self._run_test(

0 commit comments

Comments
 (0)