@@ -33,16 +33,17 @@ at::Tensor _f8i4bf16_shuffled(
33
33
at::Tensor XQ,
34
34
at::Tensor WQ,
35
35
at::Tensor x_scale,
36
- at::Tensor w_scale) {
36
+ at::Tensor w_scale,
37
+ at::Tensor w_scale_group) {
37
38
// Get shape information from input tensors.
38
39
int M = XQ.size (0 );
39
40
int K = XQ.size (1 );
40
41
int N = WQ.size (0 );
41
- // Make sure w_scale is in proper format.
42
+ // Make sure w_scale_group is in proper format.
42
43
TORCH_CHECK (
43
- w_scale .size (1 ) == 8 ,
44
- " Weights and scales must be prepacked with preshuffle_i4." );
45
- int num_groups = w_scale .size (0 );
44
+ w_scale_group .size (1 ) == 8 ,
45
+ " Weights and group scales must be prepacked with preshuffle_i4." );
46
+ int num_groups = w_scale_group .size (0 );
46
47
int group_size = K / num_groups;
47
48
// Allocate output.
48
49
at::Tensor Y = at::empty ({M, N}, XQ.options ().dtype (at::kBFloat16 ));
@@ -108,7 +109,15 @@ at::Tensor _f8i4bf16_shuffled(
108
109
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
109
110
110
111
// Define EVT for rowwise scaling.
111
- using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
112
+ // Implement rowwise scaling epilogue.
113
+ using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
114
+ 0 ,
115
+ TileShape,
116
+ ElementAccumulator,
117
+ ElementAccumulator,
118
+ cute::Stride<cute::Int<1 >, cute::Int<0 >, cute::Int<0 >>>;
119
+
120
+ using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
112
121
0 ,
113
122
TileShape,
114
123
ElementAccumulator,
@@ -119,12 +128,21 @@ at::Tensor _f8i4bf16_shuffled(
119
128
120
129
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
121
130
cutlass::multiplies,
122
- ElementC , // First stage output type.
131
+ ElementAccumulator , // First stage output type.
123
132
ElementAccumulator, // First stage input types.
124
133
cutlass::FloatRoundStyle::round_to_nearest>;
125
134
135
+ using EVTCompute0 =
136
+ cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
137
+
138
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
139
+ cutlass::multiplies,
140
+ ElementC,
141
+ ElementAccumulator, // Second stage input types.
142
+ cutlass::FloatRoundStyle::round_to_nearest>;
143
+
126
144
using EpilogueEVT =
127
- cutlass::epilogue::fusion::Sm90EVT<Compute0 , XScale, Accum >;
145
+ cutlass::epilogue::fusion::Sm90EVT<Compute1 , XScale, EVTCompute0 >;
128
146
129
147
using CollectiveEpilogue =
130
148
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -192,7 +210,8 @@ at::Tensor _f8i4bf16_shuffled(
192
210
layout_B_reordered,
193
211
reinterpret_cast <ElementA*>(XQ.data_ptr ()),
194
212
stride_A,
195
- reinterpret_cast <cutlass::Array<ElementScale, 8 >*>(w_scale.data_ptr ()),
213
+ reinterpret_cast <cutlass::Array<ElementScale, 8 >*>(
214
+ w_scale_group.data_ptr ()),
196
215
stride_S,
197
216
group_size},
198
217
{{},
@@ -202,8 +221,14 @@ at::Tensor _f8i4bf16_shuffled(
202
221
stride_C}};
203
222
204
223
arguments.epilogue .thread = {
205
- {reinterpret_cast <ElementAccumulator*>(x_scale.data_ptr ())}, // x_scale
206
- {}, // Accumulator
224
+ {reinterpret_cast <ElementAccumulator*>(w_scale.data_ptr ())}, // w_scale
225
+ // compute_0
226
+ {
227
+ {reinterpret_cast <ElementAccumulator*>(
228
+ x_scale.data_ptr ())}, // w_scale
229
+ {}, // Accumulator
230
+ {} // Multiplies
231
+ },
207
232
{}, // Multiplies
208
233
};
209
234
@@ -212,10 +237,11 @@ at::Tensor _f8i4bf16_shuffled(
212
237
213
238
// Using the arguments, query for extra workspace required for matrix
214
239
// multiplication computation
215
- size_t workspace_size = GemmShuffled::get_workspace_size (arguments);
240
+ int workspace_size = GemmShuffled::get_workspace_size (arguments);
216
241
217
242
// Allocate workspace memory
218
- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
243
+ at::Tensor workspace =
244
+ at::empty (workspace_size, XQ.options ().dtype (at::kByte ));
219
245
220
246
// Check the problem size is supported or not
221
247
cutlass::Status status = gemm.can_implement (arguments);
@@ -224,7 +250,7 @@ at::Tensor _f8i4bf16_shuffled(
224
250
}
225
251
226
252
// Initialize CUTLASS kernel with arguments and workspace pointer
227
- status = gemm.initialize (arguments, workspace.get ());
253
+ status = gemm.initialize (arguments, workspace.data_ptr ());
228
254
if (status != cutlass::Status::kSuccess ) {
229
255
throw std::runtime_error (" cutlass cannot initialize" );
230
256
}
@@ -245,54 +271,58 @@ at::Tensor f8i4bf16_shuffled(
245
271
at::Tensor XQ,
246
272
at::Tensor WQ,
247
273
at::Tensor x_scale,
248
- at::Tensor w_scale) {
274
+ at::Tensor w_scale,
275
+ at::Tensor w_scale_group) {
249
276
int M = XQ.size (0 );
250
277
int K = XQ.size (1 );
251
278
int N = WQ.size (0 );
252
279
// Use shape heuristics to dispatch to optimized kernel configuration.
253
280
if (M <= 16 ) {
254
- return _f8i4bf16_shuffled<64 , 16 , 2 , 1 , 1 , false >(XQ, WQ, x_scale, w_scale);
281
+ return _f8i4bf16_shuffled<64 , 16 , 2 , 1 , 1 , false >(
282
+ XQ, WQ, x_scale, w_scale, w_scale_group);
255
283
} else if (M <= 32 ) {
256
- return _f8i4bf16_shuffled<64 , 32 , 2 , 1 , 1 , false >(XQ, WQ, x_scale, w_scale);
284
+ return _f8i4bf16_shuffled<64 , 32 , 2 , 1 , 1 , false >(
285
+ XQ, WQ, x_scale, w_scale, w_scale_group);
257
286
} else if (M <= 64 ) {
258
- return _f8i4bf16_shuffled<64 , 64 , 2 , 1 , 1 , false >(XQ, WQ, x_scale, w_scale);
287
+ return _f8i4bf16_shuffled<64 , 64 , 2 , 1 , 1 , false >(
288
+ XQ, WQ, x_scale, w_scale, w_scale_group);
259
289
} else if (M <= 128 ) {
260
290
return _f8i4bf16_shuffled<64 , 128 , 2 , 1 , 1 , false >(
261
- XQ, WQ, x_scale, w_scale);
291
+ XQ, WQ, x_scale, w_scale, w_scale_group );
262
292
} else if (M <= 256 ) {
263
293
if (N <= 4096 ) {
264
294
return _f8i4bf16_shuffled<64 , 128 , 2 , 1 , 1 , false >(
265
- XQ, WQ, x_scale, w_scale);
295
+ XQ, WQ, x_scale, w_scale, w_scale_group );
266
296
} else {
267
297
return _f8i4bf16_shuffled<64 , 256 , 1 , 1 , 1 , false >(
268
- XQ, WQ, x_scale, w_scale);
298
+ XQ, WQ, x_scale, w_scale, w_scale_group );
269
299
}
270
300
} else if (M <= 512 ) {
271
301
if (N <= 4096 ) {
272
302
return _f8i4bf16_shuffled<64 , 256 , 2 , 1 , 1 , false >(
273
- XQ, WQ, x_scale, w_scale);
303
+ XQ, WQ, x_scale, w_scale, w_scale_group );
274
304
} else {
275
305
return _f8i4bf16_shuffled<128 , 256 , 2 , 1 , 1 , true >(
276
- XQ, WQ, x_scale, w_scale);
306
+ XQ, WQ, x_scale, w_scale, w_scale_group );
277
307
}
278
308
} else if (M <= 1024 ) {
279
309
if (N <= 1024 ) {
280
310
return _f8i4bf16_shuffled<64 , 128 , 2 , 1 , 1 , false >(
281
- XQ, WQ, x_scale, w_scale);
311
+ XQ, WQ, x_scale, w_scale, w_scale_group );
282
312
} else if (N <= 2048 ) {
283
313
return _f8i4bf16_shuffled<64 , 256 , 2 , 1 , 1 , false >(
284
- XQ, WQ, x_scale, w_scale);
314
+ XQ, WQ, x_scale, w_scale, w_scale_group );
285
315
} else {
286
316
return _f8i4bf16_shuffled<128 , 256 , 2 , 1 , 1 , true >(
287
- XQ, WQ, x_scale, w_scale);
317
+ XQ, WQ, x_scale, w_scale, w_scale_group );
288
318
}
289
319
} else {
290
320
if (N <= 1024 ) {
291
321
return _f8i4bf16_shuffled<64 , 256 , 2 , 1 , 1 , false >(
292
- XQ, WQ, x_scale, w_scale);
322
+ XQ, WQ, x_scale, w_scale, w_scale_group );
293
323
} else {
294
324
return _f8i4bf16_shuffled<128 , 256 , 2 , 1 , 1 , true >(
295
- XQ, WQ, x_scale, w_scale);
325
+ XQ, WQ, x_scale, w_scale, w_scale_group );
296
326
}
297
327
}
298
328
}
@@ -303,7 +333,8 @@ at::Tensor f8i4bf16_shuffled(
303
333
at::Tensor XQ,
304
334
at::Tensor WQ,
305
335
at::Tensor x_scale,
306
- at::Tensor w_scale) {
336
+ at::Tensor w_scale,
337
+ at::Tensor w_scale_group) {
307
338
throw std::runtime_error (
308
339
" CUDA version is older than 12.0" ); // requires CUDA>=12
309
340
}
0 commit comments