@@ -270,6 +270,71 @@ Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) {
270
270
return t_out;
271
271
}
272
272
273
+ template <
274
+ typename scalar_t ,
275
+ int ITEMS_PER_THREAD,
276
+ int NUM_THREADS_PER_BLOCK,
277
+ int MAX_ENTRIES_PER_BLOCK>
278
+ __global__
279
+ __launch_bounds__ (NUM_THREADS_PER_BLOCK) void batched_complete_cumsum_kernel(
280
+ const scalar_t * __restrict__ input,
281
+ const int32_t num_entries,
282
+ const int32_t last_block_num_entries,
283
+ const int32_t padded_num_entries_per_block,
284
+ const int32_t num_blocks,
285
+ int32_t * __restrict__ block_flags,
286
+ scalar_t * __restrict__ block_sums,
287
+ scalar_t * __restrict__ output) {
288
+ typedef cub::BlockScan<scalar_t , NUM_THREADS_PER_BLOCK> BlockScan;
289
+ __shared__ typename BlockScan::TempStorage bs_temp_storage;
290
+ __shared__ scalar_t block_prev;
291
+
292
+ scalar_t arr[ITEMS_PER_THREAD];
293
+
294
+ const int32_t block_id = blockIdx .x % num_blocks;
295
+ const int32_t vec_id = blockIdx .x / num_blocks;
296
+
297
+ const int num_entries_per_block = block_id == num_blocks - 1
298
+ ? last_block_num_entries
299
+ : MAX_ENTRIES_PER_BLOCK;
300
+ const int input_offset = vec_id * num_entries;
301
+ const int output_offset = vec_id * (num_entries + 1 );
302
+ const int flag_offset = vec_id * num_blocks;
303
+ const int block_offset = block_id * padded_num_entries_per_block;
304
+ const bool is_multi_block = num_blocks > 1 ;
305
+ const int section_offset = ITEMS_PER_THREAD * threadIdx .x ;
306
+
307
+ // Load input entries into array
308
+ for (int i = 0 ;
309
+ i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
310
+ i++) {
311
+ arr[i] = input[input_offset + block_offset + section_offset + i];
312
+ }
313
+
314
+ inclusive_sum_scan_kernel<scalar_t , ITEMS_PER_THREAD, NUM_THREADS_PER_BLOCK>(
315
+ arr,
316
+ bs_temp_storage,
317
+ is_multi_block ? block_flags + flag_offset : nullptr ,
318
+ is_multi_block ? block_sums + flag_offset : nullptr ,
319
+ is_multi_block ? &block_prev : nullptr ,
320
+ num_entries_per_block,
321
+ block_id,
322
+ is_multi_block,
323
+ /* signal=*/ 1 );
324
+
325
+ // Write zero to the first entry of each vector
326
+ if (block_id == 0 && threadIdx .x == 0 ) {
327
+ output[output_offset] = 0 ;
328
+ }
329
+
330
+ // Load results to output
331
+ for (int i = 0 ;
332
+ i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
333
+ i++) {
334
+ output[output_offset + block_offset + section_offset + i + 1 ] = arr[i];
335
+ }
336
+ }
337
+
273
338
Tensor asynchronous_complete_cumsum_gpu (const Tensor& t_in) {
274
339
TENSOR_ON_CUDA_GPU (t_in);
275
340
@@ -278,35 +343,114 @@ Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
278
343
size_t temp_storage_bytes = 0 ;
279
344
TORCH_CHECK (t_in.is_contiguous ());
280
345
TORCH_CHECK (t_in.dtype () == at::kInt || t_in.dtype () == at::kLong );
281
- // CUB only handles up to INT_MAX elements.
282
- TORCH_CHECK (t_in.numel () < std::numeric_limits<int32_t >::max ());
283
- TORCH_CHECK (t_in.dim () == 1 );
284
- auto t_out = at::empty ({t_in.numel () + 1 }, t_in.options ());
285
- t_out[0 ].zero_ ();
286
- AT_DISPATCH_INDEX_TYPES (
287
- t_in.scalar_type (), " cub_inclusive_sum_wrapper1" , [&] {
288
- AT_CUDA_CHECK (FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum (
289
- nullptr ,
290
- temp_storage_bytes,
291
- t_in.data_ptr <index_t >(),
292
- t_out.data_ptr <index_t >() + 1 ,
293
- t_in.numel (),
294
- at::cuda::getCurrentCUDAStream ()));
295
- });
296
- auto temp_storage = at::empty (
297
- {static_cast <int64_t >(temp_storage_bytes)},
298
- t_in.options ().dtype (at::kByte ));
299
- AT_DISPATCH_INDEX_TYPES (
300
- t_in.scalar_type (), " cub_inclusive_sum_wrapper2" , [&] {
301
- AT_CUDA_CHECK (FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum (
302
- temp_storage.data_ptr (),
303
- temp_storage_bytes,
304
- t_in.data_ptr <index_t >(),
305
- t_out.data_ptr <index_t >() + 1 ,
306
- t_in.numel (),
307
- at::cuda::getCurrentCUDAStream ()));
308
- });
309
- return t_out;
346
+ TORCH_CHECK (t_in.dim () == 1 || t_in.dim () == 2 );
347
+ if (t_in.dim () == 1 ) {
348
+ // CUB only handles up to INT_MAX elements.
349
+ TORCH_CHECK (t_in.numel () < std::numeric_limits<int32_t >::max ());
350
+ auto t_out = at::empty ({t_in.numel () + 1 }, t_in.options ());
351
+ t_out[0 ].zero_ ();
352
+ AT_DISPATCH_INDEX_TYPES (
353
+ t_in.scalar_type (), " cub_inclusive_sum_wrapper1" , [&] {
354
+ AT_CUDA_CHECK (FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum (
355
+ nullptr ,
356
+ temp_storage_bytes,
357
+ t_in.data_ptr <index_t >(),
358
+ t_out.data_ptr <index_t >() + 1 ,
359
+ t_in.numel (),
360
+ at::cuda::getCurrentCUDAStream ()));
361
+ });
362
+ auto temp_storage = at::empty (
363
+ {static_cast <int64_t >(temp_storage_bytes)},
364
+ t_in.options ().dtype (at::kByte ));
365
+ AT_DISPATCH_INDEX_TYPES (
366
+ t_in.scalar_type (), " cub_inclusive_sum_wrapper2" , [&] {
367
+ AT_CUDA_CHECK (FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum (
368
+ temp_storage.data_ptr (),
369
+ temp_storage_bytes,
370
+ t_in.data_ptr <index_t >(),
371
+ t_out.data_ptr <index_t >() + 1 ,
372
+ t_in.numel (),
373
+ at::cuda::getCurrentCUDAStream ()));
374
+ });
375
+ return t_out;
376
+ } else {
377
+ // Fix NUM_THREADS_PER_BLOCK because of CUB
378
+ constexpr int32_t MAX_ENTRIES_PER_BLOCK = 512 ;
379
+ constexpr int32_t NUM_THREADS_PER_BLOCK = 256 ;
380
+ const int32_t LOG_NUM_THREADS = std::log2 (NUM_THREADS_PER_BLOCK);
381
+
382
+ // Enforce the same constraint as CUB
383
+ const auto num_vecs = t_in.size (0 );
384
+ const auto num_entries = t_in.size (1 );
385
+ TORCH_CHECK (num_entries < std::numeric_limits<int32_t >::max ());
386
+
387
+ auto t_out = at::empty ({num_vecs, num_entries + 1 }, t_in.options ());
388
+
389
+ const auto num_blocks = div_round_up (num_entries, MAX_ENTRIES_PER_BLOCK);
390
+ const int num_entries_per_block =
391
+ num_blocks > 1 ? MAX_ENTRIES_PER_BLOCK : num_entries;
392
+ // rounded_num_entries_per_block is either 0 or 256
393
+ const int rounded_num_entries_per_block =
394
+ (num_entries_per_block >> LOG_NUM_THREADS) << LOG_NUM_THREADS;
395
+ // padded_num_entries_per_block is either 256 or 512
396
+ const int padded_num_entries_per_block = rounded_num_entries_per_block +
397
+ (rounded_num_entries_per_block != num_entries_per_block
398
+ ? NUM_THREADS_PER_BLOCK
399
+ : 0 );
400
+ const int items_per_thread =
401
+ padded_num_entries_per_block / NUM_THREADS_PER_BLOCK;
402
+ const int last_block_num_entries =
403
+ num_entries - ((num_blocks - 1 ) * MAX_ENTRIES_PER_BLOCK);
404
+ const auto grid_size = num_blocks * num_vecs;
405
+
406
+ at::Tensor block_flags;
407
+ at::Tensor block_sums;
408
+ if (num_blocks > 1 ) {
409
+ block_flags = at::zeros ({grid_size}, t_in.options ().dtype (at::kInt ));
410
+ block_sums = at::empty ({grid_size}, t_out.options ());
411
+ }
412
+
413
+ auto max_smem_size =
414
+ at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock ;
415
+
416
+ #define INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL (ITEMS_PER_THREAD ) \
417
+ batched_complete_cumsum_kernel< \
418
+ index_t , \
419
+ ITEMS_PER_THREAD, \
420
+ NUM_THREADS_PER_BLOCK, \
421
+ MAX_ENTRIES_PER_BLOCK> \
422
+ <<<grid_size, \
423
+ NUM_THREADS_PER_BLOCK, \
424
+ 0 , \
425
+ at::cuda::getCurrentCUDAStream ()>>>( \
426
+ t_in.data_ptr<index_t >(), \
427
+ num_entries, \
428
+ last_block_num_entries, \
429
+ padded_num_entries_per_block, \
430
+ num_blocks, \
431
+ num_blocks > 1 ? block_flags.data_ptr<int32_t>() : nullptr, \
432
+ num_blocks > 1 ? block_sums.data_ptr<index_t>() : nullptr, \
433
+ t_out.data_ptr<index_t>())
434
+
435
+ AT_DISPATCH_INDEX_TYPES(
436
+ t_in.scalar_type(), "batched_complete_cumsum_kernel_warpper", [&] {
437
+ typedef cub::BlockScan<index_t , NUM_THREADS_PER_BLOCK> BlockScan;
438
+ TORCH_CHECK (
439
+ sizeof (BlockScan::TempStorage) + sizeof (index_t ) <=
440
+ max_smem_size);
441
+ TORCH_CHECK (items_per_thread == 1 || items_per_thread == 2 )
442
+ if (items_per_thread == 1 ) {
443
+ INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL (1 );
444
+ } else {
445
+ INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL (2 );
446
+ }
447
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
448
+ });
449
+
450
+ #undef INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL
451
+
452
+ return t_out;
453
+ }
310
454
}
311
455
312
456
// Kernel for permuting the indices and weights. Used for permutation of sparse
0 commit comments