Skip to content

Commit f94254d

Browse files
ezyangfacebook-github-bot
authored andcommitted
Add scaffolding for Python impl_abstract in fbgemm, implement fbgemm.permute_1D_sparse_data (pytorch#2084)
Summary: This also fixes a minor bug in GPU permute_1D_sparse_data where we need to clone the zero-size tensors to correctly setup (lack of) aliasing. Pull Request resolved: pytorch#2084 Reviewed By: sryap Differential Revision: D50563192 fbshipit-source-id: 1dc31580c54d8a0dfd3aadaf9b440636fd1a8550
1 parent b1049cf commit f94254d

File tree

6 files changed

+58
-28
lines changed

6 files changed

+58
-28
lines changed

fbgemm_gpu/fbgemm_gpu/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
open_source: bool = True
2020

2121
# Re-export docs
22-
from . import _fbgemm_gpu_docs # noqa: F401, E402
22+
# Trigger meta registrations
23+
from . import _fbgemm_gpu_docs, sparse_operators # noqa: F401, E402 # noqa: F401, E402
2324

2425
# Re-export the version string from the auto-generated version file
2526
from ._fbgemm_gpu_version import __version__ # noqa: F401, E402
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional, Tuple
8+
9+
import torch
10+
from torch import Tensor
11+
12+
try:
13+
# pyre-ignore
14+
from fbgemm_gpu import open_source # noqa: F401
15+
except Exception:
16+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
17+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
18+
19+
20+
@torch.library.impl_abstract("fbgemm::permute_2D_sparse_data")
21+
def permute_2D_sparse_data_meta(
22+
permute: Tensor,
23+
lengths: Tensor,
24+
values: Tensor,
25+
weights: Optional[Tensor] = None,
26+
permuted_lengths_sum: Optional[int] = None,
27+
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
28+
torch._check(
29+
lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}"
30+
)
31+
T = permute.numel()
32+
B = lengths.size(1)
33+
indices = values
34+
permuted_lengths = lengths.new_empty([T, B])
35+
permuted_indices_size = 0
36+
if permuted_lengths_sum is not None:
37+
permuted_indices_size = permuted_lengths_sum
38+
else:
39+
ctx = torch._custom_op.impl.get_ctx()
40+
permuted_indices_size = ctx.new_dynamic_size()
41+
# pyre-fixme
42+
permuted_indices = indices.new_empty(permuted_indices_size)
43+
permuted_weights = None
44+
if weights is not None:
45+
# pyre-fixme
46+
permuted_weights = weights.new_empty(permuted_indices_size)
47+
return permuted_lengths, permuted_indices, permuted_weights

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,6 +2686,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
26862686
"permute_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
26872687
m.def(
26882688
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
2689+
m.impl_abstract_pystub(
2690+
"permute_2D_sparse_data",
2691+
"fbgemm_gpu.operators",
2692+
"//deeplearning/fbgemm/fbgemm_gpu:operators");
26892693
m.def(
26902694
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
26912695
m.def("invert_permute(Tensor permute) -> Tensor");

fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ permute_2D_sparse_data_cuda(
9292
// When T = 0 or B = 0, permutation will not be performed. Return the
9393
// input tensors.
9494
return {
95-
lengths,
96-
indices,
97-
weights,
95+
lengths.clone(),
96+
indices.clone(),
97+
weights.has_value() ? c10::make_optional(weights->clone())
98+
: c10::nullopt,
9899
};
99100
}
100101

fbgemm_gpu/test/failures_dict.json

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -525,18 +525,6 @@
525525
}
526526
},
527527
"fbgemm::permute_2D_sparse_data": {
528-
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_embeddings": {
529-
"comment": "",
530-
"status": "xfail"
531-
},
532-
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices": {
533-
"comment": "",
534-
"status": "xfail"
535-
},
536-
"SparseOpsTest.test_aot_dispatch_dynamic__test_permute_indices_with_repeats": {
537-
"comment": "",
538-
"status": "xfail"
539-
},
540528
"SparseOpsTest.test_aot_dispatch_static__test_permute_embeddings": {
541529
"comment": "",
542530
"status": "xfail"
@@ -549,18 +537,6 @@
549537
"comment": "",
550538
"status": "xfail"
551539
},
552-
"SparseOpsTest.test_faketensor__test_permute_embeddings": {
553-
"comment": "",
554-
"status": "xfail"
555-
},
556-
"SparseOpsTest.test_faketensor__test_permute_indices": {
557-
"comment": "",
558-
"status": "xfail"
559-
},
560-
"SparseOpsTest.test_faketensor__test_permute_indices_with_repeats": {
561-
"comment": "",
562-
"status": "xfail"
563-
},
564540
"SparseOpsTest.test_schema__test_permute_indices": {
565541
"comment": "flaky",
566542
"status": "skip"

fbgemm_gpu/test/sparse_ops_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
3737
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
3838
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")
39+
import fbgemm_gpu.sparse_operators # noqa: F401, E402
3940
from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable, skipIfRocm
4041

4142
suppressed_list: List[HealthCheck] = (

0 commit comments

Comments
 (0)