Skip to content

Commit 0c6d683

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
move common op to vector utils (pytorch#3759)
Summary: Pull Request resolved: pytorch#3759 X-link: facebookresearch/FBGEMM#840 as title. Op is used in next diff. Reviewed By: sgrigory Differential Revision: D70387235 fbshipit-source-id: c5acebd1e371e67822d51f3b5cb6f194c99ef1eb
1 parent 5158b8d commit 0c6d683

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -546,13 +546,7 @@ DEVICE_INLINE fx4 rope_xpos(
546546
double hi_freq_factor = 32) {
547547
fx4 dst; // read 4 bf16 from src and store in 4 float registers
548548
if (head == QKV::V) {
549-
auto r0 = bf1622float2(src.vals[0]);
550-
auto r1 = bf1622float2(src.vals[1]);
551-
dst.x = r0.x;
552-
dst.y = r0.y;
553-
dst.z = r1.x;
554-
dst.w = r1.y;
555-
return dst;
549+
return bfx4_to_fx4(src);
556550
}
557551
int32_t offset_0 = ((4 * threadIdx.x) / 2 + 0);
558552
int32_t offset_1 = ((4 * threadIdx.x) / 2 + 1);

fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ DEVICE_INLINE fx4 bfx4_scale_acc(fx4 acc, bfx4 a, float b) {
173173
return acc;
174174
}
175175

176+
DEVICE_INLINE fx4 bfx4_to_fx4(bfx4 src) {
177+
fx4 dst;
178+
auto r0 = bf1622float2(src.vals[0]);
179+
auto r1 = bf1622float2(src.vals[1]);
180+
dst.x = r0.x;
181+
dst.y = r0.y;
182+
dst.z = r1.x;
183+
dst.w = r1.y;
184+
return dst;
185+
}
176186
DEVICE_INLINE fx4 fx4_acc(fx4 a, fx4 b) {
177187
a.x += b.x;
178188
a.y += b.y;

0 commit comments

Comments
 (0)