|
17 | 17 |
|
18 | 18 | namespace fbgemm_gpu {
|
19 | 19 |
|
| 20 | +using SizeType32 = std::size_t; |
| 21 | + |
20 | 22 | // The heuristics are derived by sweeping over 4 different
|
21 | 23 | // problem sizes we care about and selected the best elapsed time/bw
|
22 | 24 | // combination. See more in
|
23 | 25 | // deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
|
24 | 26 | namespace {
|
25 | 27 | dim3 get_best_block_dim(int m, int n, int k) {
|
26 | 28 | if (m == 1 && n == 1280 && k == 8192) {
|
27 |
| - return dim3(128, 1); |
| 29 | + return dim3(256, 1); |
28 | 30 | } else if (m == 1 && n == 8192 && k == 1024) {
|
29 |
| - return dim3(32, 32); |
30 |
| - } else if (m == 1 && n == 7168 && k == 8192) { |
31 | 31 | return dim3(128, 1);
|
| 32 | + } else if (m == 1 && n == 7168 && k == 8192) { |
| 33 | + return dim3(256, 1); |
32 | 34 | } else if (m == 1 && n == 8192 && k == 3584) {
|
33 |
| - return dim3(64, 2); |
| 35 | + return dim3(128, 1); |
| 36 | + } else if (m == 2 && n == 1280 && k == 8192) { |
| 37 | + return dim3(128, 1); |
| 38 | + } else if (m == 2 && n == 8192 && k == 1024) { |
| 39 | + return dim3(64, 1); |
| 40 | + } else if (m == 2 && n == 7168 && k == 8192) { |
| 41 | + return dim3(256, 1); |
| 42 | + } else if (m == 2 && n == 8192 && k == 3584) { |
| 43 | + return dim3(128, 1); |
| 44 | + } else if (m == 3 && n == 1280 && k == 8192) { |
| 45 | + return dim3(128, 1); |
| 46 | + } else if (m == 3 && n == 8192 && k == 1024) { |
| 47 | + return dim3(64, 1); |
| 48 | + } else if (m == 3 && n == 7168 && k == 8192) { |
| 49 | + return dim3(128, 1); |
| 50 | + } else if (m == 3 && n == 8192 && k == 3584) { |
| 51 | + return dim3(128, 1); |
| 52 | + } else if (m == 4 && n == 1280 && k == 8192) { |
| 53 | + return dim3(128, 1); |
| 54 | + } else if (m == 4 && n == 8192 && k == 1024) { |
| 55 | + return dim3(64, 1); |
| 56 | + } else if (m == 4 && n == 7168 && k == 8192) { |
| 57 | + return dim3(128, 1); |
| 58 | + } else if (m == 4 && n == 8192 && k == 3584) { |
| 59 | + return dim3(128, 1); |
34 | 60 | } else {
|
35 | 61 | // Default block dimensions
|
36 |
| - return dim3(32, 4); |
| 62 | + return dim3(32, 1); |
37 | 63 | }
|
38 | 64 | }
|
39 | 65 | } // namespace
|
40 | 66 |
|
41 |
| -at::Tensor fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor scale) { |
42 |
| - // X: M x K |
43 |
| - // W: N x K |
44 |
| - auto m = X.size(0); |
45 |
| - auto n = W.size(0); |
46 |
| - auto k = W.size(1); |
| 67 | +template <SizeType32 TILE_M, SizeType32 TILE_N> |
| 68 | +void fp8fp8FastGemvKernel( |
| 69 | + cutlass::float_e4m3_t* mat, |
| 70 | + cutlass::float_e4m3_t* vec, |
| 71 | + __nv_bfloat16* res, |
| 72 | + const unsigned int k, |
| 73 | + const unsigned int m, |
| 74 | + const unsigned int n, |
| 75 | + float const* scale) { |
| 76 | + // each threadblock handles TILE_M * TILE_N dot products in the resulting |
| 77 | + // matrix. |
| 78 | + // block_size is represented as (block_dim.x, block_dim.y). |
| 79 | + // grid_dim is accordingly calculated based on the number of threadblocks |
| 80 | + // needed to cover the given problem size |
| 81 | + dim3 block_dim = get_best_block_dim(m, n, k); |
| 82 | + dim3 grid_dim(m / TILE_M, n / TILE_N * block_dim.y); |
| 83 | + // total number of memory loads needed per thread |
| 84 | + unsigned int num_iter_per_thread = ((k >> 4) + block_dim.x - 1) / block_dim.x; |
| 85 | + |
| 86 | + check_if_valid_input_dimensions_fp8fp8(m, n, k, TILE_N, block_dim); |
47 | 87 |
|
48 |
| - TORCH_CHECK(X.is_cuda() && X.is_contiguous()); |
49 |
| - TORCH_CHECK(W.is_cuda() && W.is_contiguous()); |
| 88 | + auto stream = at::cuda::getCurrentCUDAStream(); |
50 | 89 |
|
51 |
| - dim3 block_dim = get_best_block_dim(m, n, k); |
| 90 | + if (block_dim.x == 128) { |
| 91 | + gemv_quantized_fp8_fp8<TILE_M, TILE_N, 128> |
| 92 | + <<<grid_dim, block_dim, 0, stream>>>( |
| 93 | + mat, vec, res, k, m, n, scale, num_iter_per_thread); |
| 94 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 95 | + } else if (block_dim.x == 64) { |
| 96 | + gemv_quantized_fp8_fp8<TILE_M, TILE_N, 64> |
| 97 | + <<<grid_dim, block_dim, 0, stream>>>( |
| 98 | + mat, vec, res, k, m, n, scale, num_iter_per_thread); |
| 99 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 100 | + } else if (block_dim.x == 256) { |
| 101 | + gemv_quantized_fp8_fp8<TILE_M, TILE_N, 256> |
| 102 | + <<<grid_dim, block_dim, 0, stream>>>( |
| 103 | + mat, vec, res, k, m, n, scale, num_iter_per_thread); |
| 104 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 105 | + } else { |
| 106 | + gemv_quantized_fp8_fp8<TILE_M, TILE_N, 32> |
| 107 | + <<<grid_dim, block_dim, 0, stream>>>( |
| 108 | + mat, vec, res, k, m, n, scale, num_iter_per_thread); |
| 109 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 110 | + } |
| 111 | +} |
| 112 | + |
| 113 | +template <SizeType32 TILE_M, SizeType32 TILE_N> |
| 114 | +bool fastGemvTemplateCaller( |
| 115 | + cutlass::float_e4m3_t* mat, |
| 116 | + cutlass::float_e4m3_t* vec, |
| 117 | + __nv_bfloat16* res, |
| 118 | + const unsigned int k, |
| 119 | + const unsigned int m, |
| 120 | + const unsigned int n, |
| 121 | + float const* scale) { |
| 122 | + if (m == TILE_M) { |
| 123 | + fp8fp8FastGemvKernel<TILE_M, TILE_N>(mat, vec, res, k, m, n, scale); |
| 124 | + return true; |
| 125 | + } |
52 | 126 |
|
53 |
| - check_if_valid_block_dimensions(m, n, k, block_dim); |
| 127 | + if constexpr (TILE_M < MAX_M_SIZE) { |
| 128 | + return fastGemvTemplateCaller<TILE_M + 1, TILE_N>( |
| 129 | + mat, vec, res, k, m, n, scale); |
| 130 | + } |
| 131 | + return false; |
| 132 | +} |
54 | 133 |
|
55 |
| - dim3 grid_dim(1, n / block_dim.y); |
56 |
| - unsigned int num_per_thread = k / block_dim.x; |
| 134 | +bool fastGemvLauncher( |
| 135 | + cutlass::float_e4m3_t* mat, |
| 136 | + cutlass::float_e4m3_t* vec, |
| 137 | + __nv_bfloat16* res, |
| 138 | + const unsigned int k, |
| 139 | + const unsigned int m, |
| 140 | + const unsigned int n, |
| 141 | + float const* scale) { |
| 142 | + // Note: based on sweeping result, heuristic TILE_N = 2 here gives best |
| 143 | + // performance over larger TILE_N value. this is potentially because smaller |
| 144 | + // tile_n leads to more threadblocks and thus increase the block concurrency. |
| 145 | + return fastGemvTemplateCaller</* TILE_M=*/1, /* TILE_N=*/2>( |
| 146 | + mat, vec, res, k, m, n, scale); |
| 147 | +} |
57 | 148 |
|
58 |
| - auto stream = at::cuda::getCurrentCUDAStream(); |
| 149 | +at::Tensor |
| 150 | +fp8fp8bf16_fast_gemv(at::Tensor XQ, at::Tensor WQ, at::Tensor scale) { |
| 151 | + const unsigned int m = XQ.size(0); |
| 152 | + const unsigned int n = WQ.size(0); |
| 153 | + const unsigned int k = WQ.size(1); |
| 154 | + |
| 155 | + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); |
| 156 | + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); |
| 157 | + TORCH_CHECK(XQ.size(-1) == k); |
59 | 158 |
|
60 |
| - auto Y = at::empty({m, n}, X.options().dtype(at::kBFloat16)); |
| 159 | + auto Y = at::empty({m, n}, XQ.options().dtype(at::kBFloat16)); |
61 | 160 |
|
62 |
| - gemv_quantized_fp8_fp8<<<grid_dim, block_dim, 0, stream>>>( |
63 |
| - reinterpret_cast<cutlass::float_e4m3_t*>(W.data_ptr()), // mat |
64 |
| - reinterpret_cast<cutlass::float_e4m3_t*>(X.data_ptr()), // vec |
| 161 | + bool dispatched = fastGemvLauncher( |
| 162 | + reinterpret_cast<cutlass::float_e4m3_t*>(WQ.data_ptr()), // mat |
| 163 | + reinterpret_cast<cutlass::float_e4m3_t*>(XQ.data_ptr()), // vec |
65 | 164 | reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
|
66 | 165 | k,
|
67 |
| - reinterpret_cast<float const*>(scale.data_ptr()), |
68 |
| - num_per_thread); |
| 166 | + m, |
| 167 | + n, |
| 168 | + reinterpret_cast<float const*>(scale.data_ptr())); |
69 | 169 |
|
70 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 170 | + if (!dispatched) { |
| 171 | + throw std::runtime_error("f8f8bf16_fast_gemv cannot run."); |
| 172 | + } |
71 | 173 |
|
72 | 174 | return Y;
|
73 | 175 | }
|
|
0 commit comments