Skip to content

Commit aec40d1

Browse files
jwfrommfacebook-github-bot
authored andcommitted
FBGEMM Add Columnwise Weight Scaling to F8I4 GEMM (pytorch#3766)
Summary: X-link: facebookresearch/FBGEMM#847 One of the new interesting changes in the preshuffled F8I4 kernel is that group scales are downcast to FP8. This has the risk of running into dynamic range issues and impacting accuracy. We can mitigate this risk by adding FP32 columnwise scaling to the output. Fortunately, we can do this using EVT so the performance impact is negligible. Reviewed By: jiawenliu64 Differential Revision: D70587477
1 parent b44b473 commit aec40d1

File tree

3 files changed

+77
-39
lines changed

3 files changed

+77
-39
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,27 +1297,33 @@ def _int4_row_quantize(
12971297
out = out.to(dtype=torch.int8).reshape(x.shape)
12981298

12991299
# Scales should be in [num_groups, N] layout.
1300-
scales = scales.view(x.shape[0], -1).t().contiguous()
1300+
scales = scales.view(x.shape[0], -1).t().contiguous().to(torch.float8_e4m3fn)
13011301

13021302
return out, scales
13031303

13041304
def quantize(self, x, w):
13051305
# Quantize both input tensors.
13061306
xq, x_scale = quantize_fp8_row(x)
1307-
wq, w_scale = self._int4_row_quantize(w)
1307+
# Weight quantization happens in two steps. First we quantize to fp8
1308+
# then to int4.
1309+
wq, w_scale = quantize_fp8_row(w)
1310+
# Now quantize to int4 with group scaling.
1311+
wq, w_scale_group = self._int4_row_quantize(wq)
13081312
# Pack int4 values together.
13091313
wq = self._pack_int4(wq)
13101314
# Shuffle weights and scales for faster compute.
1311-
wq, w_scale = torch.ops.fbgemm.preshuffle_i4(wq, w_scale)
1312-
return xq, wq, x_scale, w_scale
1315+
wq, w_scale_group = torch.ops.fbgemm.preshuffle_i4(wq, w_scale_group)
1316+
return xq, wq, x_scale, w_scale, w_scale_group
13131317

1314-
def compute(self, xq, wq, x_scale, w_scale):
1315-
out = torch.ops.fbgemm.f8i4bf16_shuffled(xq, wq, x_scale, w_scale)
1318+
def compute(self, xq, wq, x_scale, w_scale, w_scale_group):
1319+
out = torch.ops.fbgemm.f8i4bf16_shuffled(
1320+
xq, wq, x_scale, w_scale, w_scale_group
1321+
)
13161322
return out
13171323

13181324
def quantize_and_compute(self, x, w):
1319-
xq, wq, x_scale, w_scale = self.quantize(x, w)
1320-
return self.compute(xq, wq, x_scale, w_scale)
1325+
xq, wq, x_scale, w_scale, w_scale_group = self.quantize(x, w)
1326+
return self.compute(xq, wq, x_scale, w_scale, w_scale_group)
13211327

13221328
@property
13231329
def name(self) -> str:

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,17 @@ at::Tensor _f8i4bf16_shuffled(
3333
at::Tensor XQ,
3434
at::Tensor WQ,
3535
at::Tensor x_scale,
36-
at::Tensor w_scale) {
36+
at::Tensor w_scale,
37+
at::Tensor w_scale_group) {
3738
// Get shape information from input tensors.
3839
int M = XQ.size(0);
3940
int K = XQ.size(1);
4041
int N = WQ.size(0);
41-
// Make sure w_scale is in proper format.
42+
// Make sure w_scale_group is in proper format.
4243
TORCH_CHECK(
43-
w_scale.size(1) == 8,
44-
"Weights and scales must be prepacked with preshuffle_i4.");
45-
int num_groups = w_scale.size(0);
44+
w_scale_group.size(1) == 8,
45+
"Weights and group scales must be prepacked with preshuffle_i4.");
46+
int num_groups = w_scale_group.size(0);
4647
int group_size = K / num_groups;
4748
// Allocate output.
4849
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
@@ -108,7 +109,15 @@ at::Tensor _f8i4bf16_shuffled(
108109
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
109110

110111
// Define EVT for rowwise scaling.
111-
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
112+
// Implement rowwise scaling epilogue.
113+
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
114+
0,
115+
TileShape,
116+
ElementAccumulator,
117+
ElementAccumulator,
118+
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
119+
120+
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
112121
0,
113122
TileShape,
114123
ElementAccumulator,
@@ -119,12 +128,21 @@ at::Tensor _f8i4bf16_shuffled(
119128

120129
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
121130
cutlass::multiplies,
122-
ElementC, // First stage output type.
131+
ElementAccumulator, // First stage output type.
123132
ElementAccumulator, // First stage input types.
124133
cutlass::FloatRoundStyle::round_to_nearest>;
125134

135+
using EVTCompute0 =
136+
cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
137+
138+
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
139+
cutlass::multiplies,
140+
ElementC,
141+
ElementAccumulator, // Second stage input types.
142+
cutlass::FloatRoundStyle::round_to_nearest>;
143+
126144
using EpilogueEVT =
127-
cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>;
145+
cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
128146

129147
using CollectiveEpilogue =
130148
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -192,7 +210,8 @@ at::Tensor _f8i4bf16_shuffled(
192210
layout_B_reordered,
193211
reinterpret_cast<ElementA*>(XQ.data_ptr()),
194212
stride_A,
195-
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()),
213+
reinterpret_cast<cutlass::Array<ElementScale, 8>*>(
214+
w_scale_group.data_ptr()),
196215
stride_S,
197216
group_size},
198217
{{},
@@ -202,8 +221,14 @@ at::Tensor _f8i4bf16_shuffled(
202221
stride_C}};
203222

204223
arguments.epilogue.thread = {
205-
{reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale
206-
{}, // Accumulator
224+
{reinterpret_cast<ElementAccumulator*>(w_scale.data_ptr())}, // w_scale
225+
// compute_0
226+
{
227+
{reinterpret_cast<ElementAccumulator*>(
228+
x_scale.data_ptr())}, // w_scale
229+
{}, // Accumulator
230+
{} // Multiplies
231+
},
207232
{}, // Multiplies
208233
};
209234

@@ -212,10 +237,11 @@ at::Tensor _f8i4bf16_shuffled(
212237

213238
// Using the arguments, query for extra workspace required for matrix
214239
// multiplication computation
215-
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
240+
int workspace_size = GemmShuffled::get_workspace_size(arguments);
216241

217242
// Allocate workspace memory
218-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
243+
at::Tensor workspace =
244+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
219245

220246
// Check the problem size is supported or not
221247
cutlass::Status status = gemm.can_implement(arguments);
@@ -224,7 +250,7 @@ at::Tensor _f8i4bf16_shuffled(
224250
}
225251

226252
// Initialize CUTLASS kernel with arguments and workspace pointer
227-
status = gemm.initialize(arguments, workspace.get());
253+
status = gemm.initialize(arguments, workspace.data_ptr());
228254
if (status != cutlass::Status::kSuccess) {
229255
throw std::runtime_error("cutlass cannot initialize");
230256
}
@@ -245,54 +271,58 @@ at::Tensor f8i4bf16_shuffled(
245271
at::Tensor XQ,
246272
at::Tensor WQ,
247273
at::Tensor x_scale,
248-
at::Tensor w_scale) {
274+
at::Tensor w_scale,
275+
at::Tensor w_scale_group) {
249276
int M = XQ.size(0);
250277
int K = XQ.size(1);
251278
int N = WQ.size(0);
252279
// Use shape heuristics to dispatch to optimized kernel configuration.
253280
if (M <= 16) {
254-
return _f8i4bf16_shuffled<64, 16, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale);
281+
return _f8i4bf16_shuffled<64, 16, 2, 1, 1, false>(
282+
XQ, WQ, x_scale, w_scale, w_scale_group);
255283
} else if (M <= 32) {
256-
return _f8i4bf16_shuffled<64, 32, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale);
284+
return _f8i4bf16_shuffled<64, 32, 2, 1, 1, false>(
285+
XQ, WQ, x_scale, w_scale, w_scale_group);
257286
} else if (M <= 64) {
258-
return _f8i4bf16_shuffled<64, 64, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale);
287+
return _f8i4bf16_shuffled<64, 64, 2, 1, 1, false>(
288+
XQ, WQ, x_scale, w_scale, w_scale_group);
259289
} else if (M <= 128) {
260290
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
261-
XQ, WQ, x_scale, w_scale);
291+
XQ, WQ, x_scale, w_scale, w_scale_group);
262292
} else if (M <= 256) {
263293
if (N <= 4096) {
264294
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
265-
XQ, WQ, x_scale, w_scale);
295+
XQ, WQ, x_scale, w_scale, w_scale_group);
266296
} else {
267297
return _f8i4bf16_shuffled<64, 256, 1, 1, 1, false>(
268-
XQ, WQ, x_scale, w_scale);
298+
XQ, WQ, x_scale, w_scale, w_scale_group);
269299
}
270300
} else if (M <= 512) {
271301
if (N <= 4096) {
272302
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
273-
XQ, WQ, x_scale, w_scale);
303+
XQ, WQ, x_scale, w_scale, w_scale_group);
274304
} else {
275305
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
276-
XQ, WQ, x_scale, w_scale);
306+
XQ, WQ, x_scale, w_scale, w_scale_group);
277307
}
278308
} else if (M <= 1024) {
279309
if (N <= 1024) {
280310
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
281-
XQ, WQ, x_scale, w_scale);
311+
XQ, WQ, x_scale, w_scale, w_scale_group);
282312
} else if (N <= 2048) {
283313
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
284-
XQ, WQ, x_scale, w_scale);
314+
XQ, WQ, x_scale, w_scale, w_scale_group);
285315
} else {
286316
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
287-
XQ, WQ, x_scale, w_scale);
317+
XQ, WQ, x_scale, w_scale, w_scale_group);
288318
}
289319
} else {
290320
if (N <= 1024) {
291321
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
292-
XQ, WQ, x_scale, w_scale);
322+
XQ, WQ, x_scale, w_scale, w_scale_group);
293323
} else {
294324
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
295-
XQ, WQ, x_scale, w_scale);
325+
XQ, WQ, x_scale, w_scale, w_scale_group);
296326
}
297327
}
298328
}
@@ -303,7 +333,8 @@ at::Tensor f8i4bf16_shuffled(
303333
at::Tensor XQ,
304334
at::Tensor WQ,
305335
at::Tensor x_scale,
306-
at::Tensor w_scale) {
336+
at::Tensor w_scale,
337+
at::Tensor w_scale_group) {
307338
throw std::runtime_error(
308339
"CUDA version is older than 12.0"); // requires CUDA>=12
309340
}

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ at::Tensor f8i4bf16_shuffled(
136136
at::Tensor XQ,
137137
at::Tensor WQ,
138138
at::Tensor x_scale,
139-
at::Tensor w_scale);
139+
at::Tensor w_scale,
140+
at::Tensor w_scale_group);
140141
std::tuple<at::Tensor, at::Tensor> preshuffle_i4(
141142
at::Tensor WQ,
142143
at::Tensor w_scale);
@@ -198,7 +199,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
198199
m.def(
199200
"f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor");
200201
m.def(
201-
"f8i4bf16_shuffled(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale) -> Tensor");
202+
"f8i4bf16_shuffled(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_scale_group) -> Tensor");
202203
m.def("preshuffle_i4(Tensor WQ, Tensor w_scale) -> (Tensor, Tensor)");
203204
m.def("bf16_fast_gemv(Tensor X, Tensor W) -> Tensor");
204205
m.def("bf16fp8bf16_fast_gemv(Tensor X, Tensor W, Tensor w_scale) -> Tensor");

0 commit comments

Comments
 (0)