@@ -36,6 +36,7 @@ def _test_grouped_gemm_fp8_rowwise(
36
36
shape : Tuple [int , int , int , int ],
37
37
device : torch .device ,
38
38
fast_accu : bool ,
39
+ use_warp_specialization : bool ,
39
40
) -> None :
40
41
G , M , N , K = shape
41
42
a = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
@@ -62,6 +63,7 @@ def _test_grouped_gemm_fp8_rowwise(
62
63
a_scale ,
63
64
b_scale ,
64
65
use_fast_accum = fast_accu ,
66
+ _use_warp_specialization = use_warp_specialization ,
65
67
)
66
68
self .assertTrue (result .shape == (M , N ))
67
69
@@ -85,17 +87,22 @@ def _test_grouped_gemm_fp8_rowwise(
85
87
for G in (1 , 4 , 16 ):
86
88
for M in (64 , 512 ):
87
89
for fast_accu in (True , False ):
88
- logging .info (
89
- f"Testing FP8 GMM with G={ G } , M={ M } , FastAccu={ fast_accu } "
90
- )
91
- _test_grouped_gemm_fp8_rowwise (
92
- (G , M , 256 , 256 ), torch .device ("cuda" ), fast_accu = fast_accu
93
- )
90
+ for ws in (True , False ):
91
+ logging .info (
92
+ f"Testing FP8 GMM with G={ G } , M={ M } , FastAccu={ fast_accu } "
93
+ )
94
+ _test_grouped_gemm_fp8_rowwise (
95
+ (G , M , 256 , 256 ),
96
+ torch .device ("cuda" ),
97
+ fast_accu = fast_accu ,
98
+ use_warp_specialization = ws ,
99
+ )
94
100
95
101
def test_grouped_gemm_bf16 (self ) -> None :
96
102
def _test_grouped_gemm_bf16 (
97
103
shape : Tuple [int , int , int , int ],
98
104
device : torch .device ,
105
+ use_warp_specialization : bool ,
99
106
) -> None :
100
107
G , M , N , K = shape
101
108
a = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
@@ -116,6 +123,7 @@ def _test_grouped_gemm_bf16(
116
123
a ,
117
124
b ,
118
125
m_sizes ,
126
+ _use_warp_specialization = use_warp_specialization ,
119
127
)
120
128
self .assertTrue (result .shape == (M , N ))
121
129
@@ -131,5 +139,10 @@ def _test_grouped_gemm_bf16(
131
139
132
140
for G in (1 , 4 , 16 ):
133
141
for M in (64 , 512 ):
134
- logging .info (f"Testing BF16 GMM with G={ G } , M={ M } " )
135
- _test_grouped_gemm_bf16 ((G , M , 256 , 256 ), torch .device ("cuda" ))
142
+ for ws in (True , False ):
143
+ logging .info (f"Testing BF16 GMM with G={ G } , M={ M } " )
144
+ _test_grouped_gemm_bf16 (
145
+ (G , M , 256 , 256 ),
146
+ torch .device ("cuda" ),
147
+ use_warp_specialization = ws ,
148
+ )
0 commit comments