Skip to content

Commit 86c8cf0

Browse files
ezyangfacebook-github-bot
authored andcommitted
impl_abstract for permute_1D_sparse_data (pytorch#2087)
Summary: Pull Request resolved: pytorch#2087 Reviewed By: zou3519 Differential Revision: D50584541
1 parent f94254d commit 86c8cf0

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_operators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,29 @@ def permute_2D_sparse_data_meta(
4545
# pyre-fixme
4646
permuted_weights = weights.new_empty(permuted_indices_size)
4747
return permuted_lengths, permuted_indices, permuted_weights
48+
49+
50+
@torch.library.impl_abstract("fbgemm::permute_1D_sparse_data")
51+
def permute_1D_sparse_data_meta(
52+
permute: Tensor,
53+
lengths: Tensor,
54+
values: Tensor,
55+
weights: Optional[Tensor] = None,
56+
permuted_lengths_sum: Optional[int] = None,
57+
):
58+
indices = values
59+
permuted_lengths_size = permute.numel()
60+
permuted_lengths = lengths.new_empty([permuted_lengths_size])
61+
permuted_indices_size = 0
62+
if permuted_lengths_sum is not None:
63+
permuted_indices_size = permuted_lengths_sum
64+
else:
65+
ctx = torch._custom_op.impl.get_ctx()
66+
permuted_indices_size = ctx.new_dynamic_size()
67+
# pyre-fixme
68+
permuted_indices = indices.new_empty(permuted_indices_size)
69+
permuted_weights = None
70+
if weights is not None:
71+
# pyre-fixme
72+
permuted_weights = weights.new_empty(permuted_indices_size)
73+
return permuted_lengths, permuted_indices, permuted_weights

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,10 +2686,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
26862686
"permute_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
26872687
m.def(
26882688
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
2689-
m.impl_abstract_pystub(
2690-
"permute_2D_sparse_data",
2691-
"fbgemm_gpu.operators",
2692-
"//deeplearning/fbgemm/fbgemm_gpu:operators");
26932689
m.def(
26942690
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
26952691
m.def("invert_permute(Tensor permute) -> Tensor");

fbgemm_gpu/test/failures_dict.json

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -507,18 +507,10 @@
507507
}
508508
},
509509
"fbgemm::permute_1D_sparse_data": {
510-
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices": {
511-
"comment": "",
512-
"status": "xfail"
513-
},
514510
"SparseOpsTest.test_aot_dispatch_static__test_permute_indices": {
515511
"comment": "",
516512
"status": "xfail"
517513
},
518-
"SparseOpsTest.test_faketensor__test_permute_indices": {
519-
"comment": "",
520-
"status": "xfail"
521-
},
522514
"SparseOpsTest.test_schema__test_permute_indices": {
523515
"comment": "flaky",
524516
"status": "skip"

0 commit comments

Comments
 (0)