Skip to content

Commit b2e0bca

Browse files
excelle08facebook-github-bot
authored andcommitted
Enable int4 to int4 CPU STBE in fbgemm_gpu TBE API (pytorch#2994)
Summary: Pull Request resolved: pytorch#2994 X-link: facebookresearch/FBGEMM#89 Enable int4 to int4 sequential CPU TBE in codegen template so that fbgemm_gpu's `IntNBitTableBatchedEmbeddingBagsCodegen` could support it Differential Revision: D61305978
1 parent 6cc4dd5 commit b2e0bca

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
167167

168168
Tensor output;
169169
SparseType o_dtype = static_cast<SparseType>(output_dtype);
170-
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16);
170+
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT4);
171171
bool output_is_bf16 = o_dtype == SparseType::BF16;
172172
bool output_is_int8 = o_dtype == SparseType::INT8;
173+
bool output_is_int4 = o_dtype == SparseType::INT4;
173174
{% if not nobag %}
174175
const int kINT8QparamsBytes = 8;
175176
int64_t total_adjusted_D = total_D;
@@ -178,10 +179,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
178179
}
179180
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
180181
{% else %}
181-
const int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
182+
constexpr int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
183+
constexpr int kINT4QparamsElems = 8; // scale + bias takes 4 bytes which are 8 int4 elements
182184
int64_t adjusted_D = D;
183185
if (o_dtype == SparseType::INT8) {
184186
adjusted_D += kINT8QparamsBytes;
187+
} else if (o_dtype == SparseType::INT4) {
188+
adjusted_D += kINT4QparamsElems;
185189
}
186190
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
187191

@@ -212,7 +216,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
212216
using other_fbgemm_out_t = typename std::conditional<
213217
std::is_same<output_t, at::Half>::value,
214218
float16,
215-
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;
219+
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type> ::type;
216220
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
217221
const auto* indices_acc = indices.data_ptr<index_t>();
218222
const auto* offsets_acc = offsets.data_ptr<index_t>();
@@ -230,7 +234,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
230234
const int32_t D_end = D_offsets_acc[t + 1];
231235
const int32_t D = D_end - D_start;
232236
{% else %}
233-
const int32_t D_start = offsets_acc[t * B] * adjusted_D;
237+
const int32_t elems_D = (o_dtype == SparseType::INT4) ? at::divup(adjusted_D, 2) : adjusted_D;
238+
const int32_t D_start = offsets_acc[t * B] * elems_D;
234239
{% endif %}
235240

236241
const auto placement = static_cast<PlacementType>(weights_placements_ptr[t]);
@@ -266,8 +271,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
266271
{% endif %}
267272

268273
const float* indice_weights_ptr = nullptr;
269-
// int8 output only enabled for nobag case with ref impl
270-
const bool nobag_op = {{ "false" if not nobag else "output_is_int8" }};
274+
// int8/int4 output only enabled for nobag case
275+
const bool nobag_op = {{ "false" if not nobag else "output_is_int8 || output_is_int4" }};
271276
{% if weighted %}
272277
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
273278
{% endif %}
@@ -278,7 +283,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
278283
if use_base else ("GenerateEmbeddingSpMDMNBitWithStrides"
279284
if use_nbit else "GenerateEmbeddingSpMDMFP8WithStrides")
280285
%}
281-
using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base else "other_fbgemm_out_t" }};
286+
using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base or use_nbit else "other_fbgemm_out_t" }};
287+
{% if use_nbit %}
288+
const int output_bit_rate = output_is_int4 ? 4 : sizeof(fbgemm_out_t) * 8;
289+
{% endif %}
282290
// TODO: merge nobag int8 path with normal asmjit dispatch
283291
{% if nobag %}
284292
const index_t* offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr;
@@ -299,7 +307,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
299307
{% endif %}
300308
>(
301309
{% if use_nbit %}
302-
/*bit_rate=*/bit_rate,
310+
/*input_bit_rate=*/bit_rate,
303311
{% endif %}
304312
D,
305313
{% if has_asmjit %}
@@ -324,6 +332,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
324332
/*no_bag=*/nobag_op,
325333
{% endif %}
326334
/*is_bf16_out=*/output_is_bf16
335+
{% if use_nbit %}
336+
,/*no_bag=*/nobag_op,
337+
/*output_bit_rate=*/output_bit_rate
338+
{% endif %}
327339
);
328340
success = kernel(
329341
{{ "B" if not nobag else "index_size"}},

fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
at::ScalarType::BFloat16, at::BFloat16, __VA_ARGS__) \
123123
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Float, float, __VA_ARGS__) \
124124
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
125+
PRIVATE_CASE_TYPE_OUTPUT2( \
126+
at::ScalarType::QUInt4x2, uint8_t, __VA_ARGS__) \
125127
default: \
126128
AT_ERROR( \
127129
#NAME, \

0 commit comments

Comments
 (0)