Skip to content

Commit 35ad7e6

Browse files
fzyzcjylifuhuang
authored andcommitted
Tiny refactor ModelConfig.from_server_args (sgl-project#5219)
1 parent c0bcbd2 commit 35ad7e6

File tree

6 files changed

+23
-53
lines changed

6 files changed

+23
-53
lines changed

python/sglang/bench_one_batch.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,7 @@ def load_model(server_args, port_args, tp_rank):
137137
suppress_other_loggers()
138138
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
139139

140-
model_config = ModelConfig(
141-
server_args.model_path,
142-
trust_remote_code=server_args.trust_remote_code,
143-
revision=server_args.revision,
144-
context_length=server_args.context_length,
145-
model_override_args=server_args.json_model_override_args,
146-
is_embedding=server_args.is_embedding,
147-
enable_multimodal=server_args.enable_multimodal,
148-
dtype=server_args.dtype,
149-
quantization=server_args.quantization,
150-
)
140+
model_config = ModelConfig.from_server_args(server_args)
151141
model_runner = ModelRunner(
152142
model_config=model_config,
153143
mem_fraction_static=server_args.mem_fraction_static,

python/sglang/srt/configs/model_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from sglang.srt.hf_transformers_utils import get_config, get_context_length
2626
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
27+
from sglang.srt.server_args import ServerArgs
2728
from sglang.srt.utils import get_bool_env_var, is_hip
2829

2930
logger = logging.getLogger(__name__)
@@ -210,6 +211,21 @@ def __init__(
210211
self.hf_eos_token_id = self.get_hf_eos_token_id()
211212
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
212213

214+
@staticmethod
215+
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
216+
return ModelConfig(
217+
model_path=model_path or server_args.model_path,
218+
trust_remote_code=server_args.trust_remote_code,
219+
revision=server_args.revision,
220+
context_length=server_args.context_length,
221+
model_override_args=server_args.json_model_override_args,
222+
is_embedding=server_args.is_embedding,
223+
enable_multimodal=server_args.enable_multimodal,
224+
dtype=server_args.dtype,
225+
quantization=server_args.quantization,
226+
**kwargs,
227+
)
228+
213229
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
214230
def get_total_num_kv_heads(self) -> int:
215231
"""Returns the total number of KV heads."""

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -455,17 +455,7 @@ def __init__(
455455
def init_tokenizer(self):
456456
server_args = self.server_args
457457

458-
self.model_config = ModelConfig(
459-
server_args.model_path,
460-
trust_remote_code=server_args.trust_remote_code,
461-
revision=server_args.revision,
462-
context_length=server_args.context_length,
463-
model_override_args=server_args.json_model_override_args,
464-
is_embedding=server_args.is_embedding,
465-
enable_multimodal=server_args.enable_multimodal,
466-
dtype=server_args.dtype,
467-
quantization=server_args.quantization,
468-
)
458+
self.model_config = ModelConfig.from_server_args(server_args)
469459
self.is_generation = self.model_config.is_generation
470460

471461
if server_args.skip_tokenizer_init:

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,7 @@ def __init__(
165165
# Read model args
166166
self.model_path = server_args.model_path
167167
self.served_model_name = server_args.served_model_name
168-
self.model_config = ModelConfig(
169-
server_args.model_path,
170-
trust_remote_code=server_args.trust_remote_code,
171-
revision=server_args.revision,
172-
context_length=server_args.context_length,
173-
model_override_args=server_args.json_model_override_args,
174-
is_embedding=server_args.is_embedding,
175-
enable_multimodal=server_args.enable_multimodal,
176-
dtype=server_args.dtype,
177-
quantization=server_args.quantization,
178-
)
168+
self.model_config = ModelConfig.from_server_args(server_args)
179169

180170
self.is_generation = self.model_config.is_generation
181171
self.is_image_gen = self.model_config.is_image_gen

python/sglang/srt/managers/tp_worker.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,13 @@ def __init__(
6565
self.pp_rank = pp_rank
6666

6767
# Init model and tokenizer
68-
self.model_config = ModelConfig(
69-
(
68+
self.model_config = ModelConfig.from_server_args(
69+
server_args,
70+
model_path=(
7071
server_args.model_path
7172
if not is_draft_worker
7273
else server_args.speculative_draft_model_path
7374
),
74-
trust_remote_code=server_args.trust_remote_code,
75-
revision=server_args.revision,
76-
context_length=server_args.context_length,
77-
model_override_args=server_args.json_model_override_args,
78-
is_embedding=server_args.is_embedding,
79-
enable_multimodal=server_args.enable_multimodal,
80-
dtype=server_args.dtype,
81-
quantization=server_args.quantization,
8275
is_draft_model=is_draft_worker,
8376
)
8477

test/srt/test_gptqmodel_dynamic.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
4343
pass
4444

4545
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
46-
model_config = ModelConfig(
47-
server_args.model_path,
48-
trust_remote_code=server_args.trust_remote_code,
49-
revision=server_args.revision,
50-
context_length=server_args.context_length,
51-
model_override_args=server_args.json_model_override_args,
52-
is_embedding=server_args.is_embedding,
53-
dtype=server_args.dtype,
54-
quantization=server_args.quantization,
55-
)
46+
model_config = ModelConfig.from_server_args(server_args)
5647

5748
load_config = LoadConfig()
5849
device_config = DeviceConfig("cuda")

0 commit comments

Comments
 (0)