@@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel(
107
107
108
108
#define LAUNCH_MERGE_ATTN_STATES (scalar_t, NUM_THREADS ) \
109
109
{ \
110
- vllm::merge_attn_states_kernel<scalar_t , NUM_THREADS><<<grid, block>>> ( \
111
- reinterpret_cast <scalar_t *>(output.data_ptr ()), output_lse_ptr, \
112
- reinterpret_cast <scalar_t *>(prefix_output.data_ptr ()), \
113
- reinterpret_cast <float *>(prefix_lse.data_ptr ()), \
114
- reinterpret_cast <scalar_t *>(suffix_output.data_ptr ()), \
115
- reinterpret_cast <float *>(suffix_lse.data_ptr ()), num_tokens, \
116
- num_heads, head_size); \
110
+ vllm::merge_attn_states_kernel<scalar_t , NUM_THREADS> \
111
+ <<<grid, block, 0 , stream>>> ( \
112
+ reinterpret_cast <scalar_t *>(output.data_ptr ()), output_lse_ptr, \
113
+ reinterpret_cast <scalar_t *>(prefix_output.data_ptr ()), \
114
+ reinterpret_cast <float *>(prefix_lse.data_ptr ()), \
115
+ reinterpret_cast <scalar_t *>(suffix_output.data_ptr ()), \
116
+ reinterpret_cast <float *>(suffix_lse.data_ptr ()), num_tokens, \
117
+ num_heads, head_size); \
117
118
}
118
119
119
120
/* @brief Merges the attention states from prefix and suffix
@@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
122
123
* @param output [n,h,d] The output tensor to store the merged attention states.
123
124
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
124
125
* @param prefix_output [n,h,d] The prefix attention states.
125
- * @param prefix_lse [h,d ] The log-sum-exp values for the prefix attention
126
+ * @param prefix_lse [h,n ] The log-sum-exp values for the prefix attention
126
127
* states.
127
128
* @param suffix_output [n,h,d] The suffix attention states.
128
- * @param suffix_lse [h,d ] The log-sum-exp values for the suffix attention
129
+ * @param suffix_lse [h,n ] The log-sum-exp values for the suffix attention
129
130
* states.
130
131
*/
131
132
template <typename scalar_t >
@@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
146
147
if (output_lse.has_value ()) {
147
148
output_lse_ptr = output_lse.value ().data_ptr <float >();
148
149
}
149
- // process one pack elements per thread. float -> 4, half/bf16 -> 8
150
+ // Process one pack elements per thread. for float, the
151
+ // pack_size is 4 for half/bf16, the pack_size is 8.
150
152
const uint threads_per_head = head_size / pack_size;
151
153
const uint total_threads = num_tokens * num_heads * threads_per_head;
152
154
153
155
dim3 block (NUM_THREADS);
154
156
dim3 grid ((total_threads + NUM_THREADS - 1 ) / NUM_THREADS);
155
157
158
+ const c10::cuda::OptionalCUDAGuard device_guard (prefix_output.device ());
159
+ auto stream = at::cuda::getCurrentCUDAStream ();
160
+
156
161
LAUNCH_MERGE_ATTN_STATES (scalar_t , NUM_THREADS);
157
162
}
158
163
0 commit comments