@@ -35,30 +35,30 @@ using CDataType = ck::bhalf_t;
35
35
36
36
// Define a custom hash function for std::tuple<int, int, int>
37
37
struct IntTupleHash {
38
- size_t operator ()(const std::tuple<int , int >& t) const {
39
- auto hash1 = std::hash<int >{}(std::get<0 >(t));
40
- auto hash2 = std::hash<int >{}(std::get<1 >(t));
38
+ size_t operator ()(const std::tuple<int64_t , int64_t >& t) const {
39
+ auto hash1 = std::hash<int64_t >{}(std::get<0 >(t));
40
+ auto hash2 = std::hash<int64_t >{}(std::get<1 >(t));
41
41
return hash1 ^ hash2;
42
42
}
43
- size_t operator ()(const std::tuple<int , int , int >& t) const {
44
- auto hash1 = std::hash<int >{}(std::get<0 >(t));
45
- auto hash2 = std::hash<int >{}(std::get<1 >(t));
46
- auto hash3 = std::hash<int >{}(std::get<2 >(t));
43
+ size_t operator ()(const std::tuple<int64_t , int64_t , int64_t >& t) const {
44
+ auto hash1 = std::hash<int64_t >{}(std::get<0 >(t));
45
+ auto hash2 = std::hash<int64_t >{}(std::get<1 >(t));
46
+ auto hash3 = std::hash<int64_t >{}(std::get<2 >(t));
47
47
return hash1 ^ hash2 ^ hash3;
48
48
}
49
- size_t operator ()(const std::tuple<int , int , int , int >& t) const {
50
- auto hash1 = std::hash<int >{}(std::get<0 >(t));
51
- auto hash2 = std::hash<int >{}(std::get<1 >(t));
52
- auto hash3 = std::hash<int >{}(std::get<2 >(t));
53
- auto hash4 = std::hash<int >{}(std::get<3 >(t));
49
+ size_t operator ()(const std::tuple<int64_t , int64_t , int64_t , int64_t >& t) const {
50
+ auto hash1 = std::hash<int64_t >{}(std::get<0 >(t));
51
+ auto hash2 = std::hash<int64_t >{}(std::get<1 >(t));
52
+ auto hash3 = std::hash<int64_t >{}(std::get<2 >(t));
53
+ auto hash4 = std::hash<int64_t >{}(std::get<3 >(t));
54
54
return hash1 ^ hash2 ^ hash3 ^ hash4;
55
55
}
56
56
};
57
57
58
58
// For certain high priority shapes, we directly map to the best kernel rather
59
59
// than use heuristics.
60
60
template <typename InputType, typename OutputType>
61
- static const std::unordered_map<std::tuple<int , int , int , int >, GroupedKernel<InputType, OutputType>, IntTupleHash> bf16_grouped_lookup_dispatch = {
61
+ static const std::unordered_map<std::tuple<int64_t , int64_t , int64_t , int64_t >, GroupedKernel<InputType, OutputType>, IntTupleHash> bf16_grouped_lookup_dispatch = {
62
62
{{16 ,16 ,2048 ,5120 },bf16_grouped_128x16x64x128_16x16_1x2_16x8x1_16x8x1_1x16x1x8_8x8x1_1x2_intrawave_v2<InputType, OutputType>},
63
63
{{16 ,16 ,5120 ,1024 },bf16_grouped_64x16x16x128_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_interwave_v2<InputType, OutputType>},
64
64
{{16 ,16 ,16384 ,5120 },bf16_grouped_64x16x32x128_16x16_1x2_16x4x1_16x4x1_1x16x1x4_8x8x1_1x2_intrawave_v2<InputType, OutputType>},
@@ -132,20 +132,20 @@ static const std::unordered_map<std::tuple<int, int, int, int>, GroupedKernel<In
132
132
133
133
134
134
// Helper function to return the next largest power of 2
135
- static constexpr int nextPow2 (unsigned int num)
135
+ static constexpr int64_t nextPow2 (int64_t num)
136
136
{
137
137
if (num <= 1 )
138
138
return 1 ;
139
139
return 1 << (CHAR_BIT * sizeof (num) - __builtin_clz (num - 1 ));
140
140
}
141
141
template <typename InputType, typename OutputType>
142
- GroupedKernel<InputType, OutputType> grouped_heuristic_dispatch (int G, int total_M, int N, int K) {
142
+ GroupedKernel<InputType, OutputType> grouped_heuristic_dispatch (int64_t G, int64_t total_M, int64_t N, int64_t K) {
143
143
// We use shape heuristics to find the best kernel.
144
144
// To do this, we divide by the size of M and find the best
145
145
// option within that grouping.
146
146
147
147
// First check if this shape is available in the direct lookup.
148
- int padded_m = nextPow2 (total_M);
148
+ int64_t padded_m = nextPow2 (total_M);
149
149
padded_m = padded_m < G ? G : padded_m;
150
150
padded_m = padded_m > 8192 ? 8192 : padded_m;
151
151
auto it = bf16_grouped_lookup_dispatch<InputType, OutputType>.find ({G, padded_m, N, K});
@@ -163,16 +163,16 @@ __global__ void set_kernel_args_kernel(
163
163
ADataType* A,
164
164
BDataType* B,
165
165
CDataType* output,
166
- int M,
167
- int N,
168
- int K) {
166
+ int64_t M,
167
+ int64_t N,
168
+ int64_t K) {
169
169
int idx = blockIdx.x * blockDim.x + threadIdx.x ;
170
170
// Each kernel annoyingly can only set the kernel args for one group.
171
171
// This could only be avoided with complicated memory management.
172
172
if (idx == 0 ) {
173
173
// Write kernel arguments directly to memory.
174
174
KernelArguments kernel_group_args = {
175
- A, B, {}, output, M, N, K, K, K , {}, N };
175
+ A, B, {}, output, int (M), int (N), int (K), int (K), int (K) , {}, int (N) };
176
176
kernel_args[0 ] = kernel_group_args;
177
177
}
178
178
}
@@ -184,32 +184,32 @@ void set_static_kernel_args(
184
184
at::Tensor output) {
185
185
// Get current cuda stream.
186
186
auto stream = at::cuda::getCurrentHIPStream ().stream ();
187
- int group_count = A.size ();
187
+ int64_t group_count = A.size ();
188
188
// When group count is large, we can more efficiently initialize
189
189
// by doing host setup and a memcpy. This is only viable if cuda
190
190
// graphs arent being used.
191
- int output_offset = 0 ;
191
+ int64_t output_offset = 0 ;
192
192
if (group_count >= 16 && stream == 0 ) {
193
193
std::vector<KernelArguments> ggemm_kargs;
194
194
ggemm_kargs.reserve (group_count);
195
195
196
196
// Iterate over inputs and get group information.
197
197
for (int i = 0 ; i < group_count; i++) {
198
- int M = A[i].size (0 );
199
- int K = A[i].size (1 );
200
- int N = B[i].size (0 );
198
+ int64_t M = A[i].size (0 );
199
+ int64_t K = A[i].size (1 );
200
+ int64_t N = B[i].size (0 );
201
201
KernelArguments group_args = {
202
202
reinterpret_cast <ADataType*>(A[i].data_ptr ()),
203
203
reinterpret_cast <BDataType*>(B[i].data_ptr ()),
204
204
{},
205
205
reinterpret_cast <CDataType*>(output.data_ptr ()) + output_offset,
206
- M ,
207
- N ,
208
- K ,
209
- K ,
210
- K ,
206
+ int (M) ,
207
+ int (N) ,
208
+ int (K) ,
209
+ int (K) ,
210
+ int (K) ,
211
211
{},
212
- N };
212
+ int (N) };
213
213
output_offset += M * N;
214
214
ggemm_kargs.push_back (group_args);
215
215
}
@@ -224,9 +224,9 @@ void set_static_kernel_args(
224
224
// Using multiple kernels this way allows us to support arbitrary M,N,K.
225
225
// For some reason, this approach is faster than using hipmemcpy.
226
226
for (int i = 0 ; i < group_count; i++) {
227
- int M = A[i].size (0 );
228
- int K = A[i].size (1 );
229
- int N = B[i].size (0 );
227
+ int64_t M = A[i].size (0 );
228
+ int64_t K = A[i].size (1 );
229
+ int64_t N = B[i].size (0 );
230
230
// Launch kernel to set kernel arguments.
231
231
set_kernel_args_kernel<<<1 , 1 , 0 , stream>>>(
232
232
reinterpret_cast <KernelArguments*>(
@@ -249,27 +249,27 @@ __global__ void set_kernel_args_fixed_nk_kernel(
249
249
BDataType* B,
250
250
CDataType* output,
251
251
int64_t * prepad_M,
252
- int M,
253
- int N,
254
- int K,
255
- int group_count) {
252
+ int64_t M,
253
+ int64_t N,
254
+ int64_t K,
255
+ int64_t group_count) {
256
256
int group_idx = blockIdx.x * blockDim.x + threadIdx.x ;
257
257
// Each thread is responsible for setting up the arguments for one group.
258
258
if (group_idx < group_count) {
259
259
// Compute offsets for this group.
260
- int group_M = prepad_M[group_idx];
260
+ int64_t group_M = prepad_M[group_idx];
261
261
KernelArguments kernel_group_args = {
262
262
A + (group_idx * M * K),
263
263
B + (group_idx * N * K),
264
264
{},
265
265
output + (group_idx * M * N),
266
- group_M,
267
- N ,
268
- K ,
269
- K ,
270
- K ,
266
+ int ( group_M) ,
267
+ int (N) ,
268
+ int (K) ,
269
+ int (K) ,
270
+ int (K) ,
271
271
{},
272
- N };
272
+ int (N) };
273
273
// Write kernel args to memory.
274
274
kernel_args[group_idx] = kernel_group_args;
275
275
}
@@ -281,16 +281,16 @@ __global__ void set_kernel_args_m_sizes_kernel(
281
281
BDataType* B,
282
282
CDataType* output,
283
283
int64_t * M_sizes,
284
- int M,
285
- int N,
286
- int K,
287
- int group_count) {
284
+ int64_t M,
285
+ int64_t N,
286
+ int64_t K,
287
+ int64_t group_count) {
288
288
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x ;
289
289
// Each thread is responsible for setting up the arguments for one group.
290
290
if (thread_idx < group_count) {
291
291
// Get M information for this group.
292
- int kernel_M = M_sizes[thread_idx];
293
- int offset_M = 0 ;
292
+ int64_t kernel_M = M_sizes[thread_idx];
293
+ int64_t offset_M = 0 ;
294
294
// Offset is computed by finding the sum of previous group Ms.
295
295
for (int i = 0 ; i < thread_idx; i++) {
296
296
offset_M += M_sizes[i];
@@ -300,13 +300,13 @@ __global__ void set_kernel_args_m_sizes_kernel(
300
300
B + (thread_idx * N * K),
301
301
{},
302
302
output + (offset_M * N),
303
- kernel_M,
304
- N ,
305
- K ,
306
- K ,
307
- K ,
303
+ int ( kernel_M) ,
304
+ int (N) ,
305
+ int (K) ,
306
+ int (K) ,
307
+ int (K) ,
308
308
{},
309
- N };
309
+ int (N) };
310
310
// Write kernel args to memory.
311
311
kernel_args[thread_idx] = kernel_group_args;
312
312
}
@@ -334,9 +334,9 @@ void set_dynamic_kernel_args(
334
334
335
335
// We assume that M, N, and K are fixed across groups.
336
336
// The actual m values are sstored in the passed M tensor.
337
- int M = A.size (1 );
338
- int K = A.size (2 );
339
- int N = B.size (1 );
337
+ int64_t M = A.size (1 );
338
+ int64_t K = A.size (2 );
339
+ int64_t N = B.size (1 );
340
340
341
341
// Launch a kernel that sets kernel argument memory.
342
342
set_kernel_args_fixed_nk_kernel<<<1 , group_count, 0 , stream>>>(
@@ -365,9 +365,9 @@ at::Tensor get_stacked_kernel_args(
365
365
{static_cast <long >(group_count * sizeof (KernelArguments))},
366
366
A.options ().dtype (at::kByte ));
367
367
368
- int M = A.size (A.dim () - 2 );
369
- int K = B.size (2 );
370
- int N = B.size (1 );
368
+ int64_t M = A.size (A.dim () - 2 );
369
+ int64_t K = B.size (2 );
370
+ int64_t N = B.size (1 );
371
371
372
372
set_kernel_args_m_sizes_kernel<<<1 , group_count, 0 , stream>>>(
373
373
reinterpret_cast <KernelArguments*>(kernel_args.data_ptr ()),
@@ -408,8 +408,8 @@ OutputType _bf16bf16bf16_grouped(
408
408
int64_t total_output_size = 0 ;
409
409
int64_t total_M = 0 ;
410
410
for (int i = 0 ; i < group_count; ++i) {
411
- int M = A[i].size (0 );
412
- int N = B[i].size (0 );
411
+ int64_t M = A[i].size (0 );
412
+ int64_t N = B[i].size (0 );
413
413
total_M += M;
414
414
const int64_t output_size = M * N;
415
415
total_output_size += output_size;
@@ -428,9 +428,9 @@ OutputType _bf16bf16bf16_grouped(
428
428
429
429
// Perform shape lookup to find best kernel.
430
430
// We use the largest of each shape for heuristics.
431
- int MaxM = 0 ;
432
- int MaxN = 0 ;
433
- int MaxK = 0 ;
431
+ int64_t MaxM = 0 ;
432
+ int64_t MaxN = 0 ;
433
+ int64_t MaxK = 0 ;
434
434
for (int i = 0 ; i < group_count; i++) {
435
435
MaxM = max (MaxM, A[i].size (0 ));
436
436
MaxN = max (MaxN, B[i].size (0 ));
@@ -473,10 +473,10 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
473
473
// First confirm that there are the same number of groups in all inputs.
474
474
TORCH_CHECK (
475
475
A.size (0 ) == B.size (0 ), " A and B must have the same number of groups." );
476
- int group_count = A.size (0 );
477
- int M = A.size (1 );
478
- int N = B.size (1 );
479
- int K = B.size (2 );
476
+ int64_t group_count = A.size (0 );
477
+ int64_t M = A.size (1 );
478
+ int64_t N = B.size (1 );
479
+ int64_t K = B.size (2 );
480
480
TORCH_CHECK (A.is_cuda () && A.is_contiguous ());
481
481
TORCH_CHECK (A.dim () == 3 , " Inputs must be 3D [G, M, K]." );
482
482
TORCH_CHECK (A.dtype () == at::kBFloat16 , " Inputs must be type bfloat16." );
@@ -499,9 +499,9 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
499
499
500
500
// Perform shape lookup to find best kernel.
501
501
// We use the largest of each shape for heuristics.
502
- int MaxM = 0 ;
503
- int MaxN = 0 ;
504
- int MaxK = 0 ;
502
+ int64_t MaxM = 0 ;
503
+ int64_t MaxN = 0 ;
504
+ int64_t MaxK = 0 ;
505
505
for (int i = 0 ; i < group_count; i++) {
506
506
MaxM = max (MaxM, A[i].size (0 ));
507
507
MaxN = max (MaxN, B[i].size (0 ));
@@ -519,12 +519,12 @@ at::Tensor bf16bf16bf16_grouped_stacked(
519
519
at::Tensor M_sizes) {
520
520
// Check that input datatypes are valid.
521
521
// First confirm that there are the same number of groups in all inputs.
522
- int group_count = M_sizes.size (0 );
522
+ int64_t group_count = M_sizes.size (0 );
523
523
// X is expected to be shape [total_M, K].
524
- int total_M = X.size (0 );
524
+ int64_t total_M = X.size (0 );
525
525
// W is expected to be shape [G, N, K].
526
- int N = W.size (1 );
527
- int K = X.size (1 );
526
+ int64_t N = W.size (1 );
527
+ int64_t K = X.size (1 );
528
528
TORCH_CHECK (W.size (0 ) == group_count,
529
529
" All inputs must have the same number of groups." );
530
530
0 commit comments