Skip to content

Commit 7f00a46

Browse files
merrymercyjimoosciuc
authored andcommitted
Clean up imports (sgl-project#5467)
1 parent 97f12dd commit 7f00a46

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+375
-572
lines changed

python/sglang/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,22 @@
2424
user_end,
2525
video,
2626
)
27+
from sglang.global_config import global_config
2728
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
2829
from sglang.lang.choices import (
2930
greedy_token_selection,
3031
token_length_normalized,
3132
unconditional_likelihood_normalized,
3233
)
3334
from sglang.utils import LazyImport
35+
from sglang.version import __version__
3436

3537
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
3638
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
3739
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
3840
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
3941
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
4042

41-
# Other configs
42-
from sglang.global_config import global_config
43-
from sglang.version import __version__
44-
4543
__all__ = [
4644
"Engine",
4745
"Runtime",

python/sglang/bench_serving.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -707,10 +707,6 @@ def sample_random_requests(
707707

708708
# Download sharegpt if necessary
709709
if not os.path.isfile(dataset_path):
710-
print(
711-
"If you do not want to randomly sample from a dataset,"
712-
" please use --dataset-name random-ids."
713-
)
714710
dataset_path = download_and_cache_file(SHAREGPT_URL)
715711

716712
# Load the dataset.

python/sglang/lang/__init__.py

Whitespace-only changes.

python/sglang/lang/backend/anthropic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from typing import List, Optional, Union
2-
3-
import numpy as np
4-
51
from sglang.lang.backend.base_backend import BaseBackend
62
from sglang.lang.chat_template import get_chat_template
73
from sglang.lang.interpreter import StreamExecutor

python/sglang/lang/backend/base_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Optional, Union
1+
from typing import List, Optional, Union
22

33
from sglang.lang.chat_template import get_chat_template
44
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod

python/sglang/lang/backend/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import time
44
import warnings
5-
from typing import Callable, List, Optional, Union
5+
from typing import List, Optional, Union
66

77
import numpy as np
88

python/sglang/lang/backend/vertexai.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import warnings
3-
from typing import Optional
43

54
from sglang.lang.backend.base_backend import BaseBackend
65
from sglang.lang.chat_template import get_chat_template

python/sglang/lang/compiler.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@
55

66
from sglang.global_config import global_config
77
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
8-
from sglang.lang.ir import (
9-
SglArgument,
10-
SglConstantText,
11-
SglExpr,
12-
SglSamplingParams,
13-
SglVariable,
14-
)
8+
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
159

1610

1711
def compile_func(function, backend):

python/sglang/lang/tracer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
"""Tracing a program."""
22

33
import uuid
4-
from typing import Any, Callable, Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional
55

6-
from sglang.global_config import global_config
76
from sglang.lang.backend.base_backend import BaseBackend
87
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
98
from sglang.lang.ir import (
109
SglArgument,
11-
SglCommitLazy,
12-
SglConcateAndAppend,
1310
SglConstantText,
1411
SglExpr,
1512
SglExprList,
1613
SglFork,
17-
SglFunction,
1814
SglGen,
1915
SglGetForkItem,
2016
SglRoleBegin,
@@ -230,8 +226,8 @@ def _execute_role_end(self, expr: SglRoleEnd):
230226
self.cur_role = None
231227

232228
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
233-
new_node = SglVariable(name, source=self.last_node)
234-
self.variables[name] = new_node
229+
new_node = SglVariable(expr.name, source=self.last_node)
230+
self.variables[expr.name] = new_node
235231

236232
def get_var(self, name):
237233
ret = self.arguments.get(name, None)

python/sglang/srt/_custom_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
22
import logging
3-
import os
43
from typing import List, Tuple
54

65
import torch
7-
import torch.library
86

97
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
108

python/sglang/srt/custom_op.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -42,65 +42,3 @@ def dispatch_forward(self):
4242
return self.forward_hip
4343
else:
4444
return self.forward_native
45-
46-
47-
if _is_cuda:
48-
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
49-
50-
def scaled_fp8_quant(
51-
input: torch.Tensor,
52-
scale: Optional[torch.Tensor] = None,
53-
num_token_padding: Optional[int] = None,
54-
use_per_token_if_dynamic: bool = False,
55-
) -> tuple[torch.Tensor, torch.Tensor]:
56-
"""
57-
Quantize input tensor to FP8 (8-bit floating point) format.
58-
59-
Args:
60-
input (torch.Tensor): Input tensor to be quantized
61-
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
62-
If None, scales will be computed dynamically.
63-
num_token_padding (Optional[int]): If specified, pad the first dimension
64-
of the output to at least this value.
65-
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
66-
determines the quantization granularity:
67-
- True: compute scale per token
68-
- False: compute single scale per tensor
69-
70-
Returns:
71-
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
72-
- quantized_tensor: The FP8 quantized version of input
73-
- scale_tensor: The scaling factors used for quantization
74-
75-
Raises:
76-
AssertionError: If input is not 2D or if static scale's numel != 1
77-
"""
78-
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
79-
shape = input.shape
80-
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
81-
if num_token_padding:
82-
shape = (max(num_token_padding, input.shape[0]), shape[1])
83-
output = torch.empty(shape, device=input.device, dtype=out_dtype)
84-
85-
if scale is None:
86-
# Dynamic scaling
87-
if use_per_token_if_dynamic:
88-
scale = torch.empty(
89-
(shape[0], 1), device=input.device, dtype=torch.float32
90-
)
91-
sgl_per_token_quant_fp8(input, output, scale)
92-
else:
93-
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
94-
sgl_per_tensor_quant_fp8(
95-
input, output, scale, is_static=False
96-
) # False for dynamic
97-
else:
98-
# Static scaling
99-
assert (
100-
scale.numel() == 1
101-
), f"Expected scalar scale, got numel={scale.numel()}"
102-
sgl_per_tensor_quant_fp8(
103-
input, output, scale, is_static=True
104-
) # True for static
105-
106-
return output, scale

python/sglang/srt/entrypoints/verl_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
from PIL.Image import Image
2020
from torch.distributed.tensor import DeviceMesh, DTensor
2121

22+
from sglang.srt.entrypoints.engine import Engine
2223
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
2324
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
2425
from sglang.srt.patch_torch import monkey_patch_torch_reductions
25-
from sglang.srt.server import Engine
26-
from sglang.srt.server_args import PortArgs, ServerArgs
2726
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
2827

2928

python/sglang/srt/layers/activation.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,19 @@
2121
import torch.nn as nn
2222
import torch.nn.functional as F
2323

24-
from sglang.srt.utils import is_cuda_available
25-
26-
_is_cuda = is_cuda_available()
27-
28-
if _is_cuda:
29-
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
30-
3124
from sglang.srt.custom_op import CustomOp
3225
from sglang.srt.distributed import (
3326
divide,
3427
get_tensor_model_parallel_rank,
3528
get_tensor_model_parallel_world_size,
3629
)
3730
from sglang.srt.layers.quantization.base_config import QuantizationConfig
38-
from sglang.srt.utils import set_weight_attrs
31+
from sglang.srt.utils import is_cuda_available, set_weight_attrs
32+
33+
_is_cuda = is_cuda_available()
34+
35+
if _is_cuda:
36+
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
3937

4038
logger = logging.getLogger(__name__)
4139

python/sglang/srt/layers/layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121

22+
from sglang.srt.custom_op import CustomOp
2223
from sglang.srt.utils import is_cuda_available
2324

2425
_is_cuda = is_cuda_available()
@@ -31,7 +32,6 @@
3132
rmsnorm,
3233
)
3334

34-
from sglang.srt.custom_op import CustomOp
3535

3636
logger = logging.getLogger(__name__)
3737

python/sglang/srt/layers/moe/ep_moe/layer.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, List, Optional, Tuple
33

44
import torch
5+
from torch.nn import Module
56

67
try:
78
from deep_gemm import (
@@ -13,8 +14,6 @@
1314
except ImportError:
1415
use_deep_gemm = False
1516

16-
from torch.nn import Module
17-
1817
from sglang.srt.custom_op import CustomOp
1918
from sglang.srt.distributed import (
2019
get_tensor_model_parallel_rank,
@@ -37,22 +36,17 @@
3736
QuantizeMethodBase,
3837
)
3938
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
39+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
4040
from sglang.srt.model_executor.forward_batch_info import ForwardMode
41-
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
41+
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
4242

43-
_is_cuda = is_cuda()
43+
_is_hip = is_hip()
4444

45-
if _is_cuda:
46-
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
47-
else:
48-
from vllm import _custom_ops as vllm_ops
45+
if _is_hip:
46+
from vllm._custom_ops import scaled_fp8_quant
4947

5048
logger = logging.getLogger(__name__)
5149

52-
_is_hip = is_hip()
53-
54-
_buffer = None
55-
5650

5751
class GroupedGemmRunner(torch.nn.Module):
5852
flashinfer_gemm_warpper = None
@@ -740,20 +734,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
740734
)
741735

742736
for expert in range(layer.num_experts_per_partition):
743-
if _is_cuda:
744-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
745-
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
746-
)
747-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
748-
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
749-
)
750-
else:
751-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
752-
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
753-
)
754-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
755-
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
756-
)
737+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
738+
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
739+
)
740+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
741+
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
742+
)
757743
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
758744
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
759745
return

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import triton.language as tl
1414

1515
from sglang.srt.layers.moe.topk import select_experts
16+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
1617
from sglang.srt.utils import (
1718
direct_register_custom_op,
1819
get_bool_env_var,
@@ -22,28 +23,25 @@
2223
)
2324

2425
_is_hip = is_hip()
25-
26-
27-
logger = logging.getLogger(__name__)
28-
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
29-
30-
enable_moe_align_block_size_triton = bool(
31-
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
32-
)
33-
3426
_is_cuda = is_cuda()
3527

3628
if _is_cuda:
3729
from sgl_kernel import gelu_and_mul, silu_and_mul
38-
39-
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
4030
else:
4131
from vllm import _custom_ops as vllm_ops
32+
from vllm._custom_ops import scaled_fp8_quant
4233

4334
if _is_cuda or _is_hip:
4435
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
4536

4637

38+
logger = logging.getLogger(__name__)
39+
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
40+
enable_moe_align_block_size_triton = bool(
41+
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
42+
)
43+
44+
4745
@triton.jit
4846
def write_zeros_to_output(
4947
c_ptr,
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
770768
# activation tensor-wise fp8 quantization, dynamic or static
771769
padded_size = padding_size
772770
# activations apply per-token quantization when weights apply per-channel quantization by default
773-
if _is_cuda:
774-
A, A_scale = sgl_scaled_fp8_quant(
775-
A, A_scale, use_per_token_if_dynamic=per_channel_quant
776-
)
777-
else:
778-
A, A_scale = vllm_ops.scaled_fp8_quant(
779-
A, A_scale, use_per_token_if_dynamic=per_channel_quant
780-
)
771+
A, A_scale = scaled_fp8_quant(
772+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
773+
)
781774
else:
782775
# activation block-wise fp8 quantization
783776
assert len(block_shape) == 2

0 commit comments

Comments
 (0)