@@ -197,7 +197,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
197
197
return ;
198
198
}
199
199
static_assert (
200
- std::is_same<output_t , float >::value || std::is_same<output_t , at::Half>::value || std::is_same<output_t , uint8_t >::value,
200
+ std::is_same<output_t , float >::value || std::is_same<output_t , at::BFloat16>::value || std::is_same< output_t , at:: Half>::value || std::is_same<output_t , uint8_t >::value,
201
201
" output_t can only be float or half or bytes now"
202
202
);
203
203
@@ -331,7 +331,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
331
331
}
332
332
{% else %}
333
333
const int32_t output_j = indices_starts[i] + L_start + input_row_idx;
334
- if (std::is_same<output_t , float >::value || std::is_same<output_t , at::Half>::value) {
334
+ if (std::is_same<output_t , float >::value || std::is_same<output_t , at::Half>::value || std::is_same< output_t , at::BFloat16>::value ) {
335
335
#pragma unroll MaxNum128BRows
336
336
for (uint32_t j = 0 ; j < MaxNum128BRows; ++j) {
337
337
// Read the uint8/4/2 values: note that first 4 Bytes will be ditched later:
@@ -388,7 +388,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
388
388
const uint32_t b = min (static_cast <uint32_t >(bb * OutputRowsPerThread + i), static_cast <uint32_t >(B - 1 ));
389
389
const float inv_L = (mean_pooling && Ls[i] != 0 ) ? static_cast <float >(1.0 ) / Ls[i]: static_cast <float >(1.0 );
390
390
391
- if (std::is_same<output_t , float >::value || std::is_same<output_t , at::Half>::value) {
391
+ if (std::is_same<output_t , float >::value || std::is_same<output_t , at::Half>::value || std::is_same< output_t , at::BFloat16>::value ) {
392
392
#pragma unroll MaxNum128BRows
393
393
for (uint32_t j = 0 ; j < MaxNum128BRows; ++j) {
394
394
const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx .x * kOutputsPerThread - D_padding;
@@ -625,7 +625,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
625
625
Tensor output;
626
626
const int kINT8QparamsBytes = 8 ;
627
627
SparseType o_dtype = static_cast <SparseType>(output_dtype);
628
- TORCH_CHECK (o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8);
628
+ TORCH_CHECK (o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType:: INT8);
629
629
{% if not nobag %}
630
630
int64_t total_adjusted_D = total_D;
631
631
if (o_dtype == SparseType::INT8) {
0 commit comments