@@ -342,6 +342,7 @@ def fused_moe_kernel(
342
342
use_fp8_w8a8 : tl .constexpr ,
343
343
use_int8_w8a8 : tl .constexpr ,
344
344
use_int8_w8a16 : tl .constexpr ,
345
+ per_channel_quant : tl .constexpr ,
345
346
even_Ks : tl .constexpr ,
346
347
):
347
348
"""
@@ -416,20 +417,7 @@ def fused_moe_kernel(
416
417
)
417
418
b_scale = tl .load (b_scale_ptrs )
418
419
419
- if use_fp8_w8a8 :
420
- # block-wise
421
- if group_k > 0 and group_n > 0 :
422
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k ) * stride_asm
423
- offs_bsn = offs_bn // group_n
424
- b_scale_ptrs = (
425
- b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
426
- )
427
- # tensor-wise
428
- else :
429
- a_scale = tl .load (a_scale_ptr )
430
- b_scale = tl .load (b_scale_ptr + off_experts )
431
-
432
- if use_int8_w8a8 :
420
+ if use_fp8_w8a8 or use_int8_w8a8 :
433
421
# block-wise
434
422
if group_k > 0 and group_n > 0 :
435
423
a_scale_ptrs = a_scale_ptr + (offs_token // top_k ) * stride_asm
@@ -438,15 +426,18 @@ def fused_moe_kernel(
438
426
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
439
427
)
440
428
# channel-wise
441
- else :
442
- # Load per-column scale for weights
429
+ elif per_channel_quant :
443
430
b_scale_ptrs = (
444
431
b_scale_ptr + off_experts * stride_bse + offs_bn [None , :] * stride_bsn
445
432
)
446
433
b_scale = tl .load (b_scale_ptrs )
447
434
# Load per-token scale for activations
448
435
a_scale_ptrs = a_scale_ptr + (offs_token // top_k ) * stride_asm
449
436
a_scale = tl .load (a_scale_ptrs , mask = token_mask , other = 0.0 )[:, None ]
437
+ # tensor-wise
438
+ else :
439
+ a_scale = tl .load (a_scale_ptr )
440
+ b_scale = tl .load (b_scale_ptr + off_experts )
450
441
451
442
# -----------------------------------------------------------
452
443
# Iterate to compute a block of the C matrix.
@@ -753,6 +744,7 @@ def invoke_fused_moe_kernel(
753
744
use_int8_w8a8 : bool ,
754
745
use_int8_w8a16 : bool ,
755
746
use_int4_w4a16 : bool ,
747
+ per_channel_quant : bool ,
756
748
block_shape : Optional [List [int ]] = None ,
757
749
no_combine : bool = False ,
758
750
) -> None :
@@ -777,10 +769,15 @@ def invoke_fused_moe_kernel(
777
769
if block_shape is None :
778
770
# activation tensor-wise fp8 quantization, dynamic or static
779
771
padded_size = padding_size
772
+ # activations apply per-token quantization when weights apply per-channel quantization by default
780
773
if _is_cuda :
781
- A , A_scale = sgl_scaled_fp8_quant (A , A_scale )
774
+ A , A_scale = sgl_scaled_fp8_quant (
775
+ A , A_scale , use_per_token_if_dynamic = per_channel_quant
776
+ )
782
777
else :
783
- A , A_scale = vllm_ops .scaled_fp8_quant (A , A_scale )
778
+ A , A_scale = vllm_ops .scaled_fp8_quant (
779
+ A , A_scale , use_per_token_if_dynamic = per_channel_quant
780
+ )
784
781
else :
785
782
# activation block-wise fp8 quantization
786
783
assert len (block_shape ) == 2
@@ -796,6 +793,9 @@ def invoke_fused_moe_kernel(
796
793
assert B_scale is not None
797
794
if block_shape is None :
798
795
# activation channel-wise int8 quantization
796
+ assert (
797
+ per_channel_quant
798
+ ), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
799
799
A , A_scale = per_token_quant_int8 (A )
800
800
else :
801
801
# activation block-wise int8 quantization
@@ -904,6 +904,7 @@ def invoke_fused_moe_kernel(
904
904
use_fp8_w8a8 = use_fp8_w8a8 ,
905
905
use_int8_w8a8 = use_int8_w8a8 ,
906
906
use_int8_w8a16 = use_int8_w8a16 ,
907
+ per_channel_quant = per_channel_quant ,
907
908
even_Ks = even_Ks ,
908
909
** config ,
909
910
)
@@ -1086,6 +1087,7 @@ def inplace_fused_experts(
1086
1087
use_int8_w8a8 : bool = False ,
1087
1088
use_int8_w8a16 : bool = False ,
1088
1089
use_int4_w4a16 : bool = False ,
1090
+ per_channel_quant : bool = False ,
1089
1091
w1_scale : Optional [torch .Tensor ] = None ,
1090
1092
w2_scale : Optional [torch .Tensor ] = None ,
1091
1093
w1_zp : Optional [torch .Tensor ] = None ,
@@ -1107,6 +1109,7 @@ def inplace_fused_experts(
1107
1109
use_int8_w8a8 ,
1108
1110
use_int8_w8a16 ,
1109
1111
use_int4_w4a16 ,
1112
+ per_channel_quant ,
1110
1113
w1_scale ,
1111
1114
w2_scale ,
1112
1115
w1_zp ,
@@ -1129,6 +1132,7 @@ def inplace_fused_experts_fake(
1129
1132
use_int8_w8a8 : bool = False ,
1130
1133
use_int8_w8a16 : bool = False ,
1131
1134
use_int4_w4a16 : bool = False ,
1135
+ per_channel_quant : bool = False ,
1132
1136
w1_scale : Optional [torch .Tensor ] = None ,
1133
1137
w2_scale : Optional [torch .Tensor ] = None ,
1134
1138
w1_zp : Optional [torch .Tensor ] = None ,
@@ -1160,6 +1164,7 @@ def outplace_fused_experts(
1160
1164
use_int8_w8a8 : bool = False ,
1161
1165
use_int8_w8a16 : bool = False ,
1162
1166
use_int4_w4a16 : bool = False ,
1167
+ per_channel_quant : bool = False ,
1163
1168
w1_scale : Optional [torch .Tensor ] = None ,
1164
1169
w2_scale : Optional [torch .Tensor ] = None ,
1165
1170
w1_zp : Optional [torch .Tensor ] = None ,
@@ -1182,6 +1187,7 @@ def outplace_fused_experts(
1182
1187
use_int8_w8a8 ,
1183
1188
use_int8_w8a16 ,
1184
1189
use_int4_w4a16 ,
1190
+ per_channel_quant ,
1185
1191
w1_scale ,
1186
1192
w2_scale ,
1187
1193
w1_zp ,
@@ -1205,6 +1211,7 @@ def outplace_fused_experts_fake(
1205
1211
use_int8_w8a8 : bool = False ,
1206
1212
use_int8_w8a16 : bool = False ,
1207
1213
use_int4_w4a16 : bool = False ,
1214
+ per_channel_quant : bool = False ,
1208
1215
w1_scale : Optional [torch .Tensor ] = None ,
1209
1216
w2_scale : Optional [torch .Tensor ] = None ,
1210
1217
w1_zp : Optional [torch .Tensor ] = None ,
@@ -1238,6 +1245,7 @@ def fused_experts(
1238
1245
use_int8_w8a8 : bool = False ,
1239
1246
use_int8_w8a16 : bool = False ,
1240
1247
use_int4_w4a16 : bool = False ,
1248
+ per_channel_quant : bool = False ,
1241
1249
w1_scale : Optional [torch .Tensor ] = None ,
1242
1250
w2_scale : Optional [torch .Tensor ] = None ,
1243
1251
w1_zp : Optional [torch .Tensor ] = None ,
@@ -1261,6 +1269,7 @@ def fused_experts(
1261
1269
use_int8_w8a8 ,
1262
1270
use_int8_w8a16 ,
1263
1271
use_int4_w4a16 ,
1272
+ per_channel_quant ,
1264
1273
w1_scale ,
1265
1274
w2_scale ,
1266
1275
w1_zp ,
@@ -1283,6 +1292,7 @@ def fused_experts(
1283
1292
use_int8_w8a8 ,
1284
1293
use_int8_w8a16 ,
1285
1294
use_int4_w4a16 ,
1295
+ per_channel_quant ,
1286
1296
w1_scale ,
1287
1297
w2_scale ,
1288
1298
w1_zp ,
@@ -1307,6 +1317,7 @@ def fused_experts_impl(
1307
1317
use_int8_w8a8 : bool = False ,
1308
1318
use_int8_w8a16 : bool = False ,
1309
1319
use_int4_w4a16 : bool = False ,
1320
+ per_channel_quant : bool = False ,
1310
1321
w1_scale : Optional [torch .Tensor ] = None ,
1311
1322
w2_scale : Optional [torch .Tensor ] = None ,
1312
1323
w1_zp : Optional [torch .Tensor ] = None ,
@@ -1443,6 +1454,7 @@ def fused_experts_impl(
1443
1454
use_int8_w8a8 = use_int8_w8a8 ,
1444
1455
use_int8_w8a16 = use_int8_w8a16 ,
1445
1456
use_int4_w4a16 = use_int4_w4a16 ,
1457
+ per_channel_quant = per_channel_quant ,
1446
1458
block_shape = block_shape ,
1447
1459
)
1448
1460
if activation == "silu" :
@@ -1486,6 +1498,7 @@ def fused_experts_impl(
1486
1498
use_int8_w8a8 = use_int8_w8a8 ,
1487
1499
use_int8_w8a16 = use_int8_w8a16 ,
1488
1500
use_int4_w4a16 = use_int4_w4a16 ,
1501
+ per_channel_quant = per_channel_quant ,
1489
1502
block_shape = block_shape ,
1490
1503
)
1491
1504
@@ -1532,6 +1545,7 @@ def fused_moe(
1532
1545
use_int8_w8a8 : bool = False ,
1533
1546
use_int8_w8a16 : bool = False ,
1534
1547
use_int4_w4a16 : bool = False ,
1548
+ per_channel_quant : bool = False ,
1535
1549
w1_scale : Optional [torch .Tensor ] = None ,
1536
1550
w2_scale : Optional [torch .Tensor ] = None ,
1537
1551
w1_zp : Optional [torch .Tensor ] = None ,
@@ -1608,6 +1622,7 @@ def fused_moe(
1608
1622
use_int8_w8a8 = use_int8_w8a8 ,
1609
1623
use_int8_w8a16 = use_int8_w8a16 ,
1610
1624
use_int4_w4a16 = use_int4_w4a16 ,
1625
+ per_channel_quant = per_channel_quant ,
1611
1626
w1_scale = w1_scale ,
1612
1627
w2_scale = w2_scale ,
1613
1628
w1_zp = w1_zp ,
0 commit comments