Skip to content

Commit 7c09c56

Browse files
jianyuhfacebook-github-bot
authored andcommitted
Add BF16 output support for inference TBE (pytorch#1498)
Summary: Pull Request resolved: pytorch#1498 As title Reviewed By: jiecaoyu Differential Revision: D41835847 fbshipit-source-id: 871c23f27027d8478372342181db3e1ba53c7cdf
1 parent 81ba6c5 commit 7c09c56

File tree

2 files changed

+282
-4
lines changed

2 files changed

+282
-4
lines changed

fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
197197
return;
198198
}
199199
static_assert(
200-
std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, uint8_t>::value,
200+
std::is_same<output_t, float>::value || std::is_same<output_t, at::BFloat16>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, uint8_t>::value,
201201
"output_t can only be float or half or bytes now"
202202
);
203203

@@ -331,7 +331,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
331331
}
332332
{% else %}
333333
const int32_t output_j = indices_starts[i] + L_start + input_row_idx;
334-
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value) {
334+
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, at::BFloat16>::value) {
335335
#pragma unroll MaxNum128BRows
336336
for (uint32_t j = 0; j < MaxNum128BRows; ++j) {
337337
// Read the uint8/4/2 values: note that first 4 Bytes will be ditched later:
@@ -388,7 +388,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
388388
const uint32_t b = min(static_cast<uint32_t>(bb * OutputRowsPerThread + i), static_cast<uint32_t>(B - 1));
389389
const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast<float>(1.0) / Ls[i]: static_cast<float>(1.0);
390390

391-
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value) {
391+
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, at::BFloat16>::value) {
392392
#pragma unroll MaxNum128BRows
393393
for (uint32_t j = 0; j < MaxNum128BRows; ++j) {
394394
const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
@@ -625,7 +625,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
625625
Tensor output;
626626
const int kINT8QparamsBytes = 8;
627627
SparseType o_dtype = static_cast<SparseType>(output_dtype);
628-
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8);
628+
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8);
629629
{% if not nobag %}
630630
int64_t total_adjusted_D = total_D;
631631
if (o_dtype == SparseType::INT8) {

0 commit comments

Comments
 (0)