Skip to content

write_k_back for fp8 ROPE #3679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
scaling_factor, \
lo_freq_factor, \
hi_freq_factor, \
write_k_back, \
k_rms_norm) \
rope_xpos_qkv_varseq_prefill_kernel_<EMB_MODE, DTYPE, NUM_GROUPS> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
Expand Down Expand Up @@ -742,6 +743,7 @@ DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
scaling_factor, \
lo_freq_factor, \
hi_freq_factor, \
write_k_back, \
k_rms_norm);

#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
Expand Down Expand Up @@ -805,6 +807,7 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
double scaling_factor = 16,
double lo_freq_factor = 1,
double hi_freq_factor = 32,
bool write_k_back = false,
bool k_rms_norm = false) {
// Launch b_t_(sum(h)) warps.
auto b_t_hh = blockIdx.x * blockDim.y +
Expand Down Expand Up @@ -905,6 +908,13 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
*reinterpret_cast<uint2*>(&dst_row[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&dst_);
} else {
if (write_k_back && qkv == QKV::K) {
// Also write back to the source row
bfx4 dst_ = fx4_to_bfx4(dst);
*reinterpret_cast<uint2*>(&src_row[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&dst_);
}
// quantize and write to dst_row
auto D_H = XQ.size(2);
auto D_H_q = cache_K.size(3);
if (kCacheDtype == CacheLogicalDtype::FP8) {
Expand All @@ -914,13 +924,9 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
} else {
__half2* qparam_row = nullptr;
auto T = cache_K.size(1);
size_t idx = 0;
if (block_tables == nullptr) {
auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
if (qkv == QKV::K) {
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
} else {
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
}
idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h;
} else {
// This is duplicate computation with get_dst_row above.
// TODO: Maybe clean up and merge later.
Expand All @@ -929,12 +935,12 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
int page_physical_idx =
block_tables[b * block_tables_b_stride + page_logical_idx];
int physical_t = page_physical_idx * page_size + page_offset;
auto idx = physical_t * N_KVH + h;
if (qkv == QKV::K) {
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
} else {
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
}
idx = physical_t * N_KVH + h;
}
if (qkv == QKV::K) {
qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]);
} else {
qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]);
}
quantize_fp8_kv(
dst, dst_row_q, qparam_row, (qkv == QKV::K && k_rms_norm));
Expand Down Expand Up @@ -1160,7 +1166,6 @@ at::Tensor rope_qkv_varseq_prefill(
varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>();
int32_t* qparam_k_ptr = nullptr;
int32_t* qparam_v_ptr = nullptr;
TORCH_CHECK(!write_k_back);
if (qparam_k.has_value()) {
qparam_k_ptr = static_cast<int32_t*>(qparam_k.value().data_ptr());
qparam_v_ptr = static_cast<int32_t*>(qparam_v.value().data_ptr());
Expand Down Expand Up @@ -1190,6 +1195,7 @@ at::Tensor rope_qkv_varseq_prefill(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
write_k_back,
k_rms_norm);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
Expand Down Expand Up @@ -1218,6 +1224,7 @@ at::Tensor rope_qkv_varseq_prefill(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
write_k_back,
false);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1341,6 +1348,7 @@ at::Tensor xpos_qkv_varseq_prefill(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false,
false);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1370,6 +1378,7 @@ at::Tensor xpos_qkv_varseq_prefill(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false,
false);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1490,6 +1499,7 @@ at::Tensor rope_qkv_decoding(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false,
k_rms_norm);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1519,6 +1529,7 @@ at::Tensor rope_qkv_decoding(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false,
false);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1642,6 +1653,7 @@ at::Tensor xpos_qkv_decoding(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
Expand Down Expand Up @@ -1670,6 +1682,7 @@ at::Tensor xpos_qkv_decoding(
scaling_factor,
lo_freq_factor,
hi_freq_factor,
false,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -2005,6 +2018,7 @@ DEVICE_INLINE void quantize_fp8_kv(

*reinterpret_cast<uint32_t*>(
&dst_row_q[4 * threadIdx.x + fp8_qparam_offset]) = packed;
// write qparams
if (threadIdx.x == 0) {
__half2* param_store = qparam;
if (param_store == nullptr) {
Expand Down
Loading