@@ -715,6 +715,7 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
715
715
scaling_factor, \
716
716
lo_freq_factor, \
717
717
hi_freq_factor, \
718
+ write_k_back, \
718
719
k_rms_norm) \
719
720
rope_xpos_qkv_varseq_prefill_kernel_<EMB_MODE, DTYPE, NUM_GROUPS> \
720
721
<<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> ( \
@@ -742,6 +743,7 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
742
743
scaling_factor, \
743
744
lo_freq_factor, \
744
745
hi_freq_factor, \
746
+ write_k_back, \
745
747
k_rms_norm);
746
748
747
749
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
@@ -805,6 +807,7 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
805
807
double scaling_factor = 16 ,
806
808
double lo_freq_factor = 1 ,
807
809
double hi_freq_factor = 32 ,
810
+ bool write_k_back = false ,
808
811
bool k_rms_norm = false ) {
809
812
// Launch b_t_(sum(h)) warps.
810
813
auto b_t_hh = blockIdx .x * blockDim .y +
@@ -905,6 +908,13 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
905
908
*reinterpret_cast <uint2 *>(&dst_row[4 * threadIdx .x ]) =
906
909
*reinterpret_cast <uint2 *>(&dst_);
907
910
} else {
911
+ if (write_k_back && qkv == QKV::K) {
912
+ // Also write back to the source row
913
+ bfx4 dst_ = fx4_to_bfx4 (dst);
914
+ *reinterpret_cast <uint2 *>(&src_row[4 * threadIdx .x ]) =
915
+ *reinterpret_cast <uint2 *>(&dst_);
916
+ }
917
+ // quantize and write to dst_row
908
918
auto D_H = XQ.size (2 );
909
919
auto D_H_q = cache_K.size (3 );
910
920
if (kCacheDtype == CacheLogicalDtype::FP8) {
@@ -914,13 +924,9 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
914
924
} else {
915
925
__half2* qparam_row = nullptr ;
916
926
auto T = cache_K.size (1 );
927
+ size_t idx = 0 ;
917
928
if (block_tables == nullptr ) {
918
- auto idx = b * (T * N_KVH) + (size_t )cache_loc_t * N_KVH + h;
919
- if (qkv == QKV::K) {
920
- qparam_row = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
921
- } else {
922
- qparam_row = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
923
- }
929
+ idx = b * (T * N_KVH) + (size_t )cache_loc_t * N_KVH + h;
924
930
} else {
925
931
// This is duplicate computation with get_dst_row above.
926
932
// TODO: Maybe clean up and merge later.
@@ -929,12 +935,12 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
929
935
int page_physical_idx =
930
936
block_tables[b * block_tables_b_stride + page_logical_idx];
931
937
int physical_t = page_physical_idx * page_size + page_offset;
932
- auto idx = physical_t * N_KVH + h;
933
- if (qkv == QKV::K) {
934
- qparam_row = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
935
- } else {
936
- qparam_row = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
937
- }
938
+ idx = physical_t * N_KVH + h;
939
+ }
940
+ if (qkv == QKV::K) {
941
+ qparam_row = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
942
+ } else {
943
+ qparam_row = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
938
944
}
939
945
quantize_fp8_kv (
940
946
dst, dst_row_q, qparam_row, (qkv == QKV::K && k_rms_norm));
@@ -1160,7 +1166,6 @@ at::Tensor rope_qkv_varseq_prefill(
1160
1166
varseq_seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>();
1161
1167
int32_t * qparam_k_ptr = nullptr ;
1162
1168
int32_t * qparam_v_ptr = nullptr ;
1163
- TORCH_CHECK (!write_k_back);
1164
1169
if (qparam_k.has_value ()) {
1165
1170
qparam_k_ptr = static_cast <int32_t *>(qparam_k.value ().data_ptr ());
1166
1171
qparam_v_ptr = static_cast <int32_t *>(qparam_v.value ().data_ptr ());
@@ -1190,6 +1195,7 @@ at::Tensor rope_qkv_varseq_prefill(
1190
1195
scaling_factor,
1191
1196
lo_freq_factor,
1192
1197
hi_freq_factor,
1198
+ write_k_back,
1193
1199
k_rms_norm);
1194
1200
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1195
1201
#else
@@ -1218,6 +1224,7 @@ at::Tensor rope_qkv_varseq_prefill(
1218
1224
scaling_factor,
1219
1225
lo_freq_factor,
1220
1226
hi_freq_factor,
1227
+ write_k_back,
1221
1228
false );
1222
1229
1223
1230
C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -1341,6 +1348,7 @@ at::Tensor xpos_qkv_varseq_prefill(
1341
1348
scaling_factor,
1342
1349
lo_freq_factor,
1343
1350
hi_freq_factor,
1351
+ false ,
1344
1352
false );
1345
1353
1346
1354
C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -1370,6 +1378,7 @@ at::Tensor xpos_qkv_varseq_prefill(
1370
1378
scaling_factor,
1371
1379
lo_freq_factor,
1372
1380
hi_freq_factor,
1381
+ false ,
1373
1382
false );
1374
1383
1375
1384
C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -1490,6 +1499,7 @@ at::Tensor rope_qkv_decoding(
1490
1499
scaling_factor,
1491
1500
lo_freq_factor,
1492
1501
hi_freq_factor,
1502
+ false ,
1493
1503
k_rms_norm);
1494
1504
1495
1505
C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -1519,6 +1529,7 @@ at::Tensor rope_qkv_decoding(
1519
1529
scaling_factor,
1520
1530
lo_freq_factor,
1521
1531
hi_freq_factor,
1532
+ false ,
1522
1533
false );
1523
1534
1524
1535
C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -1642,6 +1653,7 @@ at::Tensor xpos_qkv_decoding(
1642
1653
scaling_factor,
1643
1654
lo_freq_factor,
1644
1655
hi_freq_factor,
1656
+ false ,
1645
1657
false );
1646
1658
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1647
1659
#else
@@ -1670,6 +1682,7 @@ at::Tensor xpos_qkv_decoding(
1670
1682
scaling_factor,
1671
1683
lo_freq_factor,
1672
1684
hi_freq_factor,
1685
+ false ,
1673
1686
false );
1674
1687
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1675
1688
}
@@ -2005,6 +2018,7 @@ DEVICE_INLINE void quantize_fp8_kv(
2005
2018
2006
2019
*reinterpret_cast <uint32_t *>(
2007
2020
&dst_row_q[4 * threadIdx .x + fp8_qparam_offset]) = packed;
2021
+ // write qparams
2008
2022
if (threadIdx .x == 0 ) {
2009
2023
__half2* param_store = qparam;
2010
2024
if (param_store == nullptr ) {
0 commit comments