Skip to content

Commit 41b4c9a

Browse files
AGZainfacebook-github-bot
authored andcommitted
Add permute_duplicate_pooled_embeddings op for CPU (pytorch#1939)
Summary: Pull Request resolved: pytorch#1939 This diff builds ontop of the pervious diff and adds support for permute_duplicate_pooled_embeddings for CPU. # Background Currently permute_pooled_embs_gpu does not support duplicates in a permutation, this poses a problem with passing the same embeddings to multiple modules. This doc proposes a solution to allow duplicate subsets in the resultant permutation. # Details The required implementation of permute_duplicate_pooled_embs_gpu should support a subset being repeated. This is represented by having duplicates in the permute list. This also results in the output list size being greater than the input list. Input: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] Offset_dims: [0, 2, 5, 6, 10] Permute: [3, 0, 2, 1, 3] Output: [6, 7, 8, 9, 0, 1, 5, 2, 3, 4, 6, 7, 8, 9] Reviewed By: sryap Differential Revision: D48305145 fbshipit-source-id: a984bebb9f8974015f3d2a4f6a806d7e6d391275
1 parent f2185d6 commit 41b4c9a

File tree

3 files changed

+90
-5
lines changed

3 files changed

+90
-5
lines changed

fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embedding_ops.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
#include <ATen/ATen.h>
1212

1313
namespace fbgemm_gpu {
14+
15+
at::Tensor permute_pooled_embs_cpu_impl(
16+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
17+
const at::Tensor& offset_dim_list,
18+
const at::Tensor& permute_list,
19+
const at::Tensor& inv_offset_dim_list,
20+
const at::Tensor& inv_permute_list,
21+
const bool& allow_duplicates);
22+
23+
at::Tensor permute_duplicate_pooled_embs_cpu(
24+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
25+
const at::Tensor& offset_dim_list,
26+
const at::Tensor& permute_list,
27+
const at::Tensor& inv_offset_dim_list,
28+
const at::Tensor& inv_permute_list);
29+
1430
at::Tensor permute_pooled_embs_cpu(
1531
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
1632
const at::Tensor& offset_dim_list,

fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ using Tensor = at::Tensor;
2222

2323
namespace fbgemm_gpu {
2424

25-
///@ingroup permute-pooled-embs-cpu
26-
Tensor permute_pooled_embs_cpu(
25+
///@ingroup permute-pooled-embs-cpu-impl
26+
Tensor permute_pooled_embs_cpu_impl(
2727
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
2828
const Tensor& offset_dim_list,
2929
const Tensor& permute_list,
3030
const Tensor& inv_offset_dim_list,
31-
const Tensor& inv_permute_list) {
31+
const Tensor& inv_permute_list,
32+
const bool& allow_duplicates) {
3233
TORCH_CHECK(
3334
offset_dim_list.scalar_type() == at::ScalarType::Long,
3435
"offset_dim_list needs to have long/int64 type")
@@ -37,9 +38,10 @@ Tensor permute_pooled_embs_cpu(
3738
"permute_list needs to have long/int64 type")
3839
auto permute = permute_list.data_ptr<int64_t>();
3940
const auto n = permute_list.numel();
41+
const auto dims_size = allow_duplicates ? offset_dim_list.numel() : n;
4042
std::vector<int64_t> dims;
41-
dims.reserve(n - 1);
42-
for (const auto i : c10::irange(1, n)) {
43+
dims.reserve(dims_size - 1);
44+
for (const auto i : c10::irange(1, dims_size)) {
4345
dims.push_back(offset_dim_list[i].item<int64_t>());
4446
}
4547
auto ts = pooled_embs.tensor_split(dims, 1);
@@ -51,6 +53,38 @@ Tensor permute_pooled_embs_cpu(
5153
return at::cat(permuted_ts, 1);
5254
}
5355

56+
///@ingroup permute-duplicate-pooled-embs-cpu
57+
at::Tensor permute_pooled_embs_cpu(
58+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
59+
const at::Tensor& offset_dim_list,
60+
const at::Tensor& permute_list,
61+
const at::Tensor& inv_offset_dim_list,
62+
const at::Tensor& inv_permute_list) {
63+
return permute_pooled_embs_cpu_impl(
64+
pooled_embs,
65+
offset_dim_list,
66+
permute_list,
67+
inv_offset_dim_list,
68+
inv_permute_list,
69+
false);
70+
}
71+
72+
///@ingroup permute-pooled-embs-cpu
73+
at::Tensor permute_duplicate_pooled_embs_cpu(
74+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
75+
const at::Tensor& offset_dim_list,
76+
const at::Tensor& permute_list,
77+
const at::Tensor& inv_offset_dim_list,
78+
const at::Tensor& inv_permute_list) {
79+
return permute_pooled_embs_cpu_impl(
80+
pooled_embs,
81+
offset_dim_list,
82+
permute_list,
83+
inv_offset_dim_list,
84+
inv_permute_list,
85+
true);
86+
}
87+
5488
using torch::autograd::AutogradContext;
5589
using torch::autograd::Variable;
5690
using torch::autograd::variable_list;
@@ -201,6 +235,22 @@ Tensor permute_duplicate_pooled_embs_auto_grad_gpu(
201235
inv_permute_list,
202236
true);
203237
}
238+
239+
///@ingroup permute-duplicate-pooled-embs-cpu
240+
Tensor permute_duplicate_pooled_embs_auto_grad_cpu(
241+
const Tensor& pooled_embs,
242+
const Tensor& offset_dim_list,
243+
const Tensor& permute_list,
244+
const Tensor& inv_offset_dim_list,
245+
const Tensor& inv_permute_list) {
246+
return PermutePooledEmbsFunction::apply(
247+
pooled_embs,
248+
offset_dim_list,
249+
permute_list,
250+
inv_offset_dim_list,
251+
inv_permute_list,
252+
true);
253+
}
204254
} // namespace fbgemm_gpu
205255

206256
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
@@ -228,9 +278,15 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
228278
DISPATCH_TO_CUDA(
229279
"permute_duplicate_pooled_embs",
230280
fbgemm_gpu::permute_duplicate_pooled_embs_gpu);
281+
DISPATCH_TO_CPU(
282+
"permute_duplicate_pooled_embs",
283+
fbgemm_gpu::permute_duplicate_pooled_embs_cpu);
231284
m.def(
232285
"permute_duplicate_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
233286
DISPATCH_TO_CUDA(
234287
"permute_duplicate_pooled_embs_auto_grad",
235288
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_gpu);
289+
DISPATCH_TO_CPU(
290+
"permute_duplicate_pooled_embs_auto_grad",
291+
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_cpu);
236292
}

fbgemm_gpu/test/permute_pooled_embedding_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,19 @@ def test_duplicate_permutations(self) -> None:
209209
expected_result,
210210
)
211211

212+
input = input.to(device="cpu")
213+
result = torch.ops.fbgemm.permute_duplicate_pooled_embs_auto_grad(
214+
input,
215+
_offset_dim_list.to(device=input.device),
216+
_permute.to(device=input.device),
217+
_inv_offset_dim_list.to(device=input.device),
218+
_inv_permute.to(device=input.device),
219+
)
220+
self.assertEqual(
221+
result.view(16).tolist(),
222+
expected_result,
223+
)
224+
212225

213226
if __name__ == "__main__":
214227
unittest.main()

0 commit comments

Comments
 (0)