Skip to content

Commit 54a56f0

Browse files
houseroadfacebook-github-bot
authored andcommitted
Make auto select for embedding_ops and fix duplicate cu files (pytorch#2269)
Summary: Pull Request resolved: pytorch#2269 Both "embedding_ops" and "index_select_ops" include "gen_batch_index_select_dim0_backward_kernel_warp.cu", which caused the duplicate symbol issue. ~~So remove the "gen_embedding_backward_split_grad.cu" from the source of "index_select_ops", and include "embedding_ops" as its deps.~~ The linker cannot link to the "gen_embedding_backward_split_grad.cu" in the "embedding_ops" object, so we have to include this cu file explicitly. So introduce namespace to this file, and generate different files. Reviewed By: jiaqizhai Differential Revision: D52790445 fbshipit-source-id: a6779ebdf2028155ad20fbbf17588f635ecd6564
1 parent 85cd858 commit 54a56f0

File tree

4 files changed

+35
-3
lines changed

4 files changed

+35
-3
lines changed

fbgemm_gpu/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ set(gen_gpu_kernel_source_files
228228
"gen_batch_index_select_dim0_backward_codegen_cuda.cu"
229229
"gen_batch_index_select_dim0_backward_kernel_cta.cu"
230230
"gen_batch_index_select_dim0_backward_kernel_warp.cu"
231-
"gen_embedding_backward_split_grad.cu"
231+
"gen_embedding_backward_split_grad_embedding_ops.cu"
232+
"gen_embedding_backward_split_grad_index_select.cu"
232233
)
233234

234235
if(NOT USE_ROCM)

fbgemm_gpu/codegen/embedding_backward_code_generator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,10 @@ def index_select() -> None:
319319
)
320320

321321
template = env.get_template("embedding_backward_split_grad_template.cu")
322-
write("gen_embedding_backward_split_grad.cu", template.render())
322+
write(
323+
"gen_embedding_backward_split_grad_index_select.cu",
324+
template.render(is_index_select=True),
325+
)
323326

324327

325328
def forward_quantized() -> None:
@@ -461,7 +464,10 @@ class elem_type:
461464
def backward_grad() -> None:
462465
# Generate the common grad functions
463466
template = env.get_template("embedding_backward_split_grad_template.cu")
464-
write("gen_embedding_backward_split_grad.cu", template.render())
467+
write(
468+
"gen_embedding_backward_split_grad_embedding_ops.cu",
469+
template.render(is_index_select=False),
470+
)
465471

466472

467473
def backward_indices() -> None:

fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,16 @@
1212
#include "fbgemm_gpu/split_embeddings_utils.cuh"
1313

1414
using Tensor = at::Tensor;
15+
1516
using namespace fbgemm_gpu;
1617

18+
{% if is_index_select %}
19+
namespace index_select {
20+
{% else %}
21+
namespace embedding_ops {
22+
{% endif %}
23+
24+
1725
__global__ __launch_bounds__(kMaxThreads) void
1826
split_embedding_backward_codegen_find_long_segments(
1927
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
@@ -225,4 +233,6 @@ void grad_mean{{ vdesc }}_kernel
225233
{% endfor %} // for grad_type in ['at::Half', 'float']
226234
{% endfor %} // for vbe in [True, False]
227235

236+
}
237+
228238
// clang-format on

fbgemm_gpu/codegen/embedding_backward_split_template.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc
175175
{%- endif %}
176176
);
177177

178+
{% if is_index_select %}
179+
namespace index_select {
180+
{% else %}
181+
namespace embedding_ops {
182+
{% endif %}
183+
184+
178185
__global__ __launch_bounds__(kMaxThreads) void
179186
split_embedding_backward_codegen_find_long_segments(
180187
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs,
@@ -222,6 +229,14 @@ split_embedding_backward_count_unique_indices_kernel(
222229
const int info_B_num_bits
223230
);
224231

232+
}
233+
234+
{% if is_index_select %}
235+
using namespace index_select;
236+
{% else %}
237+
using namespace embedding_ops;
238+
{% endif %}
239+
225240
////////////////////////////////////////////////////////////////////////////////
226241
// Utility Macros
227242
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)