@@ -152,6 +152,85 @@ void segment_gemm_kernel_impl(
152
152
});
153
153
}
154
154
155
+ // [C0, C1] = A @ [B0, B1]
156
+ template <typename scalar_t >
157
+ void segment_gemm_kernel_impl (
158
+ scalar_t * __restrict__ C0,
159
+ scalar_t * __restrict__ C1,
160
+ const scalar_t * __restrict__ A,
161
+ const at::Float8_e4m3fn* __restrict__ B0,
162
+ const at::Float8_e4m3fn* __restrict__ B1,
163
+ const float * __restrict__ Bs0,
164
+ const float * __restrict__ Bs1,
165
+ int64_t M,
166
+ int64_t N0,
167
+ int64_t N1,
168
+ int64_t K,
169
+ int64_t block_size_N,
170
+ int64_t block_size_K) {
171
+ constexpr int64_t BLOCK_M = block_size_m ();
172
+ constexpr int64_t BLOCK_N = block_size_n ();
173
+ const int64_t MB = div_up (M, BLOCK_M);
174
+ const int64_t NB0 = div_up (N0, BLOCK_N);
175
+ const int64_t NB1 = div_up (N1, BLOCK_N);
176
+ const int64_t NB = NB0 + NB1;
177
+
178
+ const int64_t scale_size_K = div_up (K, block_size_K);
179
+ const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
180
+
181
+ const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
182
+
183
+ // parallel on [MB, NB0 + NB1]
184
+ at::parallel_for (0 , MB * NB, 0 , [&](int64_t begin, int64_t end) {
185
+ int64_t mb{0 }, nb{0 };
186
+ data_index_init (begin, mb, MB, nb, NB);
187
+
188
+ // for brgemm, use float32 for accumulate
189
+ alignas (64 ) float Ctmp[BLOCK_M * BLOCK_N];
190
+ // for brgemm when mat2 is float8_e4m3
191
+ alignas (64 ) scalar_t Btmp[BLOCK_N * BLOCK_K];
192
+
193
+ for (int64_t i = begin; i < end; ++i) {
194
+ UNUSED (i);
195
+
196
+ int mb_start = mb * BLOCK_M;
197
+ int mb_size = std::min (M - mb_start, BLOCK_M);
198
+ int nb_start = nb * BLOCK_N;
199
+ int nb_size = BLOCK_N;
200
+
201
+ const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1;
202
+ const float * __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
203
+ scalar_t * __restrict__ C = nb < NB0 ? C0 : C1;
204
+ int64_t ldc = nb < NB0 ? N0 : N1;
205
+ int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
206
+ int64_t new_nb = nb < NB0 ? nb : nb - NB0;
207
+
208
+ tinygemm_kernel<scalar_t >(
209
+ /* A */ A + mb_start * K,
210
+ /* B */ B + local_nb_start * K /* nb * BLOCK_N * K */ ,
211
+ /* C */ C + mb_start * ldc + local_nb_start,
212
+ /* Btmp*/ Btmp,
213
+ /* Ctmp*/ Ctmp,
214
+ /* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
215
+ /* M */ mb_size,
216
+ /* N */ nb_size,
217
+ /* K */ K,
218
+ /* lda */ K,
219
+ /* ldb */ nb_size,
220
+ /* ldc */ ldc,
221
+ /* brg */ use_brgemm,
222
+ /* block_size_K */ block_size_K);
223
+
224
+ // move to the next index
225
+ data_index_step (mb, MB, nb, NB);
226
+ }
227
+
228
+ if (use_brgemm) {
229
+ at::native::cpublas::brgemm_release ();
230
+ }
231
+ });
232
+ }
233
+
155
234
template <typename scalar_t >
156
235
inline float reduce (const scalar_t * __restrict__ x, int64_t size) {
157
236
using bVec = at::vec::Vectorized<scalar_t >;
@@ -321,6 +400,15 @@ extern at::Tensor int8_scaled_mm_with_quant(
321
400
extern void
322
401
bmm_cpu (at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
323
402
403
+ extern at::Tensor fp8_scaled_mm_cpu (
404
+ at::Tensor& mat1,
405
+ at::Tensor& mat2,
406
+ at::Tensor& scales2,
407
+ std::vector<int64_t > block_size,
408
+ const std::optional<at::Tensor>& bias,
409
+ at::ScalarType out_dtype,
410
+ bool is_vnni);
411
+
324
412
// NB: shapes in DeepDeek R1
325
413
//
326
414
// hidden_states : [num_seqs, hidden_size] [1, 7168]
@@ -343,10 +431,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
343
431
at::Tensor& cos_sin_cache,
344
432
double eps,
345
433
bool use_int8_w8a8,
434
+ bool use_fp8_w8a16,
346
435
std::optional<at::Tensor> q_a_proj_scale,
347
436
std::optional<at::Tensor> q_b_proj_scale,
348
437
std::optional<at::Tensor> kv_a_proj_scale,
349
- bool is_vnni) {
438
+ bool is_vnni,
439
+ std::optional<std::vector<int64_t >> block_size) {
350
440
RECORD_FUNCTION (
351
441
" sgl-kernel::qkv_proj_with_rope" ,
352
442
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
@@ -394,7 +484,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
394
484
TORCH_CHECK (q_b_proj_scale.has_value (), " missing q_b_proj_scale for int8 w8a8." );
395
485
TORCH_CHECK (kv_a_proj_scale.has_value (), " missing kv_a_proj_scale for int8 w8a8." );
396
486
}
397
-
487
+ if (use_fp8_w8a16) {
488
+ TORCH_CHECK (q_a_proj_scale.has_value (), " missing q_a_proj_scale for fp8 w8a16." );
489
+ TORCH_CHECK (q_b_proj_scale.has_value (), " missing q_b_proj_scale for fp8 w8a16." );
490
+ TORCH_CHECK (kv_a_proj_scale.has_value (), " missing kv_a_proj_scale for fp8 w8a16." );
491
+ TORCH_CHECK (block_size.has_value (), " missing block_size for fp8 w8a16." );
492
+ TORCH_CHECK (block_size.value ().size () == 2 , " block_size should be 2D for fp8 w8a16." );
493
+ }
398
494
// outputs and temp buffer
399
495
const auto options = hidden_states.options ();
400
496
auto q_input = at::empty ({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
@@ -436,6 +532,29 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
436
532
q_lora_rank,
437
533
kv_lora_rank + qk_rope_head_dim,
438
534
hidden_size);
535
+ } else if (use_fp8_w8a16) {
536
+ int64_t block_size_N = block_size.value ()[0 ];
537
+ int64_t block_size_K = block_size.value ()[1 ];
538
+ auto q_a_proj_s = q_a_proj_scale.value ();
539
+ auto kv_a_proj_s = kv_a_proj_scale.value ();
540
+ CHECK_EQ (q_a_proj_s.size (0 ), div_up (q_lora_rank, block_size_N));
541
+ CHECK_EQ (q_a_proj_s.size (1 ), div_up (hidden_size, block_size_K));
542
+ CHECK_EQ (kv_a_proj_s.size (0 ), div_up (kv_lora_rank + qk_rope_head_dim, block_size_N));
543
+ CHECK_EQ (kv_a_proj_s.size (1 ), div_up (hidden_size, block_size_K));
544
+ segment_gemm_kernel_impl<scalar_t >(
545
+ qa.data_ptr <scalar_t >(),
546
+ k_input.data_ptr <scalar_t >(),
547
+ hidden_states.data_ptr <scalar_t >(),
548
+ q_a_proj_weight.data_ptr <at::Float8_e4m3fn>(),
549
+ kv_a_proj_weight.data_ptr <at::Float8_e4m3fn>(),
550
+ q_a_proj_s.data_ptr <float >(),
551
+ kv_a_proj_s.data_ptr <float >(),
552
+ num_seqs,
553
+ q_lora_rank,
554
+ kv_lora_rank + qk_rope_head_dim,
555
+ hidden_size,
556
+ block_size_N,
557
+ block_size_K);
439
558
} else {
440
559
segment_gemm_kernel_impl<scalar_t >(
441
560
qa.data_ptr <scalar_t >(),
@@ -469,6 +588,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
469
588
std::optional<at::Tensor> bias;
470
589
if (use_int8_w8a8) {
471
590
qb = int8_scaled_mm_with_quant (qa, q_b_proj_weight, q_b_proj_scale.value (), bias, at::kBFloat16 , is_vnni);
591
+ } else if (use_fp8_w8a16) {
592
+ qb = fp8_scaled_mm_cpu (
593
+ qa, q_b_proj_weight, q_b_proj_scale.value (), block_size.value (), bias, at::kBFloat16 , is_vnni);
472
594
} else {
473
595
qb = weight_packed_linear (qa, q_b_proj_weight, bias, is_vnni);
474
596
}
0 commit comments