@@ -726,7 +726,8 @@ def fp8_loopover_bmm(
726
726
torch .testing .assert_close (y_ref , y_fp8 , atol = 8.0e-2 , rtol = 8.0e-2 )
727
727
728
728
@unittest .skipIf (
729
- not torch .version .cuda , "Skip on AMD: GMM ops are not yet suported."
729
+ not torch .version .cuda and torch .version .hip < "6.2" ,
730
+ "Skip on AMD with < RoCM 6.2" ,
730
731
)
731
732
@settings (deadline = None )
732
733
@given (
@@ -805,63 +806,39 @@ def test_fp8_grouped_gemm(
805
806
w_scale_group = torch .unbind (torch .stack (w_scale_group , dim = 0 ).contiguous ())
806
807
807
808
# FP8 grouped gemm kernel
809
+ fp8_args = (
810
+ [
811
+ xq_group ,
812
+ wq_group ,
813
+ x_scale_group ,
814
+ w_scale_group ,
815
+ zero_start_index_M if use_padding_zeros else None ,
816
+ ]
817
+ if use_dynamic
818
+ else [xq_group , wq_group , x_scale_group , w_scale_group ]
819
+ )
820
+ fp8_op = (
821
+ torch .ops .fbgemm .f8f8bf16_rowwise_grouped_dynamic
822
+ if use_dynamic
823
+ else torch .ops .fbgemm .f8f8bf16_rowwise_grouped
824
+ )
808
825
if use_cudagraph :
809
- if use_padding_zeros :
810
- # warmup
811
- torch .ops .fbgemm .f8f8bf16_rowwise_grouped_dynamic (
812
- xq_group ,
813
- wq_group ,
814
- x_scale_group ,
815
- w_scale_group ,
816
- zero_start_index_M ,
817
- )
818
- # With cudagraph
819
- g = torch .cuda .CUDAGraph ()
820
- with torch .cuda .graph (g ):
821
- y_fp8_group = torch .ops .fbgemm .f8f8bf16_rowwise_grouped_dynamic (
822
- xq_group ,
823
- wq_group ,
824
- x_scale_group ,
825
- w_scale_group ,
826
- zero_start_index_M ,
827
- )
828
- g .replay ()
829
- y_fp8_group = y_fp8_group .unbind (dim = 0 )
830
- else :
831
- # warmup
832
- torch .ops .fbgemm .f8f8bf16_rowwise_grouped (
833
- xq_group ,
834
- wq_group ,
835
- x_scale_group ,
836
- w_scale_group ,
837
- )
838
- # With cudagraph
839
- g = torch .cuda .CUDAGraph ()
840
- with torch .cuda .graph (g ):
841
- y_fp8_group = torch .ops .fbgemm .f8f8bf16_rowwise_grouped (
842
- xq_group ,
843
- wq_group ,
844
- x_scale_group ,
845
- w_scale_group ,
846
- )
847
- g .replay ()
826
+ # warmup
827
+ fp8_op (* fp8_args )
828
+ # With cudagraph
829
+ g = torch .cuda .CUDAGraph ()
830
+ with torch .cuda .graph (g ):
831
+ y_fp8_group = fp8_op (* fp8_args )
832
+ g .replay ()
848
833
else :
849
- if use_padding_zeros :
850
- y_fp8_group = torch .ops .fbgemm .f8f8bf16_rowwise_grouped_dynamic (
851
- xq_group ,
852
- wq_group ,
853
- x_scale_group ,
854
- w_scale_group ,
855
- zero_start_index_M ,
856
- )
857
- y_fp8_group = y_fp8_group .unbind (dim = 0 )
834
+ y_fp8_group = fp8_op (* fp8_args )
835
+
836
+ # Massage output into proper format.
837
+ if not isinstance (y_fp8_group , (tuple , list )):
838
+ if y_fp8_group .ndim == 2 :
839
+ y_fp8_group = torch .split (y_fp8_group , tuple (ms .tolist ()), dim = 0 )
858
840
else :
859
- y_fp8_group = torch .ops .fbgemm .f8f8bf16_rowwise_grouped (
860
- xq_group ,
861
- wq_group ,
862
- x_scale_group ,
863
- w_scale_group ,
864
- )
841
+ y_fp8_group = torch .unbind (y_fp8_group )
865
842
866
843
# BF16 grouped gemm kernel
867
844
bf16_args = (
0 commit comments