Skip to content

Commit 84615c4

Browse files
DefTruthMu Huai
authored andcommitted
[Bugfix][Kernel] fix potential cuda graph broken for merge_attn_states kernel (vllm-project#16693)
Signed-off-by: DefTruth <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 3eae3df commit 84615c4

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

csrc/attention/merge_attn_states.cu

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel(
107107

108108
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
109109
{ \
110-
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
111-
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
112-
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
113-
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
114-
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
115-
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
116-
num_heads, head_size); \
110+
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
111+
<<<grid, block, 0, stream>>>( \
112+
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
113+
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
114+
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
115+
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
116+
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
117+
num_heads, head_size); \
117118
}
118119

119120
/*@brief Merges the attention states from prefix and suffix
@@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
122123
* @param output [n,h,d] The output tensor to store the merged attention states.
123124
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
124125
* @param prefix_output [n,h,d] The prefix attention states.
125-
* @param prefix_lse [h,d] The log-sum-exp values for the prefix attention
126+
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
126127
* states.
127128
* @param suffix_output [n,h,d] The suffix attention states.
128-
* @param suffix_lse [h,d] The log-sum-exp values for the suffix attention
129+
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
129130
* states.
130131
*/
131132
template <typename scalar_t>
@@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
146147
if (output_lse.has_value()) {
147148
output_lse_ptr = output_lse.value().data_ptr<float>();
148149
}
149-
// process one pack elements per thread. float -> 4, half/bf16 -> 8
150+
// Process one pack elements per thread. for float, the
151+
// pack_size is 4 for half/bf16, the pack_size is 8.
150152
const uint threads_per_head = head_size / pack_size;
151153
const uint total_threads = num_tokens * num_heads * threads_per_head;
152154

153155
dim3 block(NUM_THREADS);
154156
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
155157

158+
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
159+
auto stream = at::cuda::getCurrentCUDAStream();
160+
156161
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
157162
}
158163

0 commit comments

Comments
 (0)