Skip to content

Auto set draft model path for MTP #5793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
) -> None:

self.model_path = model_path
Expand Down Expand Up @@ -85,6 +86,12 @@ def __init__(
else:
enable_multimodal = True

if (
is_draft_model
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"

# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
is_draft_model=is_draft_worker,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
Expand Down
13 changes: 11 additions & 2 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,14 @@ def profile_max_num_token(self, total_gpu_memory: int):
self.device, self.gpu_id, distributed=self.tp_size > 1
)
if self.use_mla_backend:
num_layers = (
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
)
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
* num_layers
* torch._utils._element_size(self.kv_cache_dtype)
)
else:
Expand Down Expand Up @@ -809,7 +814,11 @@ def init_memory_pool(
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
layer_num=(
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
),
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
Expand Down
258 changes: 1 addition & 257 deletions python/sglang/srt/models/deepseek_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,263 +177,7 @@ def forward(
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
assert num_nextn_layers == self.config.num_hidden_layers
else:
raise ValueError("num_nextn_predict_layers is not in the config")

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion > 0:
logger.info(
f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
)
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
else:
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.0.mlp.shared_experts.{suffix}"
)
for num_repeat in range(self.n_share_experts_fusion):
weights_list.append(
(
f"model.layers.0."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
)

# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None

nextn_layer_prefix = "model.layers.0"
nextn_spec_weight_names = [
"shared_head.norm",
"eh_proj",
"enorm",
"hnorm",
]

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if not name.startswith(nextn_layer_prefix):
continue

# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue

is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")

if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

# Handle fused_qkv_a_proj
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)

# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):

q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)

param_name = name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
)
param = params_dict[param_name]

weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)

self_attn = self.model.decoder.self_attn
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
if _is_cuda:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
).T
else:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
input_scale=None,
)
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv

w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if w.dtype == torch.int8:
if hasattr(self.quant_config, "weight_block_size"):
# block-wise int8 need it
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
torch.bfloat16
)
else:
# channel-wise int8 need it
assert hasattr(self_attn.kv_b_proj, "weight_scale")
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
super().load_weights(weights, is_nextn=True)


EntryClass = [DeepseekV3ForCausalLMNextN]
Loading
Loading