Skip to content

Commit 4ba1eea

Browse files
authored
Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT (#6493)
1 parent 4685fbb commit 4ba1eea

File tree

5 files changed

+483
-11
lines changed

5 files changed

+483
-11
lines changed

sgl-kernel/csrc/cpu/qkv_proj.cpp

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,85 @@ void segment_gemm_kernel_impl(
152152
});
153153
}
154154

155+
// [C0, C1] = A @ [B0, B1]
156+
template <typename scalar_t>
157+
void segment_gemm_kernel_impl(
158+
scalar_t* __restrict__ C0,
159+
scalar_t* __restrict__ C1,
160+
const scalar_t* __restrict__ A,
161+
const at::Float8_e4m3fn* __restrict__ B0,
162+
const at::Float8_e4m3fn* __restrict__ B1,
163+
const float* __restrict__ Bs0,
164+
const float* __restrict__ Bs1,
165+
int64_t M,
166+
int64_t N0,
167+
int64_t N1,
168+
int64_t K,
169+
int64_t block_size_N,
170+
int64_t block_size_K) {
171+
constexpr int64_t BLOCK_M = block_size_m();
172+
constexpr int64_t BLOCK_N = block_size_n();
173+
const int64_t MB = div_up(M, BLOCK_M);
174+
const int64_t NB0 = div_up(N0, BLOCK_N);
175+
const int64_t NB1 = div_up(N1, BLOCK_N);
176+
const int64_t NB = NB0 + NB1;
177+
178+
const int64_t scale_size_K = div_up(K, block_size_K);
179+
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
180+
181+
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
182+
183+
// parallel on [MB, NB0 + NB1]
184+
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
185+
int64_t mb{0}, nb{0};
186+
data_index_init(begin, mb, MB, nb, NB);
187+
188+
// for brgemm, use float32 for accumulate
189+
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
190+
// for brgemm when mat2 is float8_e4m3
191+
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
192+
193+
for (int64_t i = begin; i < end; ++i) {
194+
UNUSED(i);
195+
196+
int mb_start = mb * BLOCK_M;
197+
int mb_size = std::min(M - mb_start, BLOCK_M);
198+
int nb_start = nb * BLOCK_N;
199+
int nb_size = BLOCK_N;
200+
201+
const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1;
202+
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
203+
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
204+
int64_t ldc = nb < NB0 ? N0 : N1;
205+
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
206+
int64_t new_nb = nb < NB0 ? nb : nb - NB0;
207+
208+
tinygemm_kernel<scalar_t>(
209+
/* A */ A + mb_start * K,
210+
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
211+
/* C */ C + mb_start * ldc + local_nb_start,
212+
/* Btmp*/ Btmp,
213+
/* Ctmp*/ Ctmp,
214+
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
215+
/* M */ mb_size,
216+
/* N */ nb_size,
217+
/* K */ K,
218+
/* lda */ K,
219+
/* ldb */ nb_size,
220+
/* ldc */ ldc,
221+
/* brg */ use_brgemm,
222+
/* block_size_K */ block_size_K);
223+
224+
// move to the next index
225+
data_index_step(mb, MB, nb, NB);
226+
}
227+
228+
if (use_brgemm) {
229+
at::native::cpublas::brgemm_release();
230+
}
231+
});
232+
}
233+
155234
template <typename scalar_t>
156235
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
157236
using bVec = at::vec::Vectorized<scalar_t>;
@@ -321,6 +400,15 @@ extern at::Tensor int8_scaled_mm_with_quant(
321400
extern void
322401
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
323402

403+
extern at::Tensor fp8_scaled_mm_cpu(
404+
at::Tensor& mat1,
405+
at::Tensor& mat2,
406+
at::Tensor& scales2,
407+
std::vector<int64_t> block_size,
408+
const std::optional<at::Tensor>& bias,
409+
at::ScalarType out_dtype,
410+
bool is_vnni);
411+
324412
// NB: shapes in DeepDeek R1
325413
//
326414
// hidden_states : [num_seqs, hidden_size] [1, 7168]
@@ -343,10 +431,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
343431
at::Tensor& cos_sin_cache,
344432
double eps,
345433
bool use_int8_w8a8,
434+
bool use_fp8_w8a16,
346435
std::optional<at::Tensor> q_a_proj_scale,
347436
std::optional<at::Tensor> q_b_proj_scale,
348437
std::optional<at::Tensor> kv_a_proj_scale,
349-
bool is_vnni) {
438+
bool is_vnni,
439+
std::optional<std::vector<int64_t>> block_size) {
350440
RECORD_FUNCTION(
351441
"sgl-kernel::qkv_proj_with_rope",
352442
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
@@ -394,7 +484,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
394484
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
395485
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
396486
}
397-
487+
if (use_fp8_w8a16) {
488+
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16.");
489+
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16.");
490+
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16.");
491+
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
492+
TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16.");
493+
}
398494
// outputs and temp buffer
399495
const auto options = hidden_states.options();
400496
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
@@ -436,6 +532,29 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
436532
q_lora_rank,
437533
kv_lora_rank + qk_rope_head_dim,
438534
hidden_size);
535+
} else if (use_fp8_w8a16) {
536+
int64_t block_size_N = block_size.value()[0];
537+
int64_t block_size_K = block_size.value()[1];
538+
auto q_a_proj_s = q_a_proj_scale.value();
539+
auto kv_a_proj_s = kv_a_proj_scale.value();
540+
CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N));
541+
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
542+
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
543+
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
544+
segment_gemm_kernel_impl<scalar_t>(
545+
qa.data_ptr<scalar_t>(),
546+
k_input.data_ptr<scalar_t>(),
547+
hidden_states.data_ptr<scalar_t>(),
548+
q_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
549+
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
550+
q_a_proj_s.data_ptr<float>(),
551+
kv_a_proj_s.data_ptr<float>(),
552+
num_seqs,
553+
q_lora_rank,
554+
kv_lora_rank + qk_rope_head_dim,
555+
hidden_size,
556+
block_size_N,
557+
block_size_K);
439558
} else {
440559
segment_gemm_kernel_impl<scalar_t>(
441560
qa.data_ptr<scalar_t>(),
@@ -469,6 +588,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
469588
std::optional<at::Tensor> bias;
470589
if (use_int8_w8a8) {
471590
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
591+
} else if (use_fp8_w8a16) {
592+
qb = fp8_scaled_mm_cpu(
593+
qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni);
472594
} else {
473595
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
474596
}

sgl-kernel/csrc/cpu/torch_extension_cpu.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
165165
at::Tensor& cos_sin_cache,
166166
double eps,
167167
bool use_int8_w8a8,
168+
bool use_fp8_w8a16,
168169
std::optional<at::Tensor> q_a_proj_scale,
169170
std::optional<at::Tensor> q_b_proj_scale,
170171
std::optional<at::Tensor> kv_a_proj_scale,
171-
bool is_vnni);
172+
bool is_vnni,
173+
std::optional<std::vector<int64_t>> block_size);
172174

173175
// shared memory init
174176
void initialize(int64_t size, int64_t rank);
@@ -209,8 +211,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
209211

210212
// decode
211213
m.def(
212-
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
213-
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
214+
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
215+
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
216+
"float logit_cap) -> ()");
214217
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
215218

216219
// extend
@@ -265,8 +268,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
265268
m.def(
266269
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
267270
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
268-
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
269-
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
271+
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
272+
"q_b_proj_scale, Tensor? "
273+
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
270274
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
271275

272276
// shared expert

test/srt/cpu/test_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

3+
import sgl_kernel
34
import torch
4-
from sgl_kernel.common_ops import decode_attention_cpu as decode_attention
55
from torch.nn.functional import scaled_dot_product_attention
66

77
from sglang.test.test_utils import CustomTestCase
@@ -105,7 +105,7 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
105105
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
106106
key = key.transpose(0, 1).contiguous().transpose(0, 1)
107107
value = value.transpose(0, 1).contiguous().transpose(0, 1)
108-
decode_attention(
108+
torch.ops.sgl_kernel.decode_attention_cpu(
109109
q,
110110
k_buffer,
111111
v_buffer,

test/srt/cpu/test_extend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

3+
import sgl_kernel
34
import torch
4-
from sgl_kernel.common_ops import extend_attention_cpu as extend_attention
55
from torch.nn.functional import scaled_dot_product_attention
66

77
from sglang.test.test_utils import CustomTestCase
@@ -157,7 +157,7 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False):
157157
)
158158

159159
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
160-
extend_attention(
160+
torch.ops.sgl_kernel.extend_attention_cpu(
161161
q_extend,
162162
k_extend,
163163
v_extend,

0 commit comments

Comments
 (0)