Skip to content

Commit 2089ca0

Browse files
sryapfacebook-github-bot
authored andcommitted
Update asynchronous_complete_cumsum to support 2D inputs (pytorch#1573)
Summary: Pull Request resolved: pytorch#1573 Before this diff, `asynchronous_complete_cumsum` only supports 1D inputs. This diff adds the 2D input support. Reviewed By: yinghai, jianyuh Differential Revision: D42956351 fbshipit-source-id: 5d8c6c600a95df572d8535175de131b493b48b43
1 parent 84fe62b commit 2089ca0

File tree

4 files changed

+274
-32
lines changed

4 files changed

+274
-32
lines changed

fbgemm_gpu/bench/sparse_ops_benchmark.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,5 +312,46 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
312312
)
313313

314314

315+
@cli.command()
316+
@click.option("--num-vecs", default=2048)
317+
@click.option("--num-entries-per-vec", default=1024)
318+
@click.option("--dtype", type=str, default="long")
319+
def asynchronous_complete_cumsum_2d_bench(
320+
num_vecs: int,
321+
num_entries_per_vec: int,
322+
dtype: str,
323+
) -> None:
324+
# Reference code from TorchRec https://github.com/pytorch/torchrec/pull/332
325+
@torch.jit.script
326+
def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor:
327+
(f, b) = lengths.shape
328+
offsets_0 = lengths.new_zeros((f, 1))
329+
offsets_1 = torch.cumsum(lengths, dim=-1).to(lengths.dtype)
330+
offsets = torch.cat([offsets_0, offsets_1], dim=-1)
331+
return offsets
332+
333+
assert dtype == "int" or dtype == "long", "Only int and long are supported"
334+
index_dtype = torch.int64 if dtype == "long" else torch.int32
335+
336+
x = torch.randint(low=0, high=100, size=(num_vecs, num_entries_per_vec)).type(
337+
index_dtype
338+
)
339+
x = x.cuda()
340+
341+
time_ref, _ = benchmark_torch_function(
342+
asynchronous_complete_cumsum_2d_ref, (x,), num_warmups=100, iters=1000
343+
)
344+
345+
time, _ = benchmark_torch_function(
346+
torch.ops.fbgemm.asynchronous_complete_cumsum, (x,), num_warmups=100, iters=1000
347+
)
348+
349+
logging.info(
350+
f"asynchronous_complete_cumsum_2d_bench: input shape {x.shape}, dtype {dtype}"
351+
)
352+
logging.info(f"ref time: {time_ref:.5f} sec")
353+
logging.info(f"fbgemm_gpu time: {time:.5f} sec")
354+
355+
315356
if __name__ == "__main__":
316357
cli()

fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,14 +3054,44 @@ class FixedDivisor {
30543054
int shift_;
30553055
};
30563056

3057+
/**
3058+
* inclusive_sum_scan_kernel performs intra- and inter-thread block sum scan
3059+
* (i.e., prefix sum scan). We use cub::BlockScan to do inclusive sum within
3060+
* thread block and use a waterfall sync method to perform prefix sum across
3061+
* thread block.
3062+
*
3063+
* @param arr an array of input values. Its length must be fixed to
3064+
* ITEMS_PER_THREAD
3065+
* @param temp_storage a shared memory struct for cub::BlockScan
3066+
* @param block_flags a global flag buffer for inter-block sync (must be
3067+
* initialized with zeros)
3068+
* @param block_sums a global sum buffer for inter-block sync
3069+
* @param block_prev a shared memory pointer for sharing sum from the previous
3070+
* block within a block
3071+
* @param num_entries_per_block a number of input entries for this block
3072+
* @param block_id a relative thread block ID (the first block that contains
3073+
* the first set of input entries has block_id = 0)
3074+
* @param is_multi_block a boolean to indicate if inter-block sum scan has to
3075+
* be performed
3076+
* @param signal If the value of block_flags of the previous block is equal to
3077+
* signal, it means that the previous block has written its sum
3078+
* to block_sums. We have thread blocks increment the value of
3079+
* block_flags by one after they write their sums to block_sums.
3080+
* We increment the flag instead of setting the flag to a single
3081+
* value to support multiple sequential inclusive_sum_scan_kernel
3082+
* calls (e.g., in the AUC kernel). signal is the order that
3083+
* inclusive_sum_scan_kernel is called. Since we intialize
3084+
* block_flags with zeros, the signal of the first call should be
3085+
* one.
3086+
*/
30573087
template <typename scalar_t, int ITEMS_PER_THREAD, int NUM_THREADS_PER_BLOCK>
30583088
__inline__ __device__ void inclusive_sum_scan_kernel(
30593089
scalar_t (&arr)[ITEMS_PER_THREAD],
30603090
typename cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>::TempStorage&
30613091
temp_storage,
3062-
int* block_flags, // global flags for inter-block sync
3063-
scalar_t* block_sums, // global sums for inter-block sync
3064-
scalar_t* block_prev, // shared memory for previous sum sync within a block
3092+
int* block_flags,
3093+
scalar_t* block_sums,
3094+
scalar_t* block_prev,
30653095
const int num_entries_per_block,
30663096
const int block_id,
30673097
const bool is_multi_block,

fbgemm_gpu/src/sparse_ops.cu

Lines changed: 173 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,71 @@ Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) {
270270
return t_out;
271271
}
272272
273+
template <
274+
typename scalar_t,
275+
int ITEMS_PER_THREAD,
276+
int NUM_THREADS_PER_BLOCK,
277+
int MAX_ENTRIES_PER_BLOCK>
278+
__global__
279+
__launch_bounds__(NUM_THREADS_PER_BLOCK) void batched_complete_cumsum_kernel(
280+
const scalar_t* __restrict__ input,
281+
const int32_t num_entries,
282+
const int32_t last_block_num_entries,
283+
const int32_t padded_num_entries_per_block,
284+
const int32_t num_blocks,
285+
int32_t* __restrict__ block_flags,
286+
scalar_t* __restrict__ block_sums,
287+
scalar_t* __restrict__ output) {
288+
typedef cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK> BlockScan;
289+
__shared__ typename BlockScan::TempStorage bs_temp_storage;
290+
__shared__ scalar_t block_prev;
291+
292+
scalar_t arr[ITEMS_PER_THREAD];
293+
294+
const int32_t block_id = blockIdx.x % num_blocks;
295+
const int32_t vec_id = blockIdx.x / num_blocks;
296+
297+
const int num_entries_per_block = block_id == num_blocks - 1
298+
? last_block_num_entries
299+
: MAX_ENTRIES_PER_BLOCK;
300+
const int input_offset = vec_id * num_entries;
301+
const int output_offset = vec_id * (num_entries + 1);
302+
const int flag_offset = vec_id * num_blocks;
303+
const int block_offset = block_id * padded_num_entries_per_block;
304+
const bool is_multi_block = num_blocks > 1;
305+
const int section_offset = ITEMS_PER_THREAD * threadIdx.x;
306+
307+
// Load input entries into array
308+
for (int i = 0;
309+
i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
310+
i++) {
311+
arr[i] = input[input_offset + block_offset + section_offset + i];
312+
}
313+
314+
inclusive_sum_scan_kernel<scalar_t, ITEMS_PER_THREAD, NUM_THREADS_PER_BLOCK>(
315+
arr,
316+
bs_temp_storage,
317+
is_multi_block ? block_flags + flag_offset : nullptr,
318+
is_multi_block ? block_sums + flag_offset : nullptr,
319+
is_multi_block ? &block_prev : nullptr,
320+
num_entries_per_block,
321+
block_id,
322+
is_multi_block,
323+
/*signal=*/1);
324+
325+
// Write zero to the first entry of each vector
326+
if (block_id == 0 && threadIdx.x == 0) {
327+
output[output_offset] = 0;
328+
}
329+
330+
// Load results to output
331+
for (int i = 0;
332+
i < ITEMS_PER_THREAD && section_offset + i < num_entries_per_block;
333+
i++) {
334+
output[output_offset + block_offset + section_offset + i + 1] = arr[i];
335+
}
336+
}
337+
273338
Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
274339
TENSOR_ON_CUDA_GPU(t_in);
275340
@@ -278,35 +343,114 @@ Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
278343
size_t temp_storage_bytes = 0;
279344
TORCH_CHECK(t_in.is_contiguous());
280345
TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong);
281-
// CUB only handles up to INT_MAX elements.
282-
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
283-
TORCH_CHECK(t_in.dim() == 1);
284-
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
285-
t_out[0].zero_();
286-
AT_DISPATCH_INDEX_TYPES(
287-
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
288-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
289-
nullptr,
290-
temp_storage_bytes,
291-
t_in.data_ptr<index_t>(),
292-
t_out.data_ptr<index_t>() + 1,
293-
t_in.numel(),
294-
at::cuda::getCurrentCUDAStream()));
295-
});
296-
auto temp_storage = at::empty(
297-
{static_cast<int64_t>(temp_storage_bytes)},
298-
t_in.options().dtype(at::kByte));
299-
AT_DISPATCH_INDEX_TYPES(
300-
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
301-
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
302-
temp_storage.data_ptr(),
303-
temp_storage_bytes,
304-
t_in.data_ptr<index_t>(),
305-
t_out.data_ptr<index_t>() + 1,
306-
t_in.numel(),
307-
at::cuda::getCurrentCUDAStream()));
308-
});
309-
return t_out;
346+
TORCH_CHECK(t_in.dim() == 1 || t_in.dim() == 2);
347+
if (t_in.dim() == 1) {
348+
// CUB only handles up to INT_MAX elements.
349+
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
350+
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
351+
t_out[0].zero_();
352+
AT_DISPATCH_INDEX_TYPES(
353+
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
354+
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
355+
nullptr,
356+
temp_storage_bytes,
357+
t_in.data_ptr<index_t>(),
358+
t_out.data_ptr<index_t>() + 1,
359+
t_in.numel(),
360+
at::cuda::getCurrentCUDAStream()));
361+
});
362+
auto temp_storage = at::empty(
363+
{static_cast<int64_t>(temp_storage_bytes)},
364+
t_in.options().dtype(at::kByte));
365+
AT_DISPATCH_INDEX_TYPES(
366+
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
367+
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
368+
temp_storage.data_ptr(),
369+
temp_storage_bytes,
370+
t_in.data_ptr<index_t>(),
371+
t_out.data_ptr<index_t>() + 1,
372+
t_in.numel(),
373+
at::cuda::getCurrentCUDAStream()));
374+
});
375+
return t_out;
376+
} else {
377+
// Fix NUM_THREADS_PER_BLOCK because of CUB
378+
constexpr int32_t MAX_ENTRIES_PER_BLOCK = 512;
379+
constexpr int32_t NUM_THREADS_PER_BLOCK = 256;
380+
const int32_t LOG_NUM_THREADS = std::log2(NUM_THREADS_PER_BLOCK);
381+
382+
// Enforce the same constraint as CUB
383+
const auto num_vecs = t_in.size(0);
384+
const auto num_entries = t_in.size(1);
385+
TORCH_CHECK(num_entries < std::numeric_limits<int32_t>::max());
386+
387+
auto t_out = at::empty({num_vecs, num_entries + 1}, t_in.options());
388+
389+
const auto num_blocks = div_round_up(num_entries, MAX_ENTRIES_PER_BLOCK);
390+
const int num_entries_per_block =
391+
num_blocks > 1 ? MAX_ENTRIES_PER_BLOCK : num_entries;
392+
// rounded_num_entries_per_block is either 0 or 256
393+
const int rounded_num_entries_per_block =
394+
(num_entries_per_block >> LOG_NUM_THREADS) << LOG_NUM_THREADS;
395+
// padded_num_entries_per_block is either 256 or 512
396+
const int padded_num_entries_per_block = rounded_num_entries_per_block +
397+
(rounded_num_entries_per_block != num_entries_per_block
398+
? NUM_THREADS_PER_BLOCK
399+
: 0);
400+
const int items_per_thread =
401+
padded_num_entries_per_block / NUM_THREADS_PER_BLOCK;
402+
const int last_block_num_entries =
403+
num_entries - ((num_blocks - 1) * MAX_ENTRIES_PER_BLOCK);
404+
const auto grid_size = num_blocks * num_vecs;
405+
406+
at::Tensor block_flags;
407+
at::Tensor block_sums;
408+
if (num_blocks > 1) {
409+
block_flags = at::zeros({grid_size}, t_in.options().dtype(at::kInt));
410+
block_sums = at::empty({grid_size}, t_out.options());
411+
}
412+
413+
auto max_smem_size =
414+
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
415+
416+
#define INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(ITEMS_PER_THREAD) \
417+
batched_complete_cumsum_kernel< \
418+
index_t, \
419+
ITEMS_PER_THREAD, \
420+
NUM_THREADS_PER_BLOCK, \
421+
MAX_ENTRIES_PER_BLOCK> \
422+
<<<grid_size, \
423+
NUM_THREADS_PER_BLOCK, \
424+
0, \
425+
at::cuda::getCurrentCUDAStream()>>>( \
426+
t_in.data_ptr<index_t>(), \
427+
num_entries, \
428+
last_block_num_entries, \
429+
padded_num_entries_per_block, \
430+
num_blocks, \
431+
num_blocks > 1 ? block_flags.data_ptr<int32_t>() : nullptr, \
432+
num_blocks > 1 ? block_sums.data_ptr<index_t>() : nullptr, \
433+
t_out.data_ptr<index_t>())
434+
435+
AT_DISPATCH_INDEX_TYPES(
436+
t_in.scalar_type(), "batched_complete_cumsum_kernel_warpper", [&] {
437+
typedef cub::BlockScan<index_t, NUM_THREADS_PER_BLOCK> BlockScan;
438+
TORCH_CHECK(
439+
sizeof(BlockScan::TempStorage) + sizeof(index_t) <=
440+
max_smem_size);
441+
TORCH_CHECK(items_per_thread == 1 || items_per_thread == 2)
442+
if (items_per_thread == 1) {
443+
INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(1);
444+
} else {
445+
INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL(2);
446+
}
447+
C10_CUDA_KERNEL_LAUNCH_CHECK();
448+
});
449+
450+
#undef INVOKE_BATCHED_COMPLETE_CUMSUM_KERNEL
451+
452+
return t_out;
453+
}
310454
}
311455
312456
// Kernel for permuting the indices and weights. Used for permutation of sparse

fbgemm_gpu/test/sparse_ops_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,33 @@ def test_cumsum(self, n: int, long_index: bool) -> None:
558558
zc.cpu(),
559559
)
560560

561+
@unittest.skipIf(*gpu_unavailable)
562+
# pyre-ignore [56]
563+
@given(
564+
n=st.integers(min_value=1, max_value=600),
565+
b=st.integers(min_value=1, max_value=10),
566+
long_index=st.booleans(),
567+
)
568+
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
569+
def test_asynchronous_complete_cumsum_2d(
570+
self, n: int, b: int, long_index: bool
571+
) -> None:
572+
index_dtype = torch.int64 if long_index else torch.int32
573+
np_index_dtype = np.int64 if long_index else np.int32
574+
575+
x = torch.randint(low=0, high=100, size=(b, n)).type(index_dtype)
576+
x = x.cuda()
577+
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
578+
zeros = torch.zeros(b, 1)
579+
torch.testing.assert_close(
580+
torch.from_numpy(
581+
np.cumsum(torch.concat([zeros, x.cpu()], dim=1).numpy(), axis=1).astype(
582+
np_index_dtype
583+
)
584+
),
585+
zc.cpu(),
586+
)
587+
561588
# pyre-ignore [56]
562589
@given(
563590
N=st.integers(min_value=1, max_value=20),

0 commit comments

Comments
 (0)