Skip to content

Commit 7eeaee8

Browse files
YUNQIUGUOfacebook-github-bot
authored andcommitted
Add small M support (pytorch#3682)
Summary: Pull Request resolved: pytorch#3682 X-link: facebookresearch/FBGEMM#758 add small m (m = 2, 3, 4) support for fast gemv - bf16_fast_gemv [+] - bf16fp8bf16_fast_gemv[+] - fp8fp8bf16_fast_gemv[+] **(v20 perf analysis from quantize_bench**) | B | M | N | K | Kernel Name | Elapsed Time (ms) | TFLOPS | Bandwidth (GB/s) | |---|---|---|---|-------------|-------------------|--------|------------------| | 1 | 1 | 8192 | 1024 | bf16_baseline | 0.017 | 0.973 | 973.581 | | 1 | 1 | 8192 | 1024 | fp8fp8_oss_fast_gemv | 0.013 | 1.251 | 626.711 | | 1 | 1 | 8192 | 1024 | cuda_lite | 0.014 | 1.205 | 603.859 | | 1 | 1 | 8192 | 1024 | marlin_bf16i4 | 0.014 | 1.189 | 298.625 | | 1 | 1 | 8192 | 1024 | machete_bf16i4 | 0.014 | 1.189 | 298.669 | | 1 | 2 | 8192 | 1024 | bf16_baseline | 0.017 | 1.963 | 983.820 | | 1 | 2 | 8192 | 1024 | fp8fp8_oss_fast_gemv | 0.014 | 2.414 | 605.920 | | 1 | 2 | 8192 | 1024 | cuda_lite | 0.014 | 2.379 | 597.311 | | 1 | 2 | 8192 | 1024 | marlin_bf16i4 | 0.014 | 2.322 | 292.742 | | 1 | 2 | 8192 | 1024 | machete_bf16i4 | 0.014 | 2.345 | 295.741 | | 1 | 3 | 8192 | 1024 | bf16_baseline | 0.017 | 3.006 | 1005.276 | | 1 | 3 | 8192 | 1024 | fp8fp8_oss_fast_gemv | 0.014 | 3.513 | 589.214 | | 1 | 3 | 8192 | 1024 | cuda_lite | 0.015 | 3.381 | 566.948 | | 1 | 3 | 8192 | 1024 | marlin_bf16i4 | 0.014 | 3.474 | 293.277 | | 1 | 3 | 8192 | 1024 | machete_bf16i4 | 0.014 | 3.513 | 296.593 | | 1 | 4 | 8192 | 1024 | bf16_baseline | 0.017 | 3.920 | 984.419 | | 1 | 4 | 8192 | 1024 | fp8fp8_oss_fast_gemv | 0.015 | 4.466 | 562.896 | | 1 | 4 | 8192 | 1024 | cuda_lite | 0.016 | 4.100 | 516.728 | | 1 | 4 | 8192 | 1024 | marlin_bf16i4 | 0.014 | 4.629 | 294.426 | | 1 | 4 | 8192 | 1024 | machete_bf16i4 | 0.014 | 4.792 | 304.764 | | 1 | 1 | 8192 | 3584 | bf16_baseline | 0.044 | 1.327 | 1327.169 | | 1 | 1 | 8192 | 3584 | fp8fp8_oss_fast_gemv | 0.026 | 2.283 | 1142.422 | | 1 | 1 | 8192 | 3584 | cuda_lite | 0.026 | 2.298 | 1149.878 | | 1 | 1 | 8192 | 3584 | marlin_bf16i4 | 0.020 | 2.877 | 720.408 | | 1 | 1 | 8192 | 3584 | machete_bf16i4 | 0.024 | 2.468 | 617.894 | | 1 | 2 | 8192 | 3584 | bf16_baseline | 0.044 | 2.675 | 1338.550 | | 1 | 2 | 8192 | 3584 | fp8fp8_oss_fast_gemv | 0.026 | 4.512 | 1129.580 | | 1 | 2 | 8192 | 3584 | cuda_lite | 0.026 | 4.515 | 1130.280 | | 1 | 2 | 8192 | 3584 | marlin_bf16i4 | 0.020 | 5.743 | 720.190 | | 1 | 2 | 8192 | 3584 | machete_bf16i4 | 0.024 | 4.911 | 615.829 | | 1 | 3 | 8192 | 3584 | bf16_baseline | 0.044 | 4.014 | 1339.480 | | 1 | 3 | 8192 | 3584 | fp8fp8_oss_fast_gemv | 0.028 | 6.391 | 1067.367 | | 1 | 3 | 8192 | 3584 | cuda_lite | 0.027 | 6.471 | 1080.655 | | 1 | 3 | 8192 | 3584 | marlin_bf16i4 | 0.020 | 8.606 | 720.622 | | 1 | 3 | 8192 | 3584 | machete_bf16i4 | 0.024 | 7.366 | 616.763 | | 1 | 4 | 8192 | 3584 | bf16_baseline | 0.044 | 5.350 | 1339.637 | | 1 | 4 | 8192 | 3584 | fp8fp8_oss_fast_gemv | 0.028 | 8.275 | 1037.158 | | 1 | 4 | 8192 | 3584 | cuda_lite | 0.029 | 8.063 | 1010.621 | | 1 | 4 | 8192 | 3584 | marlin_bf16i4 | 0.020 | 11.460 | 720.846 | | 1 | 4 | 8192 | 3584 | machete_bf16i4 | 0.024 | 9.911 | 623.402 | | 1 | 1 | 1280 | 8192 | bf16_baseline | 0.024 | 0.872 | 872.425 | | 1 | 1 | 1280 | 8192 | fp8fp8_oss_fast_gemv | 0.015 | 1.403 | 702.176 | | 1 | 1 | 1280 | 8192 | cuda_lite | 0.015 | 1.421 | 711.264 | | 1 | 1 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 0.779 | 195.420 | | 1 | 1 | 1280 | 8192 | machete_bf16i4 | 0.025 | 0.837 | 209.928 | | 1 | 2 | 1280 | 8192 | bf16_baseline | 0.024 | 1.737 | 870.022 | | 1 | 2 | 1280 | 8192 | fp8fp8_oss_fast_gemv | 0.015 | 2.760 | 691.374 | | 1 | 2 | 1280 | 8192 | cuda_lite | 0.015 | 2.836 | 710.432 | | 1 | 2 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 1.558 | 196.179 | | 1 | 2 | 1280 | 8192 | machete_bf16i4 | 0.026 | 1.624 | 204.431 | | 1 | 3 | 1280 | 8192 | bf16_baseline | 0.024 | 2.594 | 866.953 | | 1 | 3 | 1280 | 8192 | fp8fp8_oss_fast_gemv | 0.015 | 4.094 | 684.375 | | 1 | 3 | 1280 | 8192 | cuda_lite | 0.015 | 4.167 | 696.571 | | 1 | 3 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 2.327 | 196.054 | | 1 | 3 | 1280 | 8192 | machete_bf16i4 | 0.026 | 2.458 | 207.069 | | 1 | 4 | 1280 | 8192 | bf16_baseline | 0.024 | 3.458 | 867.559 | | 1 | 4 | 1280 | 8192 | fp8fp8_oss_fast_gemv | 0.015 | 5.414 | 679.479 | | 1 | 4 | 1280 | 8192 | cuda_lite | 0.016 | 5.408 | 678.758 | | 1 | 4 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 3.069 | 194.570 | | 1 | 4 | 1280 | 8192 | machete_bf16i4 | 0.025 | 3.321 | 210.571 | | 1 | 1 | 7168 | 8192 | bf16_baseline | 0.073 | 1.612 | 1612.302 | | 1 | 1 | 7168 | 8192 | fp8fp8_oss_fast_gemv | 0.043 | 2.752 | 1376.396 | | 1 | 1 | 7168 | 8192 | cuda_lite | 0.044 | 2.685 | 1342.856 | | 1 | 1 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 3.580 | 896.051 | | 1 | 1 | 7168 | 8192 | machete_bf16i4 | 0.039 | 3.019 | 755.510 | | 1 | 2 | 7168 | 8192 | bf16_baseline | 0.073 | 3.227 | 1614.307 | | 1 | 2 | 7168 | 8192 | fp8fp8_oss_fast_gemv | 0.043 | 5.430 | 1358.541 | | 1 | 2 | 7168 | 8192 | cuda_lite | 0.044 | 5.324 | 1332.114 | | 1 | 2 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 7.214 | 903.651 | | 1 | 2 | 7168 | 8192 | machete_bf16i4 | 0.039 | 6.091 | 763.029 | | 1 | 3 | 7168 | 8192 | bf16_baseline | 0.072 | 4.863 | 1622.296 | | 1 | 3 | 7168 | 8192 | fp8fp8_oss_fast_gemv | 0.044 | 7.949 | 1326.423 | | 1 | 3 | 7168 | 8192 | cuda_lite | 0.044 | 7.954 | 1327.215 | | 1 | 3 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 10.816 | 904.127 | | 1 | 3 | 7168 | 8192 | machete_bf16i4 | 0.038 | 9.172 | 766.770 | | 1 | 4 | 7168 | 8192 | bf16_baseline | 0.073 | 6.452 | 1614.684 | | 1 | 4 | 7168 | 8192 | fp8fp8_oss_fast_gemv | 0.046 | 10.219 | 1279.299 | | 1 | 4 | 7168 | 8192 | cuda_lite | 0.047 | 9.904 | 1239.944 | | 1 | 4 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 14.345 | 900.287 | | 1 | 4 | 7168 | 8192 | machete_bf16i4 | 0.039 | 12.128 | 761.147 | Reviewed By: ipiszy Differential Revision: D69492556 fbshipit-source-id: 408b8325469f6993e7861064ecccce44bd5e3cf7
1 parent 49d6314 commit 7eeaee8

File tree

7 files changed

+587
-267
lines changed

7 files changed

+587
-267
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/bf16_fast_gemv.cu

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,46 @@ namespace fbgemm_gpu {
2121
// problem sizes we care about and selected the best elapsed time/bw
2222
// combination. See more in
2323
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
24+
namespace {
2425
dim3 get_best_block_dim(int m, int n, int k) {
2526
if (m == 1 && n == 1280 && k == 8192) {
26-
return dim3(128, 4);
27+
return dim3(128, 2);
2728
} else if (m == 1 && n == 8192 && k == 1024) {
28-
return dim3(32, 8);
29+
return dim3(64, 2);
2930
} else if (m == 1 && n == 7168 && k == 8192) {
30-
return dim3(256, 1);
31+
return dim3(128, 1);
3132
} else if (m == 1 && n == 8192 && k == 3584) {
3233
return dim3(64, 2);
34+
} else if (m == 2 && n == 1280 && k == 8192) {
35+
return dim3(256, 1);
36+
} else if (m == 2 && n == 8192 && k == 1024) {
37+
return dim3(64, 2);
38+
} else if (m == 2 && n == 7168 && k == 8192) {
39+
return dim3(256, 1);
40+
} else if (m == 2 && n == 8192 && k == 3584) {
41+
return dim3(64, 2);
42+
} else if (m == 3 && n == 1280 && k == 8192) {
43+
return dim3(256, 1);
44+
} else if (m == 3 && n == 8192 && k == 1024) {
45+
return dim3(64, 2);
46+
} else if (m == 3 && n == 7168 && k == 8192) {
47+
return dim3(256, 1);
48+
} else if (m == 3 && n == 8192 && k == 3584) {
49+
return dim3(64, 2);
50+
} else if (m == 4 && n == 1280 && k == 8192) {
51+
return dim3(256, 1);
52+
} else if (m == 4 && n == 8192 && k == 1024) {
53+
return dim3(64, 2);
54+
} else if (m == 4 && n == 7168 && k == 8192) {
55+
return dim3(128, 1);
56+
} else if (m == 4 && n == 8192 && k == 3584) {
57+
return dim3(64, 2);
3358
} else {
3459
// Default block dimensions
3560
return dim3(32, 4);
3661
}
3762
}
63+
} // namespace
3864

3965
at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
4066
// X: M x K
@@ -62,6 +88,8 @@ at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
6288
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
6389
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
6490
k,
91+
m,
92+
n,
6593
num_per_thread);
6694

6795
C10_CUDA_KERNEL_LAUNCH_CHECK();

fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/bf16fp8bf16_fast_gemv.cu

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,37 @@ namespace fbgemm_gpu {
2424
namespace {
2525
dim3 get_best_block_dim(int m, int n, int k) {
2626
if (m == 1 && n == 1280 && k == 8192) {
27-
return dim3(128, 1);
27+
return dim3(128, 2);
2828
} else if (m == 1 && n == 8192 && k == 1024) {
29-
return dim3(32, 4);
29+
return dim3(32, 8);
3030
} else if (m == 1 && n == 7168 && k == 8192) {
3131
return dim3(128, 1);
3232
} else if (m == 1 && n == 8192 && k == 3584) {
3333
return dim3(64, 2);
34+
} else if (m == 2 && n == 1280 && k == 8192) {
35+
return dim3(128, 1);
36+
} else if (m == 2 && n == 8192 && k == 1024) {
37+
return dim3(32, 8);
38+
} else if (m == 2 && n == 7168 && k == 8192) {
39+
return dim3(128, 1);
40+
} else if (m == 2 && n == 8192 && k == 3584) {
41+
return dim3(64, 2);
42+
} else if (m == 3 && n == 1280 && k == 8192) {
43+
return dim3(128, 2);
44+
} else if (m == 3 && n == 8192 && k == 1024) {
45+
return dim3(32, 8);
46+
} else if (m == 3 && n == 7168 && k == 8192) {
47+
return dim3(128, 1);
48+
} else if (m == 3 && n == 8192 && k == 3584) {
49+
return dim3(64, 2);
50+
} else if (m == 4 && n == 1280 && k == 8192) {
51+
return dim3(128, 2);
52+
} else if (m == 4 && n == 8192 && k == 1024) {
53+
return dim3(32, 8);
54+
} else if (m == 4 && n == 7168 && k == 8192) {
55+
return dim3(128, 1);
56+
} else if (m == 4 && n == 8192 && k == 3584) {
57+
return dim3(64, 2);
3458
} else {
3559
// Default block dimensions
3660
return dim3(32, 4);
@@ -65,6 +89,8 @@ bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor w_scale) {
6589
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
6690
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
6791
k,
92+
m,
93+
n,
6894
reinterpret_cast<float const*>(w_scale.data_ptr()),
6995
num_per_thread);
7096

fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/fp8fp8bf16_fast_gemv.cu

Lines changed: 127 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,159 @@
1717

1818
namespace fbgemm_gpu {
1919

20+
using SizeType32 = std::size_t;
21+
2022
// The heuristics are derived by sweeping over 4 different
2123
// problem sizes we care about and selected the best elapsed time/bw
2224
// combination. See more in
2325
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
2426
namespace {
2527
dim3 get_best_block_dim(int m, int n, int k) {
2628
if (m == 1 && n == 1280 && k == 8192) {
27-
return dim3(128, 1);
29+
return dim3(256, 1);
2830
} else if (m == 1 && n == 8192 && k == 1024) {
29-
return dim3(32, 32);
30-
} else if (m == 1 && n == 7168 && k == 8192) {
3131
return dim3(128, 1);
32+
} else if (m == 1 && n == 7168 && k == 8192) {
33+
return dim3(256, 1);
3234
} else if (m == 1 && n == 8192 && k == 3584) {
33-
return dim3(64, 2);
35+
return dim3(128, 1);
36+
} else if (m == 2 && n == 1280 && k == 8192) {
37+
return dim3(128, 1);
38+
} else if (m == 2 && n == 8192 && k == 1024) {
39+
return dim3(64, 1);
40+
} else if (m == 2 && n == 7168 && k == 8192) {
41+
return dim3(256, 1);
42+
} else if (m == 2 && n == 8192 && k == 3584) {
43+
return dim3(128, 1);
44+
} else if (m == 3 && n == 1280 && k == 8192) {
45+
return dim3(128, 1);
46+
} else if (m == 3 && n == 8192 && k == 1024) {
47+
return dim3(64, 1);
48+
} else if (m == 3 && n == 7168 && k == 8192) {
49+
return dim3(128, 1);
50+
} else if (m == 3 && n == 8192 && k == 3584) {
51+
return dim3(128, 1);
52+
} else if (m == 4 && n == 1280 && k == 8192) {
53+
return dim3(128, 1);
54+
} else if (m == 4 && n == 8192 && k == 1024) {
55+
return dim3(64, 1);
56+
} else if (m == 4 && n == 7168 && k == 8192) {
57+
return dim3(128, 1);
58+
} else if (m == 4 && n == 8192 && k == 3584) {
59+
return dim3(128, 1);
3460
} else {
3561
// Default block dimensions
36-
return dim3(32, 4);
62+
return dim3(32, 1);
3763
}
3864
}
3965
} // namespace
4066

41-
at::Tensor fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor scale) {
42-
// X: M x K
43-
// W: N x K
44-
auto m = X.size(0);
45-
auto n = W.size(0);
46-
auto k = W.size(1);
67+
template <SizeType32 TILE_M, SizeType32 TILE_N>
68+
void fp8fp8FastGemvKernel(
69+
cutlass::float_e4m3_t* mat,
70+
cutlass::float_e4m3_t* vec,
71+
__nv_bfloat16* res,
72+
const unsigned int k,
73+
const unsigned int m,
74+
const unsigned int n,
75+
float const* scale) {
76+
// each threadblock handles TILE_M * TILE_N dot products in the resulting
77+
// matrix.
78+
// block_size is represented as (block_dim.x, block_dim.y).
79+
// grid_dim is accordingly calculated based on the number of threadblocks
80+
// needed to cover the given problem size
81+
dim3 block_dim = get_best_block_dim(m, n, k);
82+
dim3 grid_dim(m / TILE_M, n / TILE_N * block_dim.y);
83+
// total number of memory loads needed per thread
84+
unsigned int num_iter_per_thread = ((k >> 4) + block_dim.x - 1) / block_dim.x;
85+
86+
check_if_valid_input_dimensions_fp8fp8(m, n, k, TILE_N, block_dim);
4787

48-
TORCH_CHECK(X.is_cuda() && X.is_contiguous());
49-
TORCH_CHECK(W.is_cuda() && W.is_contiguous());
88+
auto stream = at::cuda::getCurrentCUDAStream();
5089

51-
dim3 block_dim = get_best_block_dim(m, n, k);
90+
if (block_dim.x == 128) {
91+
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 128>
92+
<<<grid_dim, block_dim, 0, stream>>>(
93+
mat, vec, res, k, m, n, scale, num_iter_per_thread);
94+
C10_CUDA_KERNEL_LAUNCH_CHECK();
95+
} else if (block_dim.x == 64) {
96+
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 64>
97+
<<<grid_dim, block_dim, 0, stream>>>(
98+
mat, vec, res, k, m, n, scale, num_iter_per_thread);
99+
C10_CUDA_KERNEL_LAUNCH_CHECK();
100+
} else if (block_dim.x == 256) {
101+
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 256>
102+
<<<grid_dim, block_dim, 0, stream>>>(
103+
mat, vec, res, k, m, n, scale, num_iter_per_thread);
104+
C10_CUDA_KERNEL_LAUNCH_CHECK();
105+
} else {
106+
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 32>
107+
<<<grid_dim, block_dim, 0, stream>>>(
108+
mat, vec, res, k, m, n, scale, num_iter_per_thread);
109+
C10_CUDA_KERNEL_LAUNCH_CHECK();
110+
}
111+
}
112+
113+
template <SizeType32 TILE_M, SizeType32 TILE_N>
114+
bool fastGemvTemplateCaller(
115+
cutlass::float_e4m3_t* mat,
116+
cutlass::float_e4m3_t* vec,
117+
__nv_bfloat16* res,
118+
const unsigned int k,
119+
const unsigned int m,
120+
const unsigned int n,
121+
float const* scale) {
122+
if (m == TILE_M) {
123+
fp8fp8FastGemvKernel<TILE_M, TILE_N>(mat, vec, res, k, m, n, scale);
124+
return true;
125+
}
52126

53-
check_if_valid_block_dimensions(m, n, k, block_dim);
127+
if constexpr (TILE_M < MAX_M_SIZE) {
128+
return fastGemvTemplateCaller<TILE_M + 1, TILE_N>(
129+
mat, vec, res, k, m, n, scale);
130+
}
131+
return false;
132+
}
54133

55-
dim3 grid_dim(1, n / block_dim.y);
56-
unsigned int num_per_thread = k / block_dim.x;
134+
bool fastGemvLauncher(
135+
cutlass::float_e4m3_t* mat,
136+
cutlass::float_e4m3_t* vec,
137+
__nv_bfloat16* res,
138+
const unsigned int k,
139+
const unsigned int m,
140+
const unsigned int n,
141+
float const* scale) {
142+
// Note: based on sweeping result, heuristic TILE_N = 2 here gives best
143+
// performance over larger TILE_N value. this is potentially because smaller
144+
// tile_n leads to more threadblocks and thus increase the block concurrency.
145+
return fastGemvTemplateCaller</* TILE_M=*/1, /* TILE_N=*/2>(
146+
mat, vec, res, k, m, n, scale);
147+
}
57148

58-
auto stream = at::cuda::getCurrentCUDAStream();
149+
at::Tensor
150+
fp8fp8bf16_fast_gemv(at::Tensor XQ, at::Tensor WQ, at::Tensor scale) {
151+
const unsigned int m = XQ.size(0);
152+
const unsigned int n = WQ.size(0);
153+
const unsigned int k = WQ.size(1);
154+
155+
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
156+
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
157+
TORCH_CHECK(XQ.size(-1) == k);
59158

60-
auto Y = at::empty({m, n}, X.options().dtype(at::kBFloat16));
159+
auto Y = at::empty({m, n}, XQ.options().dtype(at::kBFloat16));
61160

62-
gemv_quantized_fp8_fp8<<<grid_dim, block_dim, 0, stream>>>(
63-
reinterpret_cast<cutlass::float_e4m3_t*>(W.data_ptr()), // mat
64-
reinterpret_cast<cutlass::float_e4m3_t*>(X.data_ptr()), // vec
161+
bool dispatched = fastGemvLauncher(
162+
reinterpret_cast<cutlass::float_e4m3_t*>(WQ.data_ptr()), // mat
163+
reinterpret_cast<cutlass::float_e4m3_t*>(XQ.data_ptr()), // vec
65164
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
66165
k,
67-
reinterpret_cast<float const*>(scale.data_ptr()),
68-
num_per_thread);
166+
m,
167+
n,
168+
reinterpret_cast<float const*>(scale.data_ptr()));
69169

70-
C10_CUDA_KERNEL_LAUNCH_CHECK();
170+
if (!dispatched) {
171+
throw std::runtime_error("f8f8bf16_fast_gemv cannot run.");
172+
}
71173

72174
return Y;
73175
}

fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/common_utils.h

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ namespace fbgemm_gpu {
1212

1313
namespace {
1414

15+
using SizeType32 = std::size_t;
16+
1517
void check_if_valid_block_dimensions(int m, int n, int k, dim3 block_dim) {
1618
TORCH_CHECK(
1719
n % block_dim.y == 0,
@@ -82,5 +84,82 @@ void check_if_valid_block_dimensions(int m, int n, int k, dim3 block_dim) {
8284
block_dim.y,
8385
".");
8486
}
87+
void check_if_valid_input_dimensions_fp8fp8(
88+
int m,
89+
int n,
90+
int k,
91+
SizeType32 TILE_N,
92+
dim3 block_dim) {
93+
TORCH_CHECK(
94+
m <= 4,
95+
"Invalid value for m: m (",
96+
m,
97+
") must not be greater than 4. The kernel cannot be run with the current value of m."
98+
" Please use an `m` smaller or equal to 4.");
99+
TORCH_CHECK(
100+
k % 16 == 0,
101+
"Invalid value for k: (",
102+
k,
103+
") must be divisible by 16.",
104+
" Please use a `k` that is divisble by 16, "
105+
" All current params - m: ",
106+
m,
107+
", n: ",
108+
n,
109+
", k: ",
110+
k,
111+
", block_dim.x: ",
112+
block_dim.x,
113+
", block_dim.y: ",
114+
block_dim.y,
115+
".");
116+
TORCH_CHECK(
117+
k % block_dim.x == 0,
118+
"Invalid block dimensions: k (",
119+
k,
120+
") must be divisible by block_dim.x (",
121+
block_dim.x,
122+
"). Received k: ",
123+
k,
124+
", block_dim.x: ",
125+
block_dim.x,
126+
" Please either use a `k` which is divisible by `block_dim.x`, or update "
127+
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
128+
" All current params - m: ",
129+
m,
130+
", n: ",
131+
n,
132+
", k: ",
133+
k,
134+
", block_dim.x: ",
135+
block_dim.x,
136+
", block_dim.y: ",
137+
block_dim.y,
138+
".");
139+
TORCH_CHECK(
140+
n % (TILE_N * block_dim.y) == 0,
141+
"Invalid block dimensions: n (",
142+
n,
143+
") must be divisible by TILE_N * block_dim.y (",
144+
TILE_N * block_dim.y,
145+
"). Received n: ",
146+
n,
147+
", block_dim.y: ",
148+
block_dim.y,
149+
", TILE_N: ",
150+
TILE_N,
151+
" Please use a `n` which is divisible by `TILE_N * block_dim.y`,"
152+
" All current params - m: ",
153+
m,
154+
", n: ",
155+
n,
156+
", k: ",
157+
k,
158+
", block_dim.x: ",
159+
block_dim.x,
160+
", block_dim.y: ",
161+
block_dim.y,
162+
".");
163+
}
85164
} // namespace
86165
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)