Skip to content

Commit 0efbeed

Browse files
AGZainfacebook-github-bot
authored andcommitted
Add permute_duplicate_pooled_embeddings op for CPU (pytorch#1939)
Summary: Pull Request resolved: pytorch#1939 This diff builds ontop of the pervious diff and adds support for permute_duplicate_pooled_embeddings for CPU. # Background Currently permute_pooled_embs_gpu does not support duplicates in a permutation, this poses a problem with passing the same embeddings to multiple modules. This doc proposes a solution to allow duplicate subsets in the resultant permutation. # Details The required implementation of permute_duplicate_pooled_embs_gpu should support a subset being repeated. This is represented by having duplicates in the permute list. This also results in the output list size being greater than the input list. Input: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] Offset_dims: [0, 2, 5, 6, 10] Permute: [3, 0, 2, 1, 3] Output: [6, 7, 8, 9, 0, 1, 5, 2, 3, 4, 6, 7, 8, 9] Differential Revision: D48305145 fbshipit-source-id: 0e6e325eab8f1907991c22594a32e8f0937a914f
1 parent f61798e commit 0efbeed

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embedding_ops.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,28 @@
1111
#include <ATen/ATen.h>
1212

1313
namespace fbgemm_gpu {
14-
at::Tensor permute_pooled_embs_cpu(
14+
15+
at::Tensor permute_pooled_embs_cpu_impl(
1516
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
1617
const at::Tensor& offset_dim_list,
1718
const at::Tensor& permute_list,
1819
const at::Tensor& inv_offset_dim_list,
1920
const at::Tensor& inv_permute_list,
20-
const bool&);
21+
const bool& allow_duplicates);
22+
23+
at::Tensor permute_duplicate_pooled_embs_cpu(
24+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
25+
const at::Tensor& offset_dim_list,
26+
const at::Tensor& permute_list,
27+
const at::Tensor& inv_offset_dim_list,
28+
const at::Tensor& inv_permute_list);
29+
30+
at::Tensor permute_pooled_embs_cpu(
31+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
32+
const at::Tensor& offset_dim_list,
33+
const at::Tensor& permute_list,
34+
const at::Tensor& inv_offset_dim_list,
35+
const at::Tensor& inv_permute_list);
2136

2237
at::Tensor permute_duplicate_pooled_embs_gpu(
2338
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]

fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ using Tensor = at::Tensor;
2222

2323
namespace fbgemm_gpu {
2424

25-
///@ingroup permute-pooled-embs-cpu
26-
Tensor permute_pooled_embs_cpu(
25+
///@ingroup permute-pooled-embs-cpu-impl
26+
Tensor permute_pooled_embs_cpu_impl(
2727
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
2828
const Tensor& offset_dim_list,
2929
const Tensor& permute_list,
3030
const Tensor& inv_offset_dim_list,
31-
const Tensor& inv_permute_list) {
31+
const Tensor& inv_permute_list,
32+
const bool& allow_duplicates) {
3233
TORCH_CHECK(
3334
offset_dim_list.scalar_type() == at::ScalarType::Long,
3435
"offset_dim_list needs to have long/int64 type")
@@ -37,9 +38,10 @@ Tensor permute_pooled_embs_cpu(
3738
"permute_list needs to have long/int64 type")
3839
auto permute = permute_list.data_ptr<int64_t>();
3940
const auto n = permute_list.numel();
41+
const auto dims_size = allow_duplicates ? offset_dim_list.numel() : n;
4042
std::vector<int64_t> dims;
41-
dims.reserve(n - 1);
42-
for (const auto i : c10::irange(1, n)) {
43+
dims.reserve(dims_size - 1);
44+
for (const auto i : c10::irange(1, dims_size)) {
4345
dims.push_back(offset_dim_list[i].item<int64_t>());
4446
}
4547
auto ts = pooled_embs.tensor_split(dims, 1);
@@ -51,6 +53,38 @@ Tensor permute_pooled_embs_cpu(
5153
return at::cat(permuted_ts, 1);
5254
}
5355

56+
///@ingroup permute-duplicate-pooled-embs-cpu
57+
at::Tensor permute_pooled_embs_cpu(
58+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
59+
const at::Tensor& offset_dim_list,
60+
const at::Tensor& permute_list,
61+
const at::Tensor& inv_offset_dim_list,
62+
const at::Tensor& inv_permute_list) {
63+
return permute_pooled_embs_cpu_impl(
64+
pooled_embs,
65+
offset_dim_list,
66+
permute_list,
67+
inv_offset_dim_list,
68+
inv_permute_list,
69+
false);
70+
}
71+
72+
///@ingroup permute-pooled-embs-cpu
73+
at::Tensor permute_duplicate_pooled_embs_cpu(
74+
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
75+
const at::Tensor& offset_dim_list,
76+
const at::Tensor& permute_list,
77+
const at::Tensor& inv_offset_dim_list,
78+
const at::Tensor& inv_permute_list) {
79+
return permute_pooled_embs_cpu_impl(
80+
pooled_embs,
81+
offset_dim_list,
82+
permute_list,
83+
inv_offset_dim_list,
84+
inv_permute_list,
85+
true);
86+
}
87+
5488
using torch::autograd::AutogradContext;
5589
using torch::autograd::Variable;
5690
using torch::autograd::variable_list;
@@ -163,6 +197,21 @@ Tensor permute_duplicate_pooled_embs_auto_grad_gpu(
163197
inv_offset_dim_list,
164198
inv_permute_list);
165199
}
200+
201+
///@ingroup permute-duplicate-pooled-embs-cpu
202+
Tensor permute_duplicate_pooled_embs_auto_grad_cpu(
203+
const Tensor& pooled_embs,
204+
const Tensor& offset_dim_list,
205+
const Tensor& permute_list,
206+
const Tensor& inv_offset_dim_list,
207+
const Tensor& inv_permute_list) {
208+
return PermutePooledEmbsFunction<permute_duplicate_pooled_embs_cpu>::apply(
209+
pooled_embs,
210+
offset_dim_list,
211+
permute_list,
212+
inv_offset_dim_list,
213+
inv_permute_list);
214+
}
166215
} // namespace fbgemm_gpu
167216

168217
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
@@ -185,4 +234,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
185234
DISPATCH_TO_CUDA(
186235
"permute_duplicate_pooled_embs_auto_grad",
187236
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_gpu);
237+
DISPATCH_TO_CPU(
238+
"permute_duplicate_pooled_embs_auto_grad",
239+
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_cpu);
188240
}

fbgemm_gpu/test/permute_pooled_embedding_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,19 @@ def test_duplicate_permutations(self) -> None:
209209
expected_result,
210210
)
211211

212+
input = input.to(device="cpu")
213+
result = torch.ops.fbgemm.permute_duplicate_pooled_embs_auto_grad(
214+
input,
215+
_offset_dim_list.to(device=input.device),
216+
_permute.to(device=input.device),
217+
_inv_offset_dim_list.to(device=input.device),
218+
_inv_permute.to(device=input.device),
219+
)
220+
self.assertEqual(
221+
result.view(16).tolist(),
222+
expected_result,
223+
)
224+
212225

213226
if __name__ == "__main__":
214227
unittest.main()

0 commit comments

Comments
 (0)