Skip to content

Commit 7880eea

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Use Int64 Indexing in Grouped Gemm
Summary: For very large sequence length workloads, its possible for int32 arithmetic to overflow, especially as we often use M*N size tensors in grouped gemm. This diff replaces int32 indexing with int64 to avoid this problem. Differential Revision: D72465728
1 parent b25dec3 commit 7880eea

File tree

5 files changed

+277
-280
lines changed

5 files changed

+277
-280
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip

Lines changed: 79 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,30 @@ using CDataType = ck::bhalf_t;
3535

3636
// Define a custom hash function for std::tuple<int, int, int>
3737
struct IntTupleHash {
38-
size_t operator()(const std::tuple<int, int>& t) const {
39-
auto hash1 = std::hash<int>{}(std::get<0>(t));
40-
auto hash2 = std::hash<int>{}(std::get<1>(t));
38+
size_t operator()(const std::tuple<int64_t, int64_t>& t) const {
39+
auto hash1 = std::hash<int64_t>{}(std::get<0>(t));
40+
auto hash2 = std::hash<int64_t>{}(std::get<1>(t));
4141
return hash1 ^ hash2;
4242
}
43-
size_t operator()(const std::tuple<int, int, int>& t) const {
44-
auto hash1 = std::hash<int>{}(std::get<0>(t));
45-
auto hash2 = std::hash<int>{}(std::get<1>(t));
46-
auto hash3 = std::hash<int>{}(std::get<2>(t));
43+
size_t operator()(const std::tuple<int64_t, int64_t, int64_t>& t) const {
44+
auto hash1 = std::hash<int64_t>{}(std::get<0>(t));
45+
auto hash2 = std::hash<int64_t>{}(std::get<1>(t));
46+
auto hash3 = std::hash<int64_t>{}(std::get<2>(t));
4747
return hash1 ^ hash2 ^ hash3;
4848
}
49-
size_t operator()(const std::tuple<int, int, int, int>& t) const {
50-
auto hash1 = std::hash<int>{}(std::get<0>(t));
51-
auto hash2 = std::hash<int>{}(std::get<1>(t));
52-
auto hash3 = std::hash<int>{}(std::get<2>(t));
53-
auto hash4 = std::hash<int>{}(std::get<3>(t));
49+
size_t operator()(const std::tuple<int64_t, int64_t, int64_t, int64_t>& t) const {
50+
auto hash1 = std::hash<int64_t>{}(std::get<0>(t));
51+
auto hash2 = std::hash<int64_t>{}(std::get<1>(t));
52+
auto hash3 = std::hash<int64_t>{}(std::get<2>(t));
53+
auto hash4 = std::hash<int64_t>{}(std::get<3>(t));
5454
return hash1 ^ hash2 ^ hash3 ^ hash4;
5555
}
5656
};
5757

5858
// For certain high priority shapes, we directly map to the best kernel rather
5959
// than use heuristics.
6060
template <typename InputType, typename OutputType>
61-
static const std::unordered_map<std::tuple<int, int, int, int>, GroupedKernel<InputType, OutputType>, IntTupleHash> bf16_grouped_lookup_dispatch = {
61+
static const std::unordered_map<std::tuple<int64_t, int64_t, int64_t, int64_t>, GroupedKernel<InputType, OutputType>, IntTupleHash> bf16_grouped_lookup_dispatch = {
6262
{{16,16,2048,5120},bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2<InputType, OutputType>},
6363
{{16,16,5120,1024},bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2<InputType, OutputType>},
6464
{{16,16,16384,5120},bf16_grouped_64x16x32x128_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2<InputType, OutputType>},
@@ -132,20 +132,20 @@ static const std::unordered_map<std::tuple<int, int, int, int>, GroupedKernel<In
132132

133133

134134
// Helper function to return the next largest power of 2
135-
static constexpr int nextPow2(unsigned int num)
135+
static constexpr int64_t nextPow2(int64_t num)
136136
{
137137
if (num <= 1)
138138
return 1;
139139
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
140140
}
141141
template <typename InputType, typename OutputType>
142-
GroupedKernel<InputType, OutputType> grouped_heuristic_dispatch(int G, int total_M, int N, int K) {
142+
GroupedKernel<InputType, OutputType> grouped_heuristic_dispatch(int64_t G, int64_t total_M, int64_t N, int64_t K) {
143143
// We use shape heuristics to find the best kernel.
144144
// To do this, we divide by the size of M and find the best
145145
// option within that grouping.
146146

147147
// First check if this shape is available in the direct lookup.
148-
int padded_m = nextPow2(total_M);
148+
int64_t padded_m = nextPow2(total_M);
149149
padded_m = padded_m < G ? G : padded_m;
150150
padded_m = padded_m > 8192 ? 8192 : padded_m;
151151
auto it = bf16_grouped_lookup_dispatch<InputType, OutputType>.find({G, padded_m, N, K});
@@ -163,16 +163,16 @@ __global__ void set_kernel_args_kernel(
163163
ADataType* A,
164164
BDataType* B,
165165
CDataType* output,
166-
int M,
167-
int N,
168-
int K) {
166+
int64_t M,
167+
int64_t N,
168+
int64_t K) {
169169
int idx = blockIdx.x * blockDim.x + threadIdx.x;
170170
// Each kernel annoyingly can only set the kernel args for one group.
171171
// This could only be avoided with complicated memory management.
172172
if (idx == 0) {
173173
// Write kernel arguments directly to memory.
174174
KernelArguments kernel_group_args = {
175-
A, B, {}, output, M, N, K, K, K, {}, N};
175+
A, B, {}, output, int(M), int(N), int(K), int(K), int(K), {}, int(N)};
176176
kernel_args[0] = kernel_group_args;
177177
}
178178
}
@@ -184,32 +184,32 @@ void set_static_kernel_args(
184184
at::Tensor output) {
185185
// Get current cuda stream.
186186
auto stream = at::cuda::getCurrentHIPStream().stream();
187-
int group_count = A.size();
187+
int64_t group_count = A.size();
188188
// When group count is large, we can more efficiently initialize
189189
// by doing host setup and a memcpy. This is only viable if cuda
190190
// graphs arent being used.
191-
int output_offset = 0;
191+
int64_t output_offset = 0;
192192
if (group_count >= 16 && stream == 0) {
193193
std::vector<KernelArguments> ggemm_kargs;
194194
ggemm_kargs.reserve(group_count);
195195

196196
// Iterate over inputs and get group information.
197197
for (int i = 0; i < group_count; i++) {
198-
int M = A[i].size(0);
199-
int K = A[i].size(1);
200-
int N = B[i].size(0);
198+
int64_t M = A[i].size(0);
199+
int64_t K = A[i].size(1);
200+
int64_t N = B[i].size(0);
201201
KernelArguments group_args = {
202202
reinterpret_cast<ADataType*>(A[i].data_ptr()),
203203
reinterpret_cast<BDataType*>(B[i].data_ptr()),
204204
{},
205205
reinterpret_cast<CDataType*>(output.data_ptr()) + output_offset,
206-
M,
207-
N,
208-
K,
209-
K,
210-
K,
206+
int(M),
207+
int(N),
208+
int(K),
209+
int(K),
210+
int(K),
211211
{},
212-
N};
212+
int(N)};
213213
output_offset += M * N;
214214
ggemm_kargs.push_back(group_args);
215215
}
@@ -224,9 +224,9 @@ void set_static_kernel_args(
224224
// Using multiple kernels this way allows us to support arbitrary M,N,K.
225225
// For some reason, this approach is faster than using hipmemcpy.
226226
for (int i = 0; i < group_count; i++) {
227-
int M = A[i].size(0);
228-
int K = A[i].size(1);
229-
int N = B[i].size(0);
227+
int64_t M = A[i].size(0);
228+
int64_t K = A[i].size(1);
229+
int64_t N = B[i].size(0);
230230
// Launch kernel to set kernel arguments.
231231
set_kernel_args_kernel<<<1, 1, 0, stream>>>(
232232
reinterpret_cast<KernelArguments*>(
@@ -249,27 +249,27 @@ __global__ void set_kernel_args_fixed_nk_kernel(
249249
BDataType* B,
250250
CDataType* output,
251251
int64_t* prepad_M,
252-
int M,
253-
int N,
254-
int K,
255-
int group_count) {
252+
int64_t M,
253+
int64_t N,
254+
int64_t K,
255+
int64_t group_count) {
256256
int group_idx = blockIdx.x * blockDim.x + threadIdx.x;
257257
// Each thread is responsible for setting up the arguments for one group.
258258
if (group_idx < group_count) {
259259
// Compute offsets for this group.
260-
int group_M = prepad_M[group_idx];
260+
int64_t group_M = prepad_M[group_idx];
261261
KernelArguments kernel_group_args = {
262262
A + (group_idx * M * K),
263263
B + (group_idx * N * K),
264264
{},
265265
output + (group_idx * M * N),
266-
group_M,
267-
N,
268-
K,
269-
K,
270-
K,
266+
int(group_M),
267+
int(N),
268+
int(K),
269+
int(K),
270+
int(K),
271271
{},
272-
N};
272+
int(N)};
273273
// Write kernel args to memory.
274274
kernel_args[group_idx] = kernel_group_args;
275275
}
@@ -281,16 +281,16 @@ __global__ void set_kernel_args_m_sizes_kernel(
281281
BDataType* B,
282282
CDataType* output,
283283
int64_t* M_sizes,
284-
int M,
285-
int N,
286-
int K,
287-
int group_count) {
284+
int64_t M,
285+
int64_t N,
286+
int64_t K,
287+
int64_t group_count) {
288288
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
289289
// Each thread is responsible for setting up the arguments for one group.
290290
if (thread_idx < group_count) {
291291
// Get M information for this group.
292-
int kernel_M = M_sizes[thread_idx];
293-
int offset_M = 0;
292+
int64_t kernel_M = M_sizes[thread_idx];
293+
int64_t offset_M = 0;
294294
// Offset is computed by finding the sum of previous group Ms.
295295
for (int i = 0; i < thread_idx; i++) {
296296
offset_M += M_sizes[i];
@@ -300,13 +300,13 @@ __global__ void set_kernel_args_m_sizes_kernel(
300300
B + (thread_idx * N * K),
301301
{},
302302
output + (offset_M * N),
303-
kernel_M,
304-
N,
305-
K,
306-
K,
307-
K,
303+
int(kernel_M),
304+
int(N),
305+
int(K),
306+
int(K),
307+
int(K),
308308
{},
309-
N};
309+
int(N)};
310310
// Write kernel args to memory.
311311
kernel_args[thread_idx] = kernel_group_args;
312312
}
@@ -334,9 +334,9 @@ void set_dynamic_kernel_args(
334334

335335
// We assume that M, N, and K are fixed across groups.
336336
// The actual m values are sstored in the passed M tensor.
337-
int M = A.size(1);
338-
int K = A.size(2);
339-
int N = B.size(1);
337+
int64_t M = A.size(1);
338+
int64_t K = A.size(2);
339+
int64_t N = B.size(1);
340340

341341
// Launch a kernel that sets kernel argument memory.
342342
set_kernel_args_fixed_nk_kernel<<<1, group_count, 0, stream>>>(
@@ -365,9 +365,9 @@ at::Tensor get_stacked_kernel_args(
365365
{static_cast<long>(group_count * sizeof(KernelArguments))},
366366
A.options().dtype(at::kByte));
367367

368-
int M = A.size(A.dim() - 2);
369-
int K = B.size(2);
370-
int N = B.size(1);
368+
int64_t M = A.size(A.dim() - 2);
369+
int64_t K = B.size(2);
370+
int64_t N = B.size(1);
371371

372372
set_kernel_args_m_sizes_kernel<<<1, group_count, 0, stream>>>(
373373
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
@@ -408,8 +408,8 @@ OutputType _bf16bf16bf16_grouped(
408408
int64_t total_output_size = 0;
409409
int64_t total_M = 0;
410410
for (int i = 0; i < group_count; ++i) {
411-
int M = A[i].size(0);
412-
int N = B[i].size(0);
411+
int64_t M = A[i].size(0);
412+
int64_t N = B[i].size(0);
413413
total_M += M;
414414
const int64_t output_size = M * N;
415415
total_output_size += output_size;
@@ -428,9 +428,9 @@ OutputType _bf16bf16bf16_grouped(
428428

429429
// Perform shape lookup to find best kernel.
430430
// We use the largest of each shape for heuristics.
431-
int MaxM = 0;
432-
int MaxN = 0;
433-
int MaxK = 0;
431+
int64_t MaxM = 0;
432+
int64_t MaxN = 0;
433+
int64_t MaxK = 0;
434434
for (int i = 0; i < group_count; i++) {
435435
MaxM = max(MaxM, A[i].size(0));
436436
MaxN = max(MaxN, B[i].size(0));
@@ -473,10 +473,10 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
473473
// First confirm that there are the same number of groups in all inputs.
474474
TORCH_CHECK(
475475
A.size(0) == B.size(0), "A and B must have the same number of groups.");
476-
int group_count = A.size(0);
477-
int M = A.size(1);
478-
int N = B.size(1);
479-
int K = B.size(2);
476+
int64_t group_count = A.size(0);
477+
int64_t M = A.size(1);
478+
int64_t N = B.size(1);
479+
int64_t K = B.size(2);
480480
TORCH_CHECK(A.is_cuda() && A.is_contiguous());
481481
TORCH_CHECK(A.dim() == 3, "Inputs must be 3D [G, M, K].");
482482
TORCH_CHECK(A.dtype() == at::kBFloat16, "Inputs must be type bfloat16.");
@@ -499,9 +499,9 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
499499

500500
// Perform shape lookup to find best kernel.
501501
// We use the largest of each shape for heuristics.
502-
int MaxM = 0;
503-
int MaxN = 0;
504-
int MaxK = 0;
502+
int64_t MaxM = 0;
503+
int64_t MaxN = 0;
504+
int64_t MaxK = 0;
505505
for (int i = 0; i < group_count; i++) {
506506
MaxM = max(MaxM, A[i].size(0));
507507
MaxN = max(MaxN, B[i].size(0));
@@ -519,12 +519,12 @@ at::Tensor bf16bf16bf16_grouped_stacked(
519519
at::Tensor M_sizes) {
520520
// Check that input datatypes are valid.
521521
// First confirm that there are the same number of groups in all inputs.
522-
int group_count = M_sizes.size(0);
522+
int64_t group_count = M_sizes.size(0);
523523
// X is expected to be shape [total_M, K].
524-
int total_M = X.size(0);
524+
int64_t total_M = X.size(0);
525525
// W is expected to be shape [G, N, K].
526-
int N = W.size(1);
527-
int K = X.size(1);
526+
int64_t N = W.size(1);
527+
int64_t K = X.size(1);
528528
TORCH_CHECK(W.size(0) == group_count,
529529
"All inputs must have the same number of groups.");
530530

0 commit comments

Comments
 (0)