Skip to content

Commit 645b0c2

Browse files
inkcherryjimoosciuc
authored andcommitted
fix dummy-load deepseekv2 (sgl-project#4535)
1 parent 8f4359b commit 645b0c2

File tree

2 files changed

+87
-73
lines changed

2 files changed

+87
-73
lines changed

python/sglang/srt/model_loader/loader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,14 @@ def load_model(
489489
# NOTE(woosuk): For accurate performance evaluation, we assign
490490
# random values to the weights.
491491
initialize_dummy_weights(model)
492+
493+
# Model weight loading consists of two stages:
494+
# 1. Initial weight loading.
495+
# 2. Post-processing of weights, including assigning specific member variables.
496+
# For `dummy_init`, only the second stage is required.
497+
if hasattr(model, "post_load_weights"):
498+
model.post_load_weights()
499+
492500
return model.eval()
493501

494502

python/sglang/srt/models/deepseek_v2.py

Lines changed: 79 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,84 @@ def forward(
13801380
input_ids, hidden_states, self.lm_head, forward_batch
13811381
)
13821382

1383+
def post_load_weights(self):
1384+
1385+
# Perform post-processing after loading weights
1386+
1387+
if not global_server_args_dict["disable_mla"]:
1388+
for layer_id in range(self.config.num_hidden_layers):
1389+
self_attn = self.model.layers[layer_id].self_attn
1390+
if hasattr(self_attn.kv_b_proj, "qweight"):
1391+
# AWQ compatible
1392+
if _is_cuda:
1393+
w = awq_dequantize(
1394+
self_attn.kv_b_proj.qweight,
1395+
self_attn.kv_b_proj.scales,
1396+
self_attn.kv_b_proj.qzeros,
1397+
).T
1398+
else:
1399+
w = ops.awq_dequantize(
1400+
self_attn.kv_b_proj.qweight,
1401+
self_attn.kv_b_proj.scales,
1402+
self_attn.kv_b_proj.qzeros,
1403+
0,
1404+
0,
1405+
0,
1406+
).T
1407+
else:
1408+
w = self_attn.kv_b_proj.weight
1409+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1410+
# This may affect the accuracy of fp8 model.
1411+
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1412+
torch.float8_e4m3fn,
1413+
torch.float8_e4m3fnuz,
1414+
):
1415+
weight_block_size = self.quant_config.weight_block_size
1416+
if weight_block_size is not None:
1417+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1418+
if _is_hip:
1419+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1420+
weight=w,
1421+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1422+
input_scale=None,
1423+
)
1424+
else:
1425+
weight = w
1426+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
1427+
1428+
w, scale = block_quant_to_tensor_quant(
1429+
weight, weight_scale, weight_block_size
1430+
)
1431+
self_attn.w_scale = scale
1432+
if w.dtype == torch.int8:
1433+
if hasattr(self.quant_config, "weight_block_size"):
1434+
# block-wise int8 need it
1435+
weight_block_size = self.quant_config.weight_block_size
1436+
if weight_block_size is not None:
1437+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1438+
weight = w
1439+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
1440+
w = int8_block_dequant(
1441+
weight, weight_scale, weight_block_size
1442+
).to(torch.bfloat16)
1443+
else:
1444+
# channel-wise int8 need it
1445+
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1446+
torch.bfloat16
1447+
)
1448+
w_kc, w_vc = w.unflatten(
1449+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1450+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1451+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1452+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1453+
if (
1454+
hasattr(self_attn.kv_b_proj, "weight_scale")
1455+
and self_attn.w_scale is None
1456+
):
1457+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1458+
if _is_hip:
1459+
self_attn.w_scale *= 2.0
1460+
13831461
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
13841462
stacked_params_mapping = [
13851463
# (param_name, shard_name, shard_id)
@@ -1504,79 +1582,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
15041582
)
15051583
weight_loader(param, loaded_weight)
15061584

1507-
if not global_server_args_dict["disable_mla"]:
1508-
for layer_id in range(self.config.num_hidden_layers):
1509-
self_attn = self.model.layers[layer_id].self_attn
1510-
if hasattr(self_attn.kv_b_proj, "qweight"):
1511-
# AWQ compatible
1512-
if _is_cuda:
1513-
w = awq_dequantize(
1514-
self_attn.kv_b_proj.qweight,
1515-
self_attn.kv_b_proj.scales,
1516-
self_attn.kv_b_proj.qzeros,
1517-
).T
1518-
else:
1519-
w = ops.awq_dequantize(
1520-
self_attn.kv_b_proj.qweight,
1521-
self_attn.kv_b_proj.scales,
1522-
self_attn.kv_b_proj.qzeros,
1523-
0,
1524-
0,
1525-
0,
1526-
).T
1527-
else:
1528-
w = self_attn.kv_b_proj.weight
1529-
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1530-
# This may affect the accuracy of fp8 model.
1531-
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
1532-
torch.float8_e4m3fn,
1533-
torch.float8_e4m3fnuz,
1534-
):
1535-
weight_block_size = self.quant_config.weight_block_size
1536-
if weight_block_size is not None:
1537-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1538-
if _is_hip:
1539-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1540-
weight=w,
1541-
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1542-
input_scale=None,
1543-
)
1544-
else:
1545-
weight = w
1546-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
1547-
1548-
w, scale = block_quant_to_tensor_quant(
1549-
weight, weight_scale, weight_block_size
1550-
)
1551-
self_attn.w_scale = scale
1552-
if w.dtype == torch.int8:
1553-
if hasattr(self.quant_config, "weight_block_size"):
1554-
# block-wise int8 need it
1555-
weight_block_size = self.quant_config.weight_block_size
1556-
if weight_block_size is not None:
1557-
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1558-
weight = w
1559-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
1560-
w = int8_block_dequant(
1561-
weight, weight_scale, weight_block_size
1562-
).to(torch.bfloat16)
1563-
else:
1564-
# channel-wise int8 need it
1565-
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
1566-
torch.bfloat16
1567-
)
1568-
w_kc, w_vc = w.unflatten(
1569-
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1570-
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1571-
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1572-
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1573-
if (
1574-
hasattr(self_attn.kv_b_proj, "weight_scale")
1575-
and self_attn.w_scale is None
1576-
):
1577-
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1578-
if _is_hip:
1579-
self_attn.w_scale *= 2.0
1585+
self.post_load_weights()
15801586

15811587
def get_embed_and_head(self):
15821588
return self.model.embed_tokens.weight, self.lm_head.weight

0 commit comments

Comments
 (0)