Skip to content

Implement generate_vbe_metadata cpu #3715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fbgemm_gpu/cmake/TbeTraining.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ gpu_cpp_library(
${gen_gpu_files_forward_split}
NVCC_FLAGS
${TORCH_CUDA_OPTIONS}
DEPS
fbgemm_gpu_tbe_common
DESTINATION
fbgemm_gpu)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@
#include "torch/csrc/autograd/record_function_ops.h"
#include "torch/csrc/autograd/record_function_ops.h"

{%- if has_vbe_support %}
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
{%- endif %}

#include "pt2_arg_utils.h"

using Tensor = at::Tensor;
Expand Down Expand Up @@ -124,6 +120,9 @@ enum SSDTensor {
const c10::SymInt /*vbe_output_size*/,
const int64_t /*info_B_num_bits*/,
const int64_t /*info_B_mask_int64*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and output
const Tensor& /*vbe_output_offsets_feature_rank*/, // for reshaping vbe cpu output
const int64_t /*max_B_int*/, // for reshaping vbe cpu offsets
{%- endif %}
{%- if is_gwd %}
const Tensor& /*prev_iter_dev*/,
Expand Down Expand Up @@ -168,6 +167,9 @@ enum SSDTensor {
vbe_output_size,
info_B_num_bits,
info_B_mask_int64,
vbe_B_offsets_rank_per_feature_, // for reshaping vbe cpu offsets and output
vbe_output_offsets_feature_rank_, // for reshaping vbe cpu output
max_B_int, // for reshaping vbe cpu offsets
{%- endif %} {# /* if vbe */ #}
{%- if is_gwd %}
prev_iter_dev_,
Expand Down Expand Up @@ -244,6 +246,8 @@ enum SSDTensor {
const Tensor& /*B_offsets*/,
const Tensor& /*vbe_row_output_offsets*/,
const Tensor& /*vbe_b_t_map*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu offsets and grad output
const int64_t /*max_B*/, // for reshaping vbe cpu offsets
{%- endif %}
const bool /*use_uniq_cache_locations_bwd*/,
const bool /*use_homogeneous_placements*/,
Expand Down Expand Up @@ -309,6 +313,8 @@ enum SSDTensor {
B_offsets,
vbe_row_output_offsets,
vbe_b_t_map,
vbe_B_offsets_rank_per_feature, // for reshaping vbe cpu offsets and grad output
max_B, // for reshaping vbe cpu offsets
{%- endif %} {# /* if vbe */ #}
{%- if not dense %}
use_uniq_cache_locations_bwd,
Expand Down Expand Up @@ -689,6 +695,7 @@ class {{ autograd_func }} :
const auto info_B_mask = static_cast<uint32_t>(aux_int[IDX_INFO_B_MASK]);

{%- if vbe %}
const int64_t max_B_int = max_B_.guard_int(__FILE__, __LINE__); // for reshaping vbe cpu offsets and grad_output
static auto generate_vbe_metadata_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::generate_vbe_metadata", "")
Expand Down Expand Up @@ -766,6 +773,7 @@ class {{ autograd_func }} :
B_offsets_,
vbe_row_output_offsets,
vbe_b_t_map,
vbe_B_offsets_rank_per_feature_, // for reshaping vbe cpu grad_output
{%- endif %}
{%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %}
prev_iter_dev_,
Expand Down Expand Up @@ -808,6 +816,9 @@ class {{ autograd_func }} :
{%- if not nobag %}
ctx->saved_data["output_dtype"] = output_dtype;
{%- endif %}
{%- if vbe %}
ctx->saved_data["max_B"] = max_B_int; // for reshaping vbe cpu offsets and grad_output
{%- endif %}

{%- if not dense %}
// unpack optim args
Expand Down Expand Up @@ -894,6 +905,7 @@ static torch::autograd::variable_list backward(
auto B_offsets = *savedItr++;
auto vbe_row_output_offsets = *savedItr++;
auto vbe_b_t_map = *savedItr++;
auto vbe_B_offsets_rank_per_feature = *savedItr++; // for reshaping vbe cpu grad_output
{%- endif %}
{%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %}
auto prev_iter_dev = *savedItr++;
Expand Down Expand Up @@ -939,6 +951,10 @@ static torch::autograd::variable_list backward(
auto output_dtype = ctx->saved_data["output_dtype"].toInt();
{%- endif %}
{%- if not dense %}
{%- if vbe %}
auto max_B = ctx->saved_data["max_B"].toInt(); // for reshaping vbe cpu offsets and grad_output
{%- endif %}

{%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %}
auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}();
{%- endfor %}
Expand Down Expand Up @@ -976,19 +992,6 @@ static torch::autograd::variable_list backward(
// {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda)
weights_dev = weights_dev.flatten();
{%- endif %}
{%- if vbe %}
// TODO: remove this once vbe_metadata for cpu is implemented
// MTIA kernel uses weights_host but follows CUDA implementation,
// so grad_output is already in a correct shape and must not be reshaped
// Reshaping on weights_host here causes MTIA kernel to fail.
// As a hotfix to unblock MTIA, we add condition check dimension so that reshpaing would skip on MTIA
// CUDA and MTIA vbe_b_t_map is size of {total_B} - should be 1 dim
// CPU vbe_b_t_map is B_offsets_rank_per_feature, so shape should be {num_features, batch_offsets}
// This will be removed totally once vbe_metadata for cpu is implemented
if (weights_host.numel() > 1 && vbe_b_t_map.dim() > 1){
grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets);
}
{%- endif %}

{%- set grad_indice_weights_op =
"{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc)
Expand Down Expand Up @@ -1023,7 +1026,9 @@ static torch::autograd::variable_list backward(
const Tensor& /*vbe_row_output_offsets*/,
const Tensor& /*vbe_b_t_map*/,
const int64_t /*info_B_num_bits*/,
const int64_t /*info_B_mask_int64*/
const int64_t /*info_B_mask_int64*/,
const Tensor& /*vbe_B_offsets_rank_per_feature*/, // for reshaping vbe cpu grad_output
const int64_t /*max_B*/ // for reshaping vbe cpu offsets and grad_output
{%- else %}
const Tensor& /*feature_requires_grad*/
{%- endif %}
Expand Down Expand Up @@ -1053,7 +1058,9 @@ static torch::autograd::variable_list backward(
vbe_row_output_offsets,
vbe_b_t_map,
info_B_num_bits,
info_B_mask_int64
info_B_mask_int64,
vbe_B_offsets_rank_per_feature,
max_B
{%- else %}
feature_requires_grad
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include "fbgemm_gpu/utils/ops_utils.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/embedding_common.h"
{%- if has_vbe_support %}
#include "fbgemm_gpu/utils/pt2_autograd_utils.h"
{%- endif %}

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand Down Expand Up @@ -53,23 +56,39 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B
{%- else %}
const Tensor& feature_requires_grad
{%- endif %}
) {
{%- if vbe %}
const auto offsets_ = reshape_vbe_offsets(
offsets,
vbe_B_offsets_rank_per_feature,
max_B,
D_offsets.numel() - 1
);
const auto grad_output_ = reshape_vbe_output(
grad_output,
max_B,
vbe_B_offsets_rank_per_feature,
D_offsets
);
{%- endif %}
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "")
.typed<Tensor(Tensor,Tensor,Tensor,Tensor,Tensor,Tensor,Tensor)>();
return op.call(
grad_output,
{{ "grad_output_" if vbe else "grad_output" }},
host_weights,
weights_offsets,
D_offsets,
indices,
offsets,
{{ "offsets_" if vbe else "offsets" }},
feature_requires_grad);
}
{%- endif %}
Expand All @@ -96,14 +115,20 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
const Tensor& /*lxu_cache_locations*/,
const Tensor& /*uvm_cache_stats*/,
{%- if vbe %}
const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/
const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const c10::SymInt vbe_output_size,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const Tensor& vbe_output_offsets_feature_rank,
const int64_t max_B,
{%- endif %}
const bool /*is_experimental = false*/,
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
{%- if vbe %}
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
{%- endif %}
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
Expand All @@ -112,16 +137,14 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper(
)>();
{%- if vbe %}
// TODO: remove this after vbe is implemented for CPU kernel
Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map;
Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets;
const auto output = op.call(
host_weights,
weights_offsets,
D_offsets,
total_D,
hash_size_cumsum,
indices,
offsets,
offsets_,
pooling_mode,
indice_weights,
output_dtype);
Expand Down Expand Up @@ -192,6 +215,8 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
const Tensor& B_offsets,
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B,
{%- endif %}
const bool /*use_uniq_cache_locations*/,
const bool /*use_homogeneous_placements*/,
Expand All @@ -200,6 +225,10 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
, const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)
{%- endif %})
{
{%- if vbe %}
const auto offsets_ = reshape_vbe_offsets(offsets, vbe_B_offsets_rank_per_feature, max_B, D_offsets.numel() - 1);
const auto grad_output_ = reshape_vbe_output(grad_output, max_B, vbe_B_offsets_rank_per_feature, D_offsets);
{%- endif %}
{%- set backward_op = "split_embedding_backward_codegen_{}_cpu".format(
optimizer
)
Expand Down Expand Up @@ -230,7 +259,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
)>();

op.call(
grad_output,
{{ "grad_output_" if vbe else "grad_output" }},
host_weights,
weights_placements,
weights_offsets,
Expand All @@ -239,7 +268,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
{{ "offsets_" if vbe else "offsets" }},
pooling_mode,
indice_weights,
stochastic_rounding,
Expand All @@ -248,7 +277,7 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p
, output_dtype
{%- endif %}
);
return grad_output;
return Tensor();
}
{% endif %}
{%- endfor %} {#-/*for weighted*/#}
Expand Down Expand Up @@ -293,6 +322,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" SymInt vbe_output_size, "
" int info_B_num_bits, "
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" Tensor vbe_output_offsets_feature_rank, "
" int max_B, "
{%- endif %}
" bool is_experimental, "
" int output_dtype "
Expand Down Expand Up @@ -345,6 +377,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor B_offsets, "
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B, "
{%- endif %}
" bool use_uniq_cache_locations, "
" bool use_homogeneous_placements,"
Expand Down Expand Up @@ -381,7 +415,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" int info_B_num_bits, "
" int info_B_mask_int64"
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B "
{%- else %}
" Tensor feature_requires_grad"
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt
const c10::SymInt vbe_output_size,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const Tensor& vbe_output_offsets_feature_rank,
const int64_t max_B,
{%- endif %}
{%- if is_gwd %}
const Tensor& prev_iter_dev,
Expand Down Expand Up @@ -241,6 +244,8 @@ Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{
const Tensor& B_offsets,
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B,
{%- endif %}
const bool use_uniq_cache_locations,
const bool use_homogeneous_placements,
Expand Down Expand Up @@ -403,7 +408,9 @@ Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_{{ d
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const int64_t info_B_num_bits,
const int64_t info_B_mask_int64
const int64_t info_B_mask_int64,
const Tensor& vbe_B_offsets_rank_per_feature,
const int64_t max_B
{%- else %}
const Tensor& feature_requires_grad
{%- endif %}
Expand Down Expand Up @@ -529,6 +536,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" SymInt vbe_output_size, "
" int info_B_num_bits, "
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" Tensor vbe_output_offsets_feature_rank, "
" int max_B, "
{%- endif %}
{%- if is_gwd %}
" Tensor{{ schema_annotation['prev_iter_dev'] }} prev_iter_dev, "
Expand Down Expand Up @@ -599,6 +609,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor B_offsets, "
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B, "
{%- endif %}
" bool use_uniq_cache_locations, "
" bool use_homogeneous_placements,"
Expand Down Expand Up @@ -656,7 +668,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" int info_B_num_bits, "
" int info_B_mask_int64"
" int info_B_mask_int64, "
" Tensor vbe_B_offsets_rank_per_feature, "
" int max_B "
{%- else %}
" Tensor feature_requires_grad"
{%- endif %}
Expand Down
Loading
Loading