Skip to content

Commit 80990a6

Browse files
lequytrafacebook-github-bot
authored andcommitted
Add BF16 support for reorder_batched_ad_indices (pytorch#2116)
Summary: Pull Request resolved: pytorch#2116 We use `reorder_batched_ad_indices` to [rebatch id_score_list weights](https://www.internalfb.com/code/fbsource/[e3bbe1eaf65e]/fbcode/caffe2/caffe2/fb/predictor/rebatch/GPURebatchUtils.cpp?lines=305) which is quantized to BFloat 16. However, BFloat16 is currently not supported in `reorder_batched_ad_indices`, see error trace: P868895010 This diff adds this support for BFloat16 dtype. Reviewed By: YazhiGao Differential Revision: D50817983 fbshipit-source-id: 4949acac8d1524dc10c7931e28bdfcabd2e94477
1 parent 174d473 commit 80990a6

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
243243
const dim3 threads(32, 32);
244244
const dim3 blocks((B * T + 32 - 1) / 32);
245245

246-
AT_DISPATCH_ALL_TYPES(
246+
AT_DISPATCH_ALL_TYPES_AND(
247+
at::ScalarType::BFloat16,
247248
cat_ad_indices.scalar_type(),
248249
"reorder_batched_ad_indices_gpu_kernel_1",
249250
[&] {

fbgemm_gpu/test/sparse_ops_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ def test_reorder_batched_ad_lengths_cpu(
11141114
T=st.integers(min_value=1, max_value=20),
11151115
L=st.integers(min_value=2, max_value=20),
11161116
A=st.integers(min_value=1, max_value=20),
1117-
Dtype=st.sampled_from([torch.int32, torch.float, torch.int64]),
1117+
Dtype=st.sampled_from([torch.int32, torch.float, torch.int64, torch.bfloat16]),
11181118
Itype=st.sampled_from([torch.int32, torch.int64]),
11191119
broadcast_indices=st.booleans(),
11201120
)

0 commit comments

Comments
 (0)