Skip to content

Commit 3995cc6

Browse files
Shen Lifacebook-github-bot
authored andcommitted
Allow jagged_index_select backward to accept pre-computed output shape
Summary: Save `num_dense_output_rows` computed during the forward pass and use it to avoid blocking `.item()` call during backward. Reviewed By: sryap Differential Revision: D54173841 fbshipit-source-id: 113c035d6462963d00df7545dd54ce4dd15ed753
1 parent ad70943 commit 3995cc6

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def jagged_index_add_2d_forward_v2_abstract(
226226
input_offsets: Tensor,
227227
output_offsets: Tensor,
228228
num_output_rows: int,
229+
num_dense_input_rows: Optional[int] = None,
229230
) -> Tensor:
230231
torch._check(values.device == indices.device)
231232
torch._check(values.device == input_offsets.device)

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ class JaggedIndexSelect2dOp
616616

617617
ctx->save_for_backward({indices, output_offsets, input_offsets});
618618
ctx->saved_data["num_input_rows"] = values.sym_size(0);
619+
ctx->saved_data["num_dense_output_rows"] = num_dense_output_rows;
619620

620621
static auto op =
621622
c10::Dispatcher::singleton()
@@ -652,6 +653,8 @@ class JaggedIndexSelect2dOp
652653
TENSORS_ON_SAME_DEVICE(grad, indices);
653654

654655
auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt();
656+
auto num_dense_input_rows =
657+
ctx->saved_data["num_dense_output_rows"].toOptional<int64_t>();
655658

656659
static auto op =
657660
c10::Dispatcher::singleton()
@@ -661,10 +664,17 @@ class JaggedIndexSelect2dOp
661664
const Tensor& indices,
662665
const Tensor& input_offsets,
663666
const Tensor& output_offsets,
664-
c10::SymInt num_output_rows)>();
667+
c10::SymInt num_output_rows,
668+
const c10::optional<int64_t> optional_num_dense_input_rows)>();
665669

666670
return {
667-
op.call(grad, indices, grad_offsets, output_offsets, num_output_rows),
671+
op.call(
672+
grad,
673+
indices,
674+
grad_offsets,
675+
output_offsets,
676+
num_output_rows,
677+
num_dense_input_rows),
668678
torch::autograd::Variable(), // lengths
669679
torch::autograd::Variable(), // indices
670680
torch::autograd::Variable() // num_dense_output_rows

fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,9 +1189,14 @@ Tensor jagged_index_add_2d_forward_v2_impl(
11891189
const Tensor& indices,
11901190
const Tensor& input_offsets,
11911191
const Tensor& output_offsets,
1192-
const int64_t num_output_rows) {
1193-
int64_t num_dense_output_rows =
1194-
input_offsets[input_offsets.numel() - 1].item<int64_t>();
1192+
const int64_t num_output_rows,
1193+
const c10::optional<int64_t> optional_num_dense_input_rows) {
1194+
// Intentionally not using optional::value_or here to avoid materializing
1195+
// .item() call when possible.
1196+
int64_t num_dense_input_rows = optional_num_dense_input_rows.has_value()
1197+
? optional_num_dense_input_rows.value()
1198+
: input_offsets[input_offsets.numel() - 1].item<int64_t>();
1199+
11951200
static auto v1_op =
11961201
c10::Dispatcher::singleton()
11971202
.findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "")
@@ -1207,7 +1212,7 @@ Tensor jagged_index_add_2d_forward_v2_impl(
12071212
indices,
12081213
input_offsets,
12091214
output_offsets,
1210-
num_dense_output_rows,
1215+
num_dense_input_rows,
12111216
num_output_rows);
12121217
}
12131218

@@ -1730,7 +1735,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
17301735
m.def(
17311736
"jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor");
17321737
m.def(
1733-
"jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows) -> Tensor",
1738+
"jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, int? num_dense_input_rows) -> Tensor",
17341739
{PT2_COMPLIANT_TAG});
17351740
m.def(
17361741
"jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor");

0 commit comments

Comments
 (0)