11
11
#include < ATen/core/dispatch/Dispatcher.h>
12
12
#include < torch/csrc/autograd/custom_function.h>
13
13
#include < torch/library.h>
14
+ #include < torch/torch.h>
14
15
15
16
#include " ATen/TensorUtils.h"
16
17
#include " fbgemm_gpu/sparse_ops.h"
@@ -35,17 +36,18 @@ class JaggedToPaddedDenseOp
35
36
const std::vector<int64_t >& max_lengths,
36
37
const double padding_value) {
37
38
ctx->save_for_backward (offsets);
38
- ctx->saved_data [" total_L" ] = values.size (0 );
39
+ ctx->saved_data [" total_L" ] = values.sym_size (0 );
39
40
40
41
static auto op =
41
42
c10::Dispatcher::singleton ()
42
43
.findSchemaOrThrow (" fbgemm::jagged_to_padded_dense_forward" , " " )
43
44
.typed <at::Tensor (
44
45
const Tensor& values,
45
46
const std::vector<Tensor>& offsets,
46
- const std::vector< int64_t >& max_lengths,
47
+ const at::ArrayRef<at::SymInt >& max_lengths,
47
48
const double padding_value)>();
48
- Tensor padded_values = op.call (values, offsets, max_lengths, padding_value);
49
+ Tensor padded_values = op.call (
50
+ values, offsets, c10::fromIntArrayRefSlow (max_lengths), padding_value);
49
51
50
52
return {padded_values};
51
53
}
@@ -54,7 +56,7 @@ class JaggedToPaddedDenseOp
54
56
torch::autograd::AutogradContext* ctx,
55
57
torch::autograd::variable_list grad_outputs) {
56
58
auto offsets = ctx->get_saved_variables ();
57
- int32_t total_L = ctx->saved_data [" total_L" ].toInt ();
59
+ at::SymInt total_L = ctx->saved_data [" total_L" ].toSymInt ();
58
60
TORCH_CHECK (grad_outputs.size () == 1 );
59
61
60
62
TORCH_CHECK (total_L >= 0 );
@@ -64,7 +66,7 @@ class JaggedToPaddedDenseOp
64
66
.typed <at::Tensor (
65
67
const Tensor& grad_output,
66
68
const std::vector<Tensor>& offsets,
67
- const int64_t total_L)>();
69
+ const at::SymInt& total_L)>();
68
70
auto grad_values = op.call (grad_outputs[0 ], {offsets}, total_L);
69
71
70
72
return {
@@ -86,7 +88,15 @@ class JaggedDenseDenseAddJaggedOutputOp
86
88
const Tensor& dense_0,
87
89
const Tensor& dense_1) {
88
90
ctx->save_for_backward (offsets);
91
+ #if TORCH_VERSION_MAJOR > 2 || \
92
+ (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1 )
93
+ // toSymIntVector support is from a recent PR
94
+ // https://github.com/pytorch/pytorch/pull/101056,
95
+ // so protect it under a version guard for compatibility
96
+ ctx->saved_data [" dense_shape" ] = dense_0.sym_sizes ();
97
+ #else
89
98
ctx->saved_data [" dense_shape" ] = dense_0.sizes ();
99
+ #endif
90
100
91
101
static auto op =
92
102
c10::Dispatcher::singleton ()
@@ -107,7 +117,12 @@ class JaggedDenseDenseAddJaggedOutputOp
107
117
torch::autograd::AutogradContext* ctx,
108
118
torch::autograd::variable_list grad_outputs) {
109
119
auto offsets = ctx->get_saved_variables ();
120
+ #if TORCH_VERSION_MAJOR > 2 || \
121
+ (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1 )
122
+ auto dense_shape = ctx->saved_data [" dense_shape" ].toSymIntVector ();
123
+ #else
110
124
auto dense_shape = ctx->saved_data [" dense_shape" ].toIntVector ();
125
+ #endif
111
126
TORCH_CHECK (grad_outputs.size () == 1 );
112
127
113
128
static auto op =
@@ -116,12 +131,12 @@ class JaggedDenseDenseAddJaggedOutputOp
116
131
.typed <at::Tensor (
117
132
const Tensor& values,
118
133
const std::vector<Tensor>& offsets,
119
- const std::vector< int64_t >& max_lengths,
134
+ const at::ArrayRef<at::SymInt >& max_lengths,
120
135
const double padding_value)>();
121
136
Tensor dense_values_grad_0 = op.call (
122
137
grad_outputs[0 ],
123
138
offsets,
124
- std::vector<int64_t >(dense_shape.begin () + 1 , dense_shape.end () - 1 ),
139
+ std::vector<at::SymInt >(dense_shape.begin () + 1 , dense_shape.end () - 1 ),
125
140
/* padding_value=*/ 0 );
126
141
Tensor dense_values_grad_1 = dense_values_grad_0;
127
142
@@ -249,19 +264,27 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
249
264
torch::autograd::AutogradContext* ctx,
250
265
const Tensor& dense,
251
266
const std::vector<Tensor>& offsets,
252
- const c10::optional<int64_t >& total_L) {
267
+ const c10::optional<at::SymInt >& total_L) {
253
268
ctx->save_for_backward (offsets);
254
269
255
270
// dims of dense tensor: <batch, [maxlen0, maxlen1, ...], embedding_dim>
271
+ #if TORCH_VERSION_MAJOR > 2 || \
272
+ (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1 )
273
+ // toSymIntVector support is from a recent PR
274
+ // https://github.com/pytorch/pytorch/pull/101056,
275
+ // so protect it under a version guard for compatibility
276
+ ctx->saved_data [" dense_shape" ] = dense.sym_sizes ();
277
+ #else
256
278
ctx->saved_data [" dense_shape" ] = dense.sizes ();
279
+ #endif
257
280
258
281
static auto op =
259
282
c10::Dispatcher::singleton ()
260
283
.findSchemaOrThrow (" fbgemm::dense_to_jagged_forward" , " " )
261
284
.typed <Tensor (
262
285
const Tensor& dense,
263
286
const std::vector<Tensor>& offsets,
264
- const c10::optional<int64_t >& total_L)>();
287
+ const c10::optional<at::SymInt >& total_L)>();
265
288
auto output = op.call (dense, offsets, total_L);
266
289
267
290
return {output};
@@ -271,7 +294,12 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
271
294
torch::autograd::AutogradContext* ctx,
272
295
torch::autograd::variable_list grad_outputs) {
273
296
auto offsets = ctx->get_saved_variables ();
297
+ #if TORCH_VERSION_MAJOR > 2 || \
298
+ (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1 )
299
+ auto dense_shape = ctx->saved_data [" dense_shape" ].toSymIntVector ();
300
+ #else
274
301
auto dense_shape = ctx->saved_data [" dense_shape" ].toIntVector ();
302
+ #endif
275
303
TORCH_CHECK (grad_outputs.size () == 1 );
276
304
277
305
static auto op =
@@ -280,15 +308,20 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
280
308
.typed <Tensor (
281
309
const Tensor& values,
282
310
const std::vector<Tensor>& offsets,
283
- const std::vector< int64_t >& max_lengths,
311
+ const at::ArrayRef<at::SymInt >& max_lengths,
284
312
const double padding_value)>();
285
313
auto dense_values_grad = op.call (
286
314
grad_outputs[0 ],
287
315
offsets,
288
- std::vector<int64_t >(dense_shape.begin () + 1 , dense_shape.end () - 1 ),
316
+ std::vector<at::SymInt >(dense_shape.begin () + 1 , dense_shape.end () - 1 ),
289
317
/* padding_value=*/ 0 );
290
318
319
+ #if TORCH_VERSION_MAJOR > 2 || \
320
+ (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1 )
321
+ TORCH_CHECK (dense_values_grad.sym_sizes () == dense_shape);
322
+ #else
291
323
TORCH_CHECK (dense_values_grad.sizes () == dense_shape);
324
+ #endif
292
325
293
326
return {
294
327
dense_values_grad,
@@ -730,7 +763,7 @@ Tensor batched_dense_vec_jagged_2d_mul(
730
763
std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged (
731
764
const Tensor& dense,
732
765
const std::vector<Tensor>& offsets,
733
- const c10::optional<int64_t >& total_L) {
766
+ const c10::optional<at::SymInt >& total_L) {
734
767
return {DenseToJaggedOp::apply (dense, offsets, total_L)[0 ], offsets};
735
768
}
736
769
0 commit comments