Skip to content

Commit 6faaa4c

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
Symbolic shape tracing on jagged op (pytorch#1758)
Summary: Pull Request resolved: pytorch#1758 fbgemm's jagged op doesn't have symint support in the meta function, resulting in all the shapes specialized (materialized int baked in the model instead of symbolic shape). Fixing them. Differential Revision: D44736488 Privacy Context Container: L1156430 fbshipit-source-id: b7ed2e2f8d3fa20adafebc290eab2de79992f3f6
1 parent dc0c29b commit 6faaa4c

File tree

8 files changed

+408
-68
lines changed

8 files changed

+408
-68
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ std::tuple<at::Tensor, std::vector<at::Tensor>> jagged_dense_elementwise_mul(
431431
std::tuple<at::Tensor, std::vector<at::Tensor>> dense_to_jagged(
432432
const at::Tensor& dense,
433433
const std::vector<at::Tensor>& offsets,
434-
const c10::optional<int64_t>& total_L);
434+
const c10::optional<at::SymInt>& total_L);
435435

436436
std::tuple<at::Tensor, std::vector<at::Tensor>>
437437
jagged_dense_elementwise_add_jagged_output(

fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ namespace fbgemm_gpu {
1515
Tensor dense_to_jagged_forward(
1616
const Tensor& dense,
1717
const std::vector<Tensor>& offsets,
18-
const c10::optional<int64_t>& total_L) {
18+
const c10::optional<at::SymInt>& total_L) {
1919
// D is the embedding dimension
2020
auto D = dense.size(-1);
2121

2222
// If total_L is not given then compute it
23-
int64_t total_L_computed;
23+
at::SymInt total_L_computed;
2424
if (total_L.has_value()) {
2525
total_L_computed = total_L.value();
2626
} else {
2727
total_L_computed = (int64_t)offsets.back().max().item<int64_t>();
2828
}
29-
auto values = at::empty({total_L_computed, D}, dense.options());
29+
auto values = at::empty_symint({total_L_computed, D}, dense.options());
3030
auto output = at::empty_like(values);
3131

3232
at::cuda::OptionalCUDAGuard device_guard;

fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace fbgemm_gpu {
1515
at::Tensor jagged_to_padded_dense_backward(
1616
const Tensor& grad_output,
1717
const std::vector<Tensor>& offsets,
18-
const int64_t total_L) {
18+
const at::SymInt& total_L) {
1919
auto grad_padded_values = grad_output;
2020
at::cuda::OptionalCUDAGuard device_guard;
2121
device_guard.set_index(grad_padded_values.get_device());
@@ -29,7 +29,8 @@ at::Tensor jagged_to_padded_dense_backward(
2929

3030
// Initialize with zeros so output will be zero for the portion truncated
3131
// in forward.
32-
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());
32+
auto grad_values =
33+
at::zeros_symint({total_L, D}, grad_padded_values.options());
3334

3435
AT_DISPATCH_FLOATING_TYPES_AND2(
3536
at::ScalarType::Half,

fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace fbgemm_gpu {
2121
at::Tensor jagged_to_padded_dense_forward(
2222
const Tensor& values,
2323
const std::vector<Tensor>& offsets,
24-
const std::vector<int64_t>& max_lengths,
24+
const at::ArrayRef<at::SymInt>& max_lengths,
2525
const double padding_value) {
2626
const size_t num_jagged_dim = offsets.size();
2727
TORCH_CHECK(
@@ -40,7 +40,7 @@ at::Tensor jagged_to_padded_dense_forward(
4040
values.sizes().end(),
4141
1,
4242
std::multiplies<size_t>())});
43-
at::DimVector padded_values_shape({offsets[0].size(0) - 1});
43+
at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)});
4444
padded_values_shape.insert(
4545
padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
4646

@@ -50,7 +50,8 @@ at::Tensor jagged_to_padded_dense_forward(
5050
if (!D_folded) {
5151
padded_values_shape.push_back(values.size(-1));
5252
}
53-
Tensor padded_values = at::empty(padded_values_shape, values.options());
53+
Tensor padded_values =
54+
at::empty_symint(padded_values_shape, values.options());
5455
Tensor padded_values_view =
5556
D_folded ? padded_values.unsqueeze(-1) : padded_values;
5657

@@ -121,7 +122,7 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
121122
padded_values_per_key.push_back(jagged_to_padded_dense_forward(
122123
values.slice(0, offset_per_key[t], offset_per_key[t + 1]),
123124
{offsets},
124-
{max_L},
125+
at::ArrayRef<at::SymInt>({max_L}),
125126
padding_value));
126127
}
127128
return padded_values_per_key;
@@ -179,7 +180,7 @@ stacked_jagged_2d_to_dense_forward_cuda(
179180
padded_values_per_key.push_back(jagged_to_padded_dense_forward(
180181
values.slice(0, offset_per_key[t], offset_per_key[t + 1]),
181182
{offsets},
182-
{max_L},
183+
at::ArrayRef<at::SymInt>({max_L}),
183184
padding_value));
184185
}
185186

@@ -301,7 +302,10 @@ Tensor jagged_2d_to_dense_gpu_forward(
301302
Tensor offsets,
302303
int64_t max_sequence_length) {
303304
return jagged_to_padded_dense_forward(
304-
values, {offsets}, {max_sequence_length}, /*padding_value=*/0);
305+
values,
306+
{offsets},
307+
c10::ArrayRef<c10::SymInt>({max_sequence_length}),
308+
/*padding_value=*/0);
305309
}
306310

307311
namespace {
@@ -369,7 +373,8 @@ class JaggedDenseAddJaggedOutputGPUOp
369373
Tensor dense_values_grad = jagged_to_padded_dense_forward(
370374
grad_outputs[0],
371375
offsets,
372-
std::vector<int64_t>(dense_shape.begin() + 1, dense_shape.end() - 1),
376+
c10::fromIntArrayRefKnownNonNegative(std::vector<int64_t>(
377+
dense_shape.begin() + 1, dense_shape.end() - 1)),
373378
/*padding_value=*/0);
374379
TORCH_CHECK(dense_values_grad.sizes() == dense_shape);
375380

fbgemm_gpu/src/jagged_tensor_ops_autograd.cpp

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/core/dispatch/Dispatcher.h>
1212
#include <torch/csrc/autograd/custom_function.h>
1313
#include <torch/library.h>
14+
#include <torch/torch.h>
1415

1516
#include "ATen/TensorUtils.h"
1617
#include "fbgemm_gpu/sparse_ops.h"
@@ -35,17 +36,18 @@ class JaggedToPaddedDenseOp
3536
const std::vector<int64_t>& max_lengths,
3637
const double padding_value) {
3738
ctx->save_for_backward(offsets);
38-
ctx->saved_data["total_L"] = values.size(0);
39+
ctx->saved_data["total_L"] = values.sym_size(0);
3940

4041
static auto op =
4142
c10::Dispatcher::singleton()
4243
.findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "")
4344
.typed<at::Tensor(
4445
const Tensor& values,
4546
const std::vector<Tensor>& offsets,
46-
const std::vector<int64_t>& max_lengths,
47+
const at::ArrayRef<at::SymInt>& max_lengths,
4748
const double padding_value)>();
48-
Tensor padded_values = op.call(values, offsets, max_lengths, padding_value);
49+
Tensor padded_values = op.call(
50+
values, offsets, c10::fromIntArrayRefSlow(max_lengths), padding_value);
4951

5052
return {padded_values};
5153
}
@@ -54,7 +56,7 @@ class JaggedToPaddedDenseOp
5456
torch::autograd::AutogradContext* ctx,
5557
torch::autograd::variable_list grad_outputs) {
5658
auto offsets = ctx->get_saved_variables();
57-
int32_t total_L = ctx->saved_data["total_L"].toInt();
59+
at::SymInt total_L = ctx->saved_data["total_L"].toSymInt();
5860
TORCH_CHECK(grad_outputs.size() == 1);
5961

6062
TORCH_CHECK(total_L >= 0);
@@ -64,7 +66,7 @@ class JaggedToPaddedDenseOp
6466
.typed<at::Tensor(
6567
const Tensor& grad_output,
6668
const std::vector<Tensor>& offsets,
67-
const int64_t total_L)>();
69+
const at::SymInt& total_L)>();
6870
auto grad_values = op.call(grad_outputs[0], {offsets}, total_L);
6971

7072
return {
@@ -86,7 +88,15 @@ class JaggedDenseDenseAddJaggedOutputOp
8688
const Tensor& dense_0,
8789
const Tensor& dense_1) {
8890
ctx->save_for_backward(offsets);
91+
#if TORCH_VERSION_MAJOR > 2 || \
92+
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
93+
// toSymIntVector support is from a recent PR
94+
// https://github.com/pytorch/pytorch/pull/101056,
95+
// so protect it under a version guard for compatibility
96+
ctx->saved_data["dense_shape"] = dense_0.sym_sizes();
97+
#else
8998
ctx->saved_data["dense_shape"] = dense_0.sizes();
99+
#endif
90100

91101
static auto op =
92102
c10::Dispatcher::singleton()
@@ -107,7 +117,12 @@ class JaggedDenseDenseAddJaggedOutputOp
107117
torch::autograd::AutogradContext* ctx,
108118
torch::autograd::variable_list grad_outputs) {
109119
auto offsets = ctx->get_saved_variables();
120+
#if TORCH_VERSION_MAJOR > 2 || \
121+
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
122+
auto dense_shape = ctx->saved_data["dense_shape"].toSymIntVector();
123+
#else
110124
auto dense_shape = ctx->saved_data["dense_shape"].toIntVector();
125+
#endif
111126
TORCH_CHECK(grad_outputs.size() == 1);
112127

113128
static auto op =
@@ -116,12 +131,12 @@ class JaggedDenseDenseAddJaggedOutputOp
116131
.typed<at::Tensor(
117132
const Tensor& values,
118133
const std::vector<Tensor>& offsets,
119-
const std::vector<int64_t>& max_lengths,
134+
const at::ArrayRef<at::SymInt>& max_lengths,
120135
const double padding_value)>();
121136
Tensor dense_values_grad_0 = op.call(
122137
grad_outputs[0],
123138
offsets,
124-
std::vector<int64_t>(dense_shape.begin() + 1, dense_shape.end() - 1),
139+
std::vector<at::SymInt>(dense_shape.begin() + 1, dense_shape.end() - 1),
125140
/*padding_value=*/0);
126141
Tensor dense_values_grad_1 = dense_values_grad_0;
127142

@@ -249,19 +264,27 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
249264
torch::autograd::AutogradContext* ctx,
250265
const Tensor& dense,
251266
const std::vector<Tensor>& offsets,
252-
const c10::optional<int64_t>& total_L) {
267+
const c10::optional<at::SymInt>& total_L) {
253268
ctx->save_for_backward(offsets);
254269

255270
// dims of dense tensor: <batch, [maxlen0, maxlen1, ...], embedding_dim>
271+
#if TORCH_VERSION_MAJOR > 2 || \
272+
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
273+
// toSymIntVector support is from a recent PR
274+
// https://github.com/pytorch/pytorch/pull/101056,
275+
// so protect it under a version guard for compatibility
276+
ctx->saved_data["dense_shape"] = dense.sym_sizes();
277+
#else
256278
ctx->saved_data["dense_shape"] = dense.sizes();
279+
#endif
257280

258281
static auto op =
259282
c10::Dispatcher::singleton()
260283
.findSchemaOrThrow("fbgemm::dense_to_jagged_forward", "")
261284
.typed<Tensor(
262285
const Tensor& dense,
263286
const std::vector<Tensor>& offsets,
264-
const c10::optional<int64_t>& total_L)>();
287+
const c10::optional<at::SymInt>& total_L)>();
265288
auto output = op.call(dense, offsets, total_L);
266289

267290
return {output};
@@ -271,7 +294,12 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
271294
torch::autograd::AutogradContext* ctx,
272295
torch::autograd::variable_list grad_outputs) {
273296
auto offsets = ctx->get_saved_variables();
297+
#if TORCH_VERSION_MAJOR > 2 || \
298+
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
299+
auto dense_shape = ctx->saved_data["dense_shape"].toSymIntVector();
300+
#else
274301
auto dense_shape = ctx->saved_data["dense_shape"].toIntVector();
302+
#endif
275303
TORCH_CHECK(grad_outputs.size() == 1);
276304

277305
static auto op =
@@ -280,15 +308,20 @@ class DenseToJaggedOp : public torch::autograd::Function<DenseToJaggedOp> {
280308
.typed<Tensor(
281309
const Tensor& values,
282310
const std::vector<Tensor>& offsets,
283-
const std::vector<int64_t>& max_lengths,
311+
const at::ArrayRef<at::SymInt>& max_lengths,
284312
const double padding_value)>();
285313
auto dense_values_grad = op.call(
286314
grad_outputs[0],
287315
offsets,
288-
std::vector<int64_t>(dense_shape.begin() + 1, dense_shape.end() - 1),
316+
std::vector<at::SymInt>(dense_shape.begin() + 1, dense_shape.end() - 1),
289317
/*padding_value=*/0);
290318

319+
#if TORCH_VERSION_MAJOR > 2 || \
320+
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
321+
TORCH_CHECK(dense_values_grad.sym_sizes() == dense_shape);
322+
#else
291323
TORCH_CHECK(dense_values_grad.sizes() == dense_shape);
324+
#endif
292325

293326
return {
294327
dense_values_grad,
@@ -730,7 +763,7 @@ Tensor batched_dense_vec_jagged_2d_mul(
730763
std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged(
731764
const Tensor& dense,
732765
const std::vector<Tensor>& offsets,
733-
const c10::optional<int64_t>& total_L) {
766+
const c10::optional<at::SymInt>& total_L) {
734767
return {DenseToJaggedOp::apply(dense, offsets, total_L)[0], offsets};
735768
}
736769

0 commit comments

Comments
 (0)