@@ -20,6 +20,17 @@ using Tensor = at::Tensor;
20
20
21
21
namespace fbgemm_gpu {
22
22
23
+ Tensor asynchronous_complete_cumsum_meta (const Tensor& t_in) {
24
+ const auto num_dims = t_in.dim ();
25
+ TORCH_CHECK (num_dims == 1 || num_dims == 2 );
26
+
27
+ auto output = num_dims == 1
28
+ ? at::zeros_symint ({t_in.sym_numel () + 1 }, t_in.options ())
29
+ : at::zeros_symint (
30
+ {t_in.sym_size (0 ), t_in.sym_size (1 ) + 1 }, t_in.options ());
31
+ return output;
32
+ }
33
+
23
34
namespace {
24
35
25
36
Tensor pack_segments_forward_meta (
@@ -62,13 +73,27 @@ Tensor batched_unary_embeddings_forward_meta(
62
73
return at::empty_symint ({N, B, T}, weight.options ());
63
74
}
64
75
76
+ Tensor asynchronous_inclusive_cumsum_meta (const Tensor& t_in) {
77
+ return at::empty_symint (t_in.sym_sizes (), t_in.options ());
78
+ }
79
+
80
+ Tensor asynchronous_exclusive_cumsum_meta (const Tensor& t_in) {
81
+ return at::empty_symint (t_in.sym_sizes (), t_in.options ());
82
+ }
83
+
65
84
} // namespace
66
85
67
86
} // namespace fbgemm_gpu
68
87
69
88
TORCH_LIBRARY_IMPL (fbgemm, Meta, m) {
70
89
m.impl (" pack_segments" , TORCH_FN (fbgemm_gpu::pack_segments_forward_meta));
71
90
m.impl (" unpack_segments" , TORCH_FN (fbgemm_gpu::pack_segments_backward_meta));
91
+ m.impl (
92
+ " asynchronous_inclusive_cumsum" ,
93
+ TORCH_FN (fbgemm_gpu::asynchronous_inclusive_cumsum_meta));
94
+ m.impl (
95
+ " asynchronous_exclusive_cumsum" ,
96
+ TORCH_FN (fbgemm_gpu::asynchronous_exclusive_cumsum_meta));
72
97
m.impl (
73
98
" asynchronous_complete_cumsum" ,
74
99
TORCH_FN (fbgemm_gpu::asynchronous_complete_cumsum_meta));
0 commit comments