Skip to content

Commit 8bfa07c

Browse files
SungMinChofacebook-github-bot
authored andcommitted
support rope with block tables (pytorch#3146)
Summary: Pull Request resolved: pytorch#3146 X-link: facebookresearch/FBGEMM#227 Modify `rope_xpos_qkv_varseq_prefill_kernel_` so that it uses page indirection for qparam tensors as well. Reviewed By: sgrigory Differential Revision: D61898380
1 parent e27c5e1 commit 8bfa07c

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -795,11 +795,27 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
795795
} else {
796796
__half2* qparam_row = nullptr;
797797
auto T = cache_K.size(1);
798-
auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
799-
if (qkv == QKV::K) {
800-
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
798+
if (block_tables == nullptr) {
799+
auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
800+
if (qkv == QKV::K) {
801+
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
802+
} else {
803+
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
804+
}
801805
} else {
802-
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
806+
// This is duplicate computation with get_dst_row above.
807+
// TODO: Maybe clean up and merge later.
808+
int page_logical_idx = cache_loc_t / page_size;
809+
int page_offset = cache_loc_t % page_size;
810+
int page_physical_idx =
811+
block_tables[b * block_tables_b_stride + page_logical_idx];
812+
int physical_t = page_physical_idx * page_size + page_offset;
813+
auto idx = physical_t * N_KVH + h;
814+
if (qkv == QKV::K) {
815+
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
816+
} else {
817+
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
818+
}
803819
}
804820
quantize_fp8_kv(dst, dst_row_q, qparam_row);
805821
}

0 commit comments

Comments
 (0)