|
| 1 | +#include <ATen/cuda/CUDAContext.h> |
| 2 | +#include <c10/cuda/CUDAGuard.h> |
| 3 | + |
| 4 | +#include <algorithm> |
| 5 | +#include <optional> |
| 6 | + |
| 7 | +#include "pytorch_extension_utils.h" |
| 8 | + |
| 9 | +// Helper functions to convert between different data types |
| 10 | +// (float, half, bfloat16) for the merge attention states kernel. |
| 11 | +inline __device__ float to_float(float u) { |
| 12 | + return u; |
| 13 | +} |
| 14 | +inline __device__ float to_float(half u) { |
| 15 | + return __half2float(u); |
| 16 | +} |
| 17 | +inline __device__ float to_float(__nv_bfloat16 u) { |
| 18 | + return __bfloat162float(u); |
| 19 | +} |
| 20 | +inline __device__ void from_float(float& d, float s) { |
| 21 | + d = s; |
| 22 | +} |
| 23 | +inline __device__ void from_float(half& d, float s) { |
| 24 | + d = __float2half(s); |
| 25 | +} |
| 26 | +inline __device__ void from_float(__nv_bfloat16& d, float s) { |
| 27 | + d = __float2bfloat16(s); |
| 28 | +} |
| 29 | + |
| 30 | +// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 |
| 31 | +template <typename scalar_t, const uint NUM_THREADS> |
| 32 | +__global__ void merge_attn_states_kernel( |
| 33 | + scalar_t* output, |
| 34 | + float* output_lse, |
| 35 | + const scalar_t* prefix_output, |
| 36 | + const float* prefix_lse, |
| 37 | + const scalar_t* suffix_output, |
| 38 | + const float* suffix_lse, |
| 39 | + const uint num_tokens, |
| 40 | + const uint num_heads, |
| 41 | + const uint head_size) { |
| 42 | + using pack_128b_t = uint4; |
| 43 | + const uint pack_size = 16 / sizeof(scalar_t); |
| 44 | + const uint threads_per_head = head_size / pack_size; |
| 45 | + |
| 46 | + const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; |
| 47 | + const uint token_head_threads = num_tokens * num_heads * threads_per_head; |
| 48 | + |
| 49 | + if (global_idx >= token_head_threads) return; |
| 50 | + |
| 51 | + // global_idx -> token_idx + head_idx + pack_idx |
| 52 | + const uint token_head_idx = global_idx / threads_per_head; |
| 53 | + const uint pack_idx = global_idx % threads_per_head; |
| 54 | + |
| 55 | + const uint token_idx = token_head_idx / num_heads; |
| 56 | + const uint head_idx = token_head_idx % num_heads; |
| 57 | + |
| 58 | + const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. |
| 59 | + const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size; |
| 60 | + const scalar_t* prefix_head_ptr = prefix_output + head_offset; |
| 61 | + const scalar_t* suffix_head_ptr = suffix_output + head_offset; |
| 62 | + scalar_t* output_head_ptr = output + head_offset; |
| 63 | + |
| 64 | + // float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; |
| 65 | + // float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; |
| 66 | + float p_lse = prefix_lse[token_idx * num_heads + head_idx]; |
| 67 | + float s_lse = suffix_lse[token_idx * num_heads + head_idx]; |
| 68 | + p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse; |
| 69 | + s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse; |
| 70 | + |
| 71 | + const float max_lse = fmaxf(p_lse, s_lse); |
| 72 | + p_lse = p_lse - max_lse; |
| 73 | + s_lse = s_lse - max_lse; |
| 74 | + const float p_se = expf(p_lse); |
| 75 | + const float s_se = expf(s_lse); |
| 76 | + const float out_se = p_se + s_se; |
| 77 | + const float p_scale = p_se / out_se; |
| 78 | + const float s_scale = s_se / out_se; |
| 79 | + |
| 80 | + if (pack_offset < head_size) { |
| 81 | + // Pack 128b load |
| 82 | + pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size]; |
| 83 | + pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size]; |
| 84 | + pack_128b_t o_out_pack; |
| 85 | + |
| 86 | +#pragma unroll |
| 87 | + for (uint i = 0; i < pack_size; ++i) { |
| 88 | + // Always use float for FMA to keep high precision. |
| 89 | + // half(uint16_t), bfloat16, float -> float. |
| 90 | + const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]); |
| 91 | + const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]); |
| 92 | + // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) |
| 93 | + const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); |
| 94 | + // float -> half(uint16_t), bfloat16, float. |
| 95 | + from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f); |
| 96 | + } |
| 97 | + |
| 98 | + // Pack 128b storage |
| 99 | + reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack; |
| 100 | + } |
| 101 | + // We only need to write to output_lse once per head. |
| 102 | + if (output_lse != nullptr && pack_idx == 0) { |
| 103 | + float out_lse = logf(out_se) + max_lse; |
| 104 | + output_lse[token_idx * num_heads + head_idx] = out_lse; |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +// The following macro is used to dispatch the conversion function based on |
| 109 | +// the output data type. The FN is a macro that calls a function with |
| 110 | +// template<typename scalar_t>. |
| 111 | +#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ |
| 112 | + { \ |
| 113 | + if (scalar_dtype == at::ScalarType::Float) { \ |
| 114 | + fn(float); \ |
| 115 | + } else if (scalar_dtype == at::ScalarType::Half) { \ |
| 116 | + fn(half); \ |
| 117 | + } else if (scalar_dtype == at::ScalarType::BFloat16) { \ |
| 118 | + fn(__nv_bfloat16); \ |
| 119 | + } else { \ |
| 120 | + TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ |
| 121 | + } \ |
| 122 | + } |
| 123 | + |
| 124 | +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ |
| 125 | + { \ |
| 126 | + merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \ |
| 127 | + reinterpret_cast<scalar_t*>(output.data_ptr()), \ |
| 128 | + reinterpret_cast<float*>(output_lse.data_ptr()), \ |
| 129 | + reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \ |
| 130 | + reinterpret_cast<float*>(prefix_lse.data_ptr()), \ |
| 131 | + reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \ |
| 132 | + reinterpret_cast<float*>(suffix_lse.data_ptr()), \ |
| 133 | + num_tokens, \ |
| 134 | + num_heads, \ |
| 135 | + head_size); \ |
| 136 | + } |
| 137 | + |
| 138 | +/*@brief Merges the attention states from prefix and suffix |
| 139 | + * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d |
| 140 | + * |
| 141 | + * @param output [n,h,d] The output tensor to store the merged attention states. |
| 142 | + * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. |
| 143 | + * @param prefix_output [n,h,d] The prefix attention states. |
| 144 | + * @param prefix_lse [n,h] The log-sum-exp values for the prefix attention |
| 145 | + * states. |
| 146 | + * @param suffix_output [n,h,d] The suffix attention states. |
| 147 | + * @param suffix_lse [n,h] The log-sum-exp values for the suffix attention |
| 148 | + * states. |
| 149 | + */ |
| 150 | +template <typename scalar_t> |
| 151 | +void merge_attn_states_launcher( |
| 152 | + const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] |
| 153 | + const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS] |
| 154 | + const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] |
| 155 | + const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS] |
| 156 | + at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] |
| 157 | + at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS] |
| 158 | +) { |
| 159 | + constexpr uint NUM_THREADS = 128; |
| 160 | + const uint num_tokens = output.size(0); |
| 161 | + const uint num_heads = output.size(1); |
| 162 | + const uint head_size = output.size(2); |
| 163 | + const uint pack_size = 16 / sizeof(scalar_t); |
| 164 | + TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); |
| 165 | + // Process one pack elements per thread. for float, the |
| 166 | + // pack_size is 4 for half/bf16, the pack_size is 8. |
| 167 | + const uint threads_per_head = head_size / pack_size; |
| 168 | + const uint total_threads = num_tokens * num_heads * threads_per_head; |
| 169 | + |
| 170 | + dim3 block(NUM_THREADS); |
| 171 | + dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); |
| 172 | + |
| 173 | + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); |
| 174 | +} |
| 175 | + |
| 176 | +#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ |
| 177 | + { merge_attn_states_launcher<scalar_t>(v_a, s_a, v_b, s_b, v_merged, s_merged); } |
| 178 | + |
| 179 | +void merge_state_v2( |
| 180 | + at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) { |
| 181 | + // Input tensors must be contiguous |
| 182 | + CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim) |
| 183 | + CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads) |
| 184 | + CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim) |
| 185 | + CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads) |
| 186 | + // v_merged output (seq_len, num_heads, head_dim) |
| 187 | + // s_merged output_lse (seq_len, num_heads) |
| 188 | + auto device = v_a.device(); |
| 189 | + CHECK_EQ(s_a.device(), device); |
| 190 | + CHECK_EQ(v_b.device(), device); |
| 191 | + CHECK_EQ(s_b.device(), device); |
| 192 | + CHECK_DIM(3, v_a); |
| 193 | + CHECK_DIM(2, s_a); |
| 194 | + CHECK_DIM(3, v_b); |
| 195 | + CHECK_DIM(2, s_b); |
| 196 | + CHECK_SHAPE(v_a, v_b); |
| 197 | + CHECK_SHAPE(s_a, s_b); |
| 198 | + CHECK_EQ(v_a.size(0), s_a.size(0)); |
| 199 | + CHECK_EQ(v_a.size(1), s_b.size(1)); |
| 200 | + DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); |
| 201 | +} |
0 commit comments