Skip to content

Commit 9548cac

Browse files
aakhundovfacebook-github-bot
authored andcommitted
Add meta-functions for asynchronous_*_cumsum ops (pytorch#2028)
Summary: ATT Reviewed By: xw285cornell Differential Revision: D49467255
1 parent b568c53 commit 9548cac

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,17 +1126,6 @@ Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) {
11261126
return output;
11271127
}
11281128

1129-
Tensor asynchronous_complete_cumsum_meta(const Tensor& t_in) {
1130-
const auto num_dims = t_in.dim();
1131-
TORCH_CHECK(num_dims == 1 || num_dims == 2);
1132-
1133-
auto output = num_dims == 1
1134-
? at::zeros_symint({t_in.sym_numel() + 1}, t_in.options())
1135-
: at::zeros_symint(
1136-
{t_in.sym_size(0), t_in.sym_size(1) + 1}, t_in.options());
1137-
return output;
1138-
}
1139-
11401129
template <typename index_t, typename scalar_t>
11411130
void reorder_batched_ad_lengths_(
11421131
const Tensor& cat_ad_lengths,

fbgemm_gpu/src/sparse_ops/sparse_ops_meta.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ using Tensor = at::Tensor;
2020

2121
namespace fbgemm_gpu {
2222

23+
Tensor asynchronous_complete_cumsum_meta(const Tensor& t_in) {
24+
const auto num_dims = t_in.dim();
25+
TORCH_CHECK(num_dims == 1 || num_dims == 2);
26+
27+
auto output = num_dims == 1
28+
? at::zeros_symint({t_in.sym_numel() + 1}, t_in.options())
29+
: at::zeros_symint(
30+
{t_in.sym_size(0), t_in.sym_size(1) + 1}, t_in.options());
31+
return output;
32+
}
33+
2334
namespace {
2435

2536
Tensor pack_segments_forward_meta(
@@ -62,6 +73,14 @@ Tensor batched_unary_embeddings_forward_meta(
6273
return at::empty_symint({N, B, T}, weight.options());
6374
}
6475

76+
Tensor asynchronous_inclusive_cumsum_meta(const Tensor& t_in) {
77+
return at::empty_symint(t_in.sym_sizes(), t_in.options());
78+
}
79+
80+
Tensor asynchronous_exclusive_cumsum_meta(const Tensor& t_in) {
81+
return at::empty_symint(t_in.sym_sizes(), t_in.options());
82+
}
83+
6584
} // namespace
6685

6786
} // namespace fbgemm_gpu
@@ -71,6 +90,12 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
7190
m.impl(
7291
"pack_segments_backward",
7392
TORCH_FN(fbgemm_gpu::pack_segments_backward_meta));
93+
m.impl(
94+
"asynchronous_inclusive_cumsum",
95+
TORCH_FN(fbgemm_gpu::asynchronous_inclusive_cumsum_meta));
96+
m.impl(
97+
"asynchronous_exclusive_cumsum",
98+
TORCH_FN(fbgemm_gpu::asynchronous_exclusive_cumsum_meta));
7499
m.impl(
75100
"asynchronous_complete_cumsum",
76101
TORCH_FN(fbgemm_gpu::asynchronous_complete_cumsum_meta));

0 commit comments

Comments
 (0)