@@ -1325,6 +1325,70 @@ def cuda(self) -> bool:
1325
1325
return True
1326
1326
1327
1327
1328
+ @register_quantize_op
1329
+ class F8I4ShuffledGroupedGemm (QuantizeOpBase ):
1330
+ """
1331
+ FP8 x Int4 mixed dtype grouped gemm with preshuffling.
1332
+ """
1333
+
1334
+ def preprocess (self , x , w ):
1335
+ assert isinstance (x , list ) and isinstance (
1336
+ w , list
1337
+ ), "Only supported for grouped inputs."
1338
+ m_values = [i .shape [0 ] for i in x ]
1339
+ # Convert m_values into offsets into grouped tensor.
1340
+ m_offsets = torch .tensor (np .cumsum (m_values )).to (
1341
+ dtype = torch .int64 , device = x [0 ].device
1342
+ )
1343
+ # Quantize weights.
1344
+ # TODO Only rowwise scaling is currently supported. This needs to be fixed.
1345
+ K = x [0 ].shape [- 1 ]
1346
+ wq , row_scale , group_scale = zip (
1347
+ * [quantize_int4_preshuffle (i , group_size = K ) for i in w ]
1348
+ )
1349
+ # Group weights as single tensor.
1350
+ wq = torch .stack (wq , dim = 0 ).contiguous ()
1351
+ row_scale = torch .stack (row_scale , dim = 0 ).contiguous ()
1352
+ group_scale = torch .stack (group_scale , dim = 0 ).contiguous ()
1353
+ # Also view input as flattened.
1354
+ x = torch .concat (x , dim = 0 ).contiguous ()
1355
+ # Return processed tensors.
1356
+ return x , wq , row_scale , group_scale , m_offsets
1357
+
1358
+ def quantize (self , x , wq , row_scale , group_scale , m_offsets ):
1359
+ B = x .shape [0 ]
1360
+ xq , x_scale = triton_quantize_fp8_row (x )
1361
+ x_scale = x_scale .view (B , - 1 )
1362
+ return xq , wq , x_scale , row_scale , group_scale , m_offsets
1363
+
1364
+ def compute (self , xq , wq , x_scale , row_scale , group_scale , m_offsets ):
1365
+ out = torch .ops .fbgemm .f8i4bf16_shuffled_grouped (
1366
+ xq , wq , x_scale , row_scale , group_scale , m_offsets
1367
+ )
1368
+ return out
1369
+
1370
+ def quantize_and_compute (self , x , wq , row_scale , group_scale , m_offsets ):
1371
+ xq , wq , x_scale , row_scale , group_scale , m_offsets = self .quantize (
1372
+ x , wq , row_scale , group_scale , m_offsets
1373
+ )
1374
+ return self .compute (xq , wq , x_scale , row_scale , group_scale , m_offsets )
1375
+
1376
+ @property
1377
+ def name (self ) -> str :
1378
+ if torch .version .cuda :
1379
+ return "cutlass_f8i4_grouped_preshuffle"
1380
+ else :
1381
+ return "ck_f8i4_grouped_preshuffle"
1382
+
1383
+ @property
1384
+ def hip (self ) -> bool :
1385
+ return False
1386
+
1387
+ @property
1388
+ def cuda (self ) -> bool :
1389
+ return True
1390
+
1391
+
1328
1392
@register_quantize_op
1329
1393
class BF16I4RowwiseGemm (F8I4RowwiseGemm ):
1330
1394
"""
0 commit comments