File tree Expand file tree Collapse file tree 1 file changed +20
-4
lines changed
fbgemm_gpu/experimental/gen_ai/src/kv_cache Expand file tree Collapse file tree 1 file changed +20
-4
lines changed Original file line number Diff line number Diff line change @@ -795,11 +795,27 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
795
795
} else {
796
796
__half2* qparam_row = nullptr ;
797
797
auto T = cache_K.size (1 );
798
- auto idx = b * (T * N_KVH) + (size_t )cache_loc_t * N_KVH + h;
799
- if (qkv == QKV::K) {
800
- qparam_row = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
798
+ if (block_tables == nullptr ) {
799
+ auto idx = b * (T * N_KVH) + (size_t )cache_loc_t * N_KVH + h;
800
+ if (qkv == QKV::K) {
801
+ qparam_row = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
802
+ } else {
803
+ qparam_row = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
804
+ }
801
805
} else {
802
- qparam_row = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
806
+ // This is duplicate computation with get_dst_row above.
807
+ // TODO: Maybe clean up and merge later.
808
+ int page_logical_idx = cache_loc_t / page_size;
809
+ int page_offset = cache_loc_t % page_size;
810
+ int page_physical_idx =
811
+ block_tables[b * block_tables_b_stride + page_logical_idx];
812
+ int physical_t = page_physical_idx * page_size + page_offset;
813
+ auto idx = physical_t * N_KVH + h;
814
+ if (qkv == QKV::K) {
815
+ qparam_row = reinterpret_cast <__half2*>(&qparam_k_ptr[idx]);
816
+ } else {
817
+ qparam_row = reinterpret_cast <__half2*>(&qparam_v_ptr[idx]);
818
+ }
803
819
}
804
820
quantize_fp8_kv (dst, dst_row_q, qparam_row);
805
821
}
You can’t perform that action at this time.
0 commit comments