@@ -167,9 +167,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
167
167
168
168
Tensor output;
169
169
SparseType o_dtype = static_cast <SparseType>(output_dtype);
170
- TORCH_CHECK (o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16);
170
+ TORCH_CHECK (o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT4 );
171
171
bool output_is_bf16 = o_dtype == SparseType::BF16;
172
172
bool output_is_int8 = o_dtype == SparseType::INT8;
173
+ bool output_is_int4 = o_dtype == SparseType::INT4;
173
174
{% if not nobag %}
174
175
const int kINT8QparamsBytes = 8 ;
175
176
int64_t total_adjusted_D = total_D;
@@ -178,10 +179,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
178
179
}
179
180
output = at::empty ({B, total_adjusted_D}, dev_weights.options ().dtype (getScalarType (o_dtype)).pinned_memory (pinned_memory));
180
181
{% else %}
181
- const int kINT8QparamsBytes = 4 ; // no bag int8 output aligns with fbgemm weights storage size and layout
182
+ constexpr int kINT8QparamsBytes = 4 ; // no bag int8 output aligns with fbgemm weights storage size and layout
183
+ constexpr int kINT4QparamsElems = 8 ; // scale + bias takes 4 bytes which are 8 int4 elements
182
184
int64_t adjusted_D = D;
183
185
if (o_dtype == SparseType::INT8) {
184
186
adjusted_D += kINT8QparamsBytes ;
187
+ } else if (o_dtype == SparseType::INT4) {
188
+ adjusted_D += kINT4QparamsElems ;
185
189
}
186
190
output = at::empty ({total_L, adjusted_D}, dev_weights.options ().dtype (getScalarType (o_dtype)).pinned_memory (pinned_memory));
187
191
@@ -212,7 +216,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
212
216
using other_fbgemm_out_t = typename std::conditional<
213
217
std::is_same<output_t , at::Half>::value,
214
218
float16,
215
- std::conditional<std::is_same<output_t , at::BFloat16>::value, bfloat16, float >::type > ::type;
219
+ std::conditional<std::is_same<output_t , at::BFloat16>::value, bfloat16, float >::type> ::type;
216
220
AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " int_nbit_split_embedding{{ " _nobag" if nobag else " " }}_codegen_forward_" , [&] {
217
221
const auto * indices_acc = indices.data_ptr <index_t >();
218
222
const auto * offsets_acc = offsets.data_ptr <index_t >();
@@ -230,7 +234,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
230
234
const int32_t D_end = D_offsets_acc[t + 1 ];
231
235
const int32_t D = D_end - D_start;
232
236
{% else %}
233
- const int32_t D_start = offsets_acc[t * B] * adjusted_D;
237
+ const int32_t elems_D = (o_dtype == SparseType::INT4) ? at::divup (adjusted_D, 2 ) : adjusted_D;
238
+ const int32_t D_start = offsets_acc[t * B] * elems_D;
234
239
{% endif %}
235
240
236
241
const auto placement = static_cast <PlacementType>(weights_placements_ptr[t]);
@@ -266,8 +271,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
266
271
{% endif %}
267
272
268
273
const float * indice_weights_ptr = nullptr ;
269
- // int8 output only enabled for nobag case with ref impl
270
- const bool nobag_op = {{ " false" if not nobag else " output_is_int8" }};
274
+ // int8/int4 output only enabled for nobag case
275
+ const bool nobag_op = {{ " false" if not nobag else " output_is_int8 || output_is_int4 " }};
271
276
{% if weighted %}
272
277
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
273
278
{% endif %}
@@ -278,7 +283,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
278
283
if use_base else (" GenerateEmbeddingSpMDMNBitWithStrides"
279
284
if use_nbit else " GenerateEmbeddingSpMDMFP8WithStrides" )
280
285
%}
281
- using fbgemm_out_t = {{ " base_fbgemm_out_t" if use_base else " other_fbgemm_out_t" }};
286
+ using fbgemm_out_t = {{ " base_fbgemm_out_t" if use_base or use_nbit else " other_fbgemm_out_t" }};
287
+ {% if use_nbit %}
288
+ const int output_bit_rate = output_is_int4 ? 4 : sizeof (fbgemm_out_t ) * 8 ;
289
+ {% endif %}
282
290
// TODO: merge nobag int8 path with normal asmjit dispatch
283
291
{% if nobag %}
284
292
const index_t * offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr;
@@ -299,7 +307,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
299
307
{% endif %}
300
308
>(
301
309
{% if use_nbit %}
302
- /* bit_rate =*/ bit_rate,
310
+ /* input_bit_rate =*/ bit_rate,
303
311
{% endif %}
304
312
D,
305
313
{% if has_asmjit %}
@@ -324,6 +332,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
324
332
/* no_bag=*/ nobag_op,
325
333
{% endif %}
326
334
/* is_bf16_out=*/ output_is_bf16
335
+ {% if use_nbit %}
336
+ ,/* no_bag=*/ nobag_op,
337
+ /* output_bit_rate=*/ output_bit_rate
338
+ {% endif %}
327
339
);
328
340
success = kernel (
329
341
{{ " B" if not nobag else " index_size" }},
0 commit comments