@@ -1380,6 +1380,84 @@ def forward(
1380
1380
input_ids , hidden_states , self .lm_head , forward_batch
1381
1381
)
1382
1382
1383
+ def post_load_weights (self ):
1384
+
1385
+ # Perform post-processing after loading weights
1386
+
1387
+ if not global_server_args_dict ["disable_mla" ]:
1388
+ for layer_id in range (self .config .num_hidden_layers ):
1389
+ self_attn = self .model .layers [layer_id ].self_attn
1390
+ if hasattr (self_attn .kv_b_proj , "qweight" ):
1391
+ # AWQ compatible
1392
+ if _is_cuda :
1393
+ w = awq_dequantize (
1394
+ self_attn .kv_b_proj .qweight ,
1395
+ self_attn .kv_b_proj .scales ,
1396
+ self_attn .kv_b_proj .qzeros ,
1397
+ ).T
1398
+ else :
1399
+ w = ops .awq_dequantize (
1400
+ self_attn .kv_b_proj .qweight ,
1401
+ self_attn .kv_b_proj .scales ,
1402
+ self_attn .kv_b_proj .qzeros ,
1403
+ 0 ,
1404
+ 0 ,
1405
+ 0 ,
1406
+ ).T
1407
+ else :
1408
+ w = self_attn .kv_b_proj .weight
1409
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1410
+ # This may affect the accuracy of fp8 model.
1411
+ if hasattr (self .quant_config , "weight_block_size" ) and w .dtype in (
1412
+ torch .float8_e4m3fn ,
1413
+ torch .float8_e4m3fnuz ,
1414
+ ):
1415
+ weight_block_size = self .quant_config .weight_block_size
1416
+ if weight_block_size is not None :
1417
+ assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1418
+ if _is_hip :
1419
+ weight , weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
1420
+ weight = w ,
1421
+ weight_scale = self_attn .kv_b_proj .weight_scale_inv ,
1422
+ input_scale = None ,
1423
+ )
1424
+ else :
1425
+ weight = w
1426
+ weight_scale = self_attn .kv_b_proj .weight_scale_inv
1427
+
1428
+ w , scale = block_quant_to_tensor_quant (
1429
+ weight , weight_scale , weight_block_size
1430
+ )
1431
+ self_attn .w_scale = scale
1432
+ if w .dtype == torch .int8 :
1433
+ if hasattr (self .quant_config , "weight_block_size" ):
1434
+ # block-wise int8 need it
1435
+ weight_block_size = self .quant_config .weight_block_size
1436
+ if weight_block_size is not None :
1437
+ assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1438
+ weight = w
1439
+ weight_scale = self_attn .kv_b_proj .weight_scale_inv
1440
+ w = int8_block_dequant (
1441
+ weight , weight_scale , weight_block_size
1442
+ ).to (torch .bfloat16 )
1443
+ else :
1444
+ # channel-wise int8 need it
1445
+ w = w .to (torch .bfloat16 ) * self_attn .kv_b_proj .weight_scale .to (
1446
+ torch .bfloat16
1447
+ )
1448
+ w_kc , w_vc = w .unflatten (
1449
+ 0 , (- 1 , self_attn .qk_nope_head_dim + self_attn .v_head_dim )
1450
+ ).split ([self_attn .qk_nope_head_dim , self_attn .v_head_dim ], dim = 1 )
1451
+ self_attn .w_kc = w_kc .transpose (1 , 2 ).contiguous ().transpose (1 , 2 )
1452
+ self_attn .w_vc = w_vc .contiguous ().transpose (1 , 2 )
1453
+ if (
1454
+ hasattr (self_attn .kv_b_proj , "weight_scale" )
1455
+ and self_attn .w_scale is None
1456
+ ):
1457
+ self_attn .w_scale = self_attn .kv_b_proj .weight_scale
1458
+ if _is_hip :
1459
+ self_attn .w_scale *= 2.0
1460
+
1383
1461
def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
1384
1462
stacked_params_mapping = [
1385
1463
# (param_name, shard_name, shard_id)
@@ -1504,79 +1582,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1504
1582
)
1505
1583
weight_loader (param , loaded_weight )
1506
1584
1507
- if not global_server_args_dict ["disable_mla" ]:
1508
- for layer_id in range (self .config .num_hidden_layers ):
1509
- self_attn = self .model .layers [layer_id ].self_attn
1510
- if hasattr (self_attn .kv_b_proj , "qweight" ):
1511
- # AWQ compatible
1512
- if _is_cuda :
1513
- w = awq_dequantize (
1514
- self_attn .kv_b_proj .qweight ,
1515
- self_attn .kv_b_proj .scales ,
1516
- self_attn .kv_b_proj .qzeros ,
1517
- ).T
1518
- else :
1519
- w = ops .awq_dequantize (
1520
- self_attn .kv_b_proj .qweight ,
1521
- self_attn .kv_b_proj .scales ,
1522
- self_attn .kv_b_proj .qzeros ,
1523
- 0 ,
1524
- 0 ,
1525
- 0 ,
1526
- ).T
1527
- else :
1528
- w = self_attn .kv_b_proj .weight
1529
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1530
- # This may affect the accuracy of fp8 model.
1531
- if hasattr (self .quant_config , "weight_block_size" ) and w .dtype in (
1532
- torch .float8_e4m3fn ,
1533
- torch .float8_e4m3fnuz ,
1534
- ):
1535
- weight_block_size = self .quant_config .weight_block_size
1536
- if weight_block_size is not None :
1537
- assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1538
- if _is_hip :
1539
- weight , weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
1540
- weight = w ,
1541
- weight_scale = self_attn .kv_b_proj .weight_scale_inv ,
1542
- input_scale = None ,
1543
- )
1544
- else :
1545
- weight = w
1546
- weight_scale = self_attn .kv_b_proj .weight_scale_inv
1547
-
1548
- w , scale = block_quant_to_tensor_quant (
1549
- weight , weight_scale , weight_block_size
1550
- )
1551
- self_attn .w_scale = scale
1552
- if w .dtype == torch .int8 :
1553
- if hasattr (self .quant_config , "weight_block_size" ):
1554
- # block-wise int8 need it
1555
- weight_block_size = self .quant_config .weight_block_size
1556
- if weight_block_size is not None :
1557
- assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1558
- weight = w
1559
- weight_scale = self_attn .kv_b_proj .weight_scale_inv
1560
- w = int8_block_dequant (
1561
- weight , weight_scale , weight_block_size
1562
- ).to (torch .bfloat16 )
1563
- else :
1564
- # channel-wise int8 need it
1565
- w = w .to (torch .bfloat16 ) * self_attn .kv_b_proj .weight_scale .to (
1566
- torch .bfloat16
1567
- )
1568
- w_kc , w_vc = w .unflatten (
1569
- 0 , (- 1 , self_attn .qk_nope_head_dim + self_attn .v_head_dim )
1570
- ).split ([self_attn .qk_nope_head_dim , self_attn .v_head_dim ], dim = 1 )
1571
- self_attn .w_kc = w_kc .transpose (1 , 2 ).contiguous ().transpose (1 , 2 )
1572
- self_attn .w_vc = w_vc .contiguous ().transpose (1 , 2 )
1573
- if (
1574
- hasattr (self_attn .kv_b_proj , "weight_scale" )
1575
- and self_attn .w_scale is None
1576
- ):
1577
- self_attn .w_scale = self_attn .kv_b_proj .weight_scale
1578
- if _is_hip :
1579
- self_attn .w_scale *= 2.0
1585
+ self .post_load_weights ()
1580
1586
1581
1587
def get_embed_and_head (self ):
1582
1588
return self .model .embed_tokens .weight , self .lm_head .weight
0 commit comments