Skip to content

Commit ebc6902

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
write_k_back for fp8 ROPE (pytorch#756)
Summary: X-link: pytorch#3679 Pull Request resolved: facebookresearch/FBGEMM#756 Needed for tree attention. Make rope write the xk for the suffix back to the input xk before applying quantization. Reviewed By: jianyuh Differential Revision: D69182701 fbshipit-source-id: ba279706f4807d66fc1738796ea9e433705b285d
1 parent cb70110 commit ebc6902

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,7 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
715715
scaling_factor, \
716716
lo_freq_factor, \
717717
hi_freq_factor, \
718+
write_k_back, \
718719
k_rms_norm) \
719720
rope_xpos_qkv_varseq_prefill_kernel_<EMB_MODE, DTYPE, NUM_GROUPS> \
720721
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
@@ -742,6 +743,7 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
742743
scaling_factor, \
743744
lo_freq_factor, \
744745
hi_freq_factor, \
746+
write_k_back, \
745747
k_rms_norm);
746748

747749
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
@@ -805,6 +807,7 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
805807
double scaling_factor = 16,
806808
double lo_freq_factor = 1,
807809
double hi_freq_factor = 32,
810+
bool write_k_back = false,
808811
bool k_rms_norm = false) {
809812
// Launch b_t_(sum(h)) warps.
810813
auto b_t_hh = blockIdx.x * blockDim.y +
@@ -905,6 +908,13 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
905908
*reinterpret_cast<uint2*>(&dst_row[4 * threadIdx.x]) =
906909
*reinterpret_cast<uint2*>(&dst_);
907910
} else {
911+
if (write_k_back && qkv == QKV::K) {
912+
// Also write back to the source row
913+
bfx4 dst_ = fx4_to_bfx4(dst);
914+
*reinterpret_cast<uint2*>(&src_row[4 * threadIdx.x]) =
915+
*reinterpret_cast<uint2*>(&dst_);
916+
}
917+
// quantize and write to dst_row
908918
auto D_H = XQ.size(2);
909919
auto D_H_q = cache_K.size(3);
910920
if (kCacheDtype == CacheLogicalDtype::FP8) {
@@ -914,13 +924,9 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
914924
} else {
915925
__half2* qparam_row = nullptr;
916926
auto T = cache_K.size(1);
927+
size_t idx = 0;
917928
if (block_tables == nullptr) {
918-
auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
919-
if (qkv == QKV::K) {
920-
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
921-
} else {
922-
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
923-
}
929+
idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
924930
} else {
925931
// This is duplicate computation with get_dst_row above.
926932
// TODO: Maybe clean up and merge later.
@@ -929,12 +935,12 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
929935
int page_physical_idx =
930936
block_tables[b * block_tables_b_stride + page_logical_idx];
931937
int physical_t = page_physical_idx * page_size + page_offset;
932-
auto idx = physical_t * N_KVH + h;
933-
if (qkv == QKV::K) {
934-
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
935-
} else {
936-
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
937-
}
938+
idx = physical_t * N_KVH + h;
939+
}
940+
if (qkv == QKV::K) {
941+
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
942+
} else {
943+
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
938944
}
939945
quantize_fp8_kv(
940946
dst, dst_row_q, qparam_row, (qkv == QKV::K && k_rms_norm));
@@ -1160,7 +1166,6 @@ at::Tensor rope_qkv_varseq_prefill(
11601166
varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>();
11611167
int32_t* qparam_k_ptr = nullptr;
11621168
int32_t* qparam_v_ptr = nullptr;
1163-
TORCH_CHECK(!write_k_back);
11641169
if (qparam_k.has_value()) {
11651170
qparam_k_ptr = static_cast<int32_t*>(qparam_k.value().data_ptr());
11661171
qparam_v_ptr = static_cast<int32_t*>(qparam_v.value().data_ptr());
@@ -1190,6 +1195,7 @@ at::Tensor rope_qkv_varseq_prefill(
11901195
scaling_factor,
11911196
lo_freq_factor,
11921197
hi_freq_factor,
1198+
write_k_back,
11931199
k_rms_norm);
11941200
C10_CUDA_KERNEL_LAUNCH_CHECK();
11951201
#else
@@ -1218,6 +1224,7 @@ at::Tensor rope_qkv_varseq_prefill(
12181224
scaling_factor,
12191225
lo_freq_factor,
12201226
hi_freq_factor,
1227+
write_k_back,
12211228
false);
12221229
12231230
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -1341,6 +1348,7 @@ at::Tensor xpos_qkv_varseq_prefill(
13411348
scaling_factor,
13421349
lo_freq_factor,
13431350
hi_freq_factor,
1351+
false,
13441352
false);
13451353
13461354
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -1370,6 +1378,7 @@ at::Tensor xpos_qkv_varseq_prefill(
13701378
scaling_factor,
13711379
lo_freq_factor,
13721380
hi_freq_factor,
1381+
false,
13731382
false);
13741383
13751384
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -1490,6 +1499,7 @@ at::Tensor rope_qkv_decoding(
14901499
scaling_factor,
14911500
lo_freq_factor,
14921501
hi_freq_factor,
1502+
false,
14931503
k_rms_norm);
14941504
14951505
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -1519,6 +1529,7 @@ at::Tensor rope_qkv_decoding(
15191529
scaling_factor,
15201530
lo_freq_factor,
15211531
hi_freq_factor,
1532+
false,
15221533
false);
15231534
15241535
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -1642,6 +1653,7 @@ at::Tensor xpos_qkv_decoding(
16421653
scaling_factor,
16431654
lo_freq_factor,
16441655
hi_freq_factor,
1656+
false,
16451657
false);
16461658
C10_CUDA_KERNEL_LAUNCH_CHECK();
16471659
#else
@@ -1670,6 +1682,7 @@ at::Tensor xpos_qkv_decoding(
16701682
scaling_factor,
16711683
lo_freq_factor,
16721684
hi_freq_factor,
1685+
false,
16731686
false);
16741687
C10_CUDA_KERNEL_LAUNCH_CHECK();
16751688
}
@@ -2005,6 +2018,7 @@ DEVICE_INLINE void quantize_fp8_kv(
20052018
20062019
*reinterpret_cast<uint32_t*>(
20072020
&dst_row_q[4 * threadIdx.x + fp8_qparam_offset]) = packed;
2021+
// write qparams
20082022
if (threadIdx.x == 0) {
20092023
__half2* param_store = qparam;
20102024
if (param_store == nullptr) {

0 commit comments

Comments
 (0)