Skip to content

Commit 457bac0

Browse files
jianyuhfacebook-github-bot
authored andcommitted
Follow up on BC issue for open sourcing TBE inplace update op (pytorch#1492)
Summary: Pull Request resolved: pytorch#1492 Reviewed By: jspark1105 Differential Revision: D41717190 fbshipit-source-id: 818c54fb236e72b5816921e7d2c579843d346d47
1 parent 4cd267c commit 457bac0

9 files changed

+36
-19
lines changed

.github/workflows/fbgemmci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ jobs:
197197
runs-on: ${{ matrix.os }}
198198
strategy:
199199
matrix:
200-
os: [ubuntu-latest]
200+
os: [ubuntu-20.04]
201201
config: [[pip, 11.3], [pip, 11.5], [pip, 11.6], [pip, 11.7], [conda, 11.7]]
202202

203203
steps:

fbgemm_gpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ set(codegen_dependencies
161161
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/dispatch_macros.h
162162
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_backward_template_helpers.cuh
163163
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_common.h
164+
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_inplace_update.h
164165
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
165166
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_utils.h
166167
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/split_embeddings_utils.cuh

fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
1818
torch.ops.load_library(
1919
"//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings"
2020
)
21-
try:
22-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update")
23-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update_cpu")
24-
except OSError:
25-
# Keep for BC: will be deprecated soon.
26-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/fb:embedding_inplace_update")
27-
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/fb:embedding_inplace_update_cpu")
21+
# try:
22+
# torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update")
23+
# torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update_cpu")
24+
# except OSError:
25+
# # Keep for BC: will be deprecated soon.
26+
# torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/fb:embedding_inplace_update")
27+
# torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/fb:embedding_inplace_update_cpu")
2828

2929
{% else %}
3030
#import os

fbgemm_gpu/src/embedding_inplace_update.h renamed to fbgemm_gpu/include/fbgemm_gpu/embedding_inplace_update.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,22 @@ void embedding_inplace_update_cuda(
7171
c10::optional<Tensor> lxu_cache_weights = c10::nullopt,
7272
c10::optional<Tensor> lxu_cache_locations = c10::nullopt);
7373

74+
void embedding_inplace_update_cpu(
75+
Tensor dev_weights,
76+
Tensor uvm_weights,
77+
Tensor weights_placements,
78+
Tensor weights_offsets,
79+
Tensor weights_tys,
80+
Tensor D_offsets,
81+
Tensor update_weights,
82+
Tensor update_table_idx,
83+
Tensor update_row_idx,
84+
Tensor update_offsets,
85+
const int64_t row_alignment,
86+
c10::optional<Tensor> lxu_cache_weights =
87+
c10::nullopt, // Not used, to match cache interface for CUDA op
88+
c10::optional<Tensor> lxu_cache_locations =
89+
c10::nullopt // Not used, to match cache interface for CUDA op
90+
);
91+
7492
} // namespace fbgemm_gpu

fbgemm_gpu/src/embedding_inplace_update.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include <c10/cuda/CUDAGuard.h>
1212

13-
#include "embedding_inplace_update.h"
13+
#include "fbgemm_gpu/embedding_inplace_update.h"
1414
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
1515

1616
using Tensor = at::Tensor;

fbgemm_gpu/src/embedding_inplace_update_cpu.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <ATen/ATen.h>
1313
#include <torch/library.h>
1414

15-
#include "embedding_inplace_update.h"
15+
#include "fbgemm_gpu/embedding_inplace_update.h"
1616

1717
using Tensor = at::Tensor;
1818

@@ -72,11 +72,8 @@ void embedding_inplace_update_cpu(
7272
Tensor update_row_idx,
7373
Tensor update_offsets,
7474
const int64_t row_alignment,
75-
c10::optional<Tensor> lxu_cache_weights =
76-
c10::nullopt, // Not used, to match cache interface for CUDA op
77-
c10::optional<Tensor> lxu_cache_locations =
78-
c10::nullopt // Not used, to match cache interface for CUDA op
79-
) {
75+
c10::optional<Tensor> lxu_cache_weights,
76+
c10::optional<Tensor> lxu_cache_locations) {
8077
TENSOR_ON_CPU(dev_weights);
8178
TENSOR_ON_CPU(uvm_weights);
8279
TENSOR_ON_CPU(weights_placements);

fbgemm_gpu/src/embedding_inplace_update_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <ATen/core/op_registration/op_registration.h>
99
#include <ATen/cuda/CUDAContext.h>
1010
#include <torch/library.h>
11-
#include "embedding_inplace_update.h"
11+
#include "fbgemm_gpu/embedding_inplace_update.h"
1212

1313
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
1414
DISPATCH_TO_CUDA(

fbgemm_gpu/test/embedding_inplace_update_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
*/
77
#include <folly/Random.h>
88
#include <gtest/gtest.h>
9-
#include "deeplearning/fbgemm/fbgemm_gpu/src/embedding_inplace_update.h"
9+
#include "fbgemm_gpu/embedding_inplace_update.h"
1010

1111
using namespace ::testing;
1212
using namespace fbgemm_gpu;

fbgemm_gpu/test/split_table_batched_embeddings_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4685,8 +4685,9 @@ def test_embedding_inplace_update(
46854685
)
46864686

46874687
weights_ty_list = [weights_ty] * T
4688-
if open_source:
4689-
test_internal = False
4688+
# if open_source:
4689+
# test_internal = False
4690+
test_internal = False
46904691

46914692
# create two embedding bag op with random weights
46924693
locations = [location] * T

0 commit comments

Comments
 (0)