@@ -38,10 +38,16 @@ namespace MARLIN_NAMESPACE_NAME {
38
38
39
39
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
40
40
41
- __global__ void permute_cols_kernel (int4 const * __restrict__ a_int4_ptr,
42
- int const * __restrict__ perm_int_ptr,
43
- int4 * __restrict__ out_int4_ptr, int size_m,
44
- int size_k, int block_rows) {}
41
+ template <int moe_block_size>
42
+ __global__ void permute_cols_kernel (
43
+ int4 const * __restrict__ a_int4_ptr,
44
+ int const * __restrict__ perm_int_ptr,
45
+ int4 * __restrict__ out_int4_ptr,
46
+ const int32_t * __restrict__ sorted_token_ids_ptr,
47
+ const int32_t * __restrict__ expert_ids_ptr,
48
+ const int32_t * __restrict__ num_tokens_past_padded_ptr,
49
+ int size_m, int size_k, int top_k) {};
50
+
45
51
46
52
template <typename scalar_t , // compute dtype, half or nv_float16
47
53
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
@@ -54,6 +60,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
54
60
const int stages, // number of stages for the async global->shared
55
61
// fetch pipeline
56
62
const bool has_act_order, // whether act_order is enabled
63
+ const bool has_zp, // whether zero-points are enabled
57
64
const int group_blocks = -1 , // number of consecutive 16x16 blocks
58
65
// with a separate quantization scale
59
66
const bool is_zp_float // is zero point of float16 type?
@@ -65,12 +72,22 @@ __global__ void Marlin(
65
72
int4 * __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
66
73
const int4 * __restrict__ scales_ptr, // fp16 quantization scales of shape
67
74
// (k/groupsize)xn
75
+ const int4 * __restrict__ zp_ptr, // 4bit packed zero-points of shape
76
+ // (k/groupsize)x(n/pack_factor)
68
77
const int * __restrict__ g_idx, // int32 group indices of shape k
78
+ const int32_t * __restrict__ sorted_token_ids_ptr, // moe sorted_ids
79
+ const int32_t * __restrict__ expert_ids_ptr, // moe expert ids
80
+ const int32_t * __restrict__ num_tokens_past_padded_ptr, // moe num tokens
81
+ const float * __restrict__ topk_weights_ptr, // moe top weights
82
+ int top_k, // num of experts per token
83
+ bool mul_topk_weights, // mul topk weights or not
84
+ bool is_ep, // expert parallelism
69
85
int num_groups, // number of scale groups per output channel
70
86
int prob_m, // batch dimension m
71
87
int prob_n, // output dimension n
72
88
int prob_k, // reduction dimension k
73
89
int * locks, // extra global storage for barrier synchronization
90
+ bool use_atomic_add, // whether to use atomic add to reduce
74
91
bool use_fp32_reduce // whether to use fp32 global reduce
75
92
) {}
76
93
@@ -455,27 +472,47 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
455
472
456
473
// For a given "a" of size [M,K] performs a permutation of the K columns based
457
474
// on the given "perm" indices.
458
- __global__ void permute_cols_kernel (int4 const * __restrict__ a_int4_ptr,
459
- int const * __restrict__ perm_int_ptr,
460
- int4 * __restrict__ out_int4_ptr, int size_m,
461
- int size_k, int block_rows) {
462
- int start_row = block_rows * blockIdx .x ;
463
- int finish_row = start_row + block_rows;
464
- if (finish_row > size_m) {
465
- finish_row = size_m;
466
- }
467
- int cur_block_rows = finish_row - start_row;
475
+ template <int moe_block_size>
476
+ __global__ void permute_cols_kernel (
477
+ int4 const * __restrict__ a_int4_ptr,
478
+ int const * __restrict__ perm_int_ptr,
479
+ int4 * __restrict__ out_int4_ptr,
480
+ const int32_t * __restrict__ sorted_token_ids_ptr,
481
+ const int32_t * __restrict__ expert_ids_ptr,
482
+ const int32_t * __restrict__ num_tokens_past_padded_ptr,
483
+ int size_m, int size_k, int top_k) {
468
484
485
+ int num_tokens_past_padded = num_tokens_past_padded_ptr[0 ];
486
+ int num_moe_blocks = div_ceil (num_tokens_past_padded, moe_block_size);
487
+ int32_t block_sorted_ids[moe_block_size];
488
+ int block_num_valid_tokens = 0 ;
489
+ int64_t old_expert_id = 0 ;
490
+ int64_t expert_id = 0 ;
469
491
int row_stride = size_k * sizeof (half) / 16 ;
470
492
493
+ auto read_moe_block_data = [&](int block_id) {
494
+ block_num_valid_tokens = moe_block_size;
495
+ int4 *tmp_block_sorted_ids = reinterpret_cast <int4 *>(block_sorted_ids);
496
+ for (int i = 0 ; i < moe_block_size / 4 ; i++) {
497
+ tmp_block_sorted_ids[i] = ((int4 *)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
498
+ }
499
+ for (int i = 0 ; i < moe_block_size; i++) {
500
+ if (block_sorted_ids[i] >= size_m * top_k) {
501
+ block_num_valid_tokens = i;
502
+ break ;
503
+ };
504
+ }
505
+ };
506
+
471
507
auto permute_row = [&](int row) {
472
508
int iters = size_k / default_threads;
473
509
int rest = size_k % default_threads;
474
510
475
- int offset = row * row_stride;
511
+ int in_offset = (row / top_k) * row_stride;
512
+ int out_offset = row * row_stride;
476
513
477
- half const * a_row_half = reinterpret_cast <half const *>(a_int4_ptr + offset );
478
- half* out_half = reinterpret_cast <half*>(out_int4_ptr + offset );
514
+ half const * a_row_half = reinterpret_cast <half const *>(a_int4_ptr + in_offset );
515
+ half* out_half = reinterpret_cast <half*>(out_int4_ptr + out_offset );
479
516
480
517
int base_k = 0 ;
481
518
@@ -498,11 +535,16 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
498
535
}
499
536
};
500
537
501
- for (int i = 0 ; i < cur_block_rows; i++) {
502
- int cur_row = start_row + i;
503
- if (cur_row < size_m) {
504
- permute_row (cur_row);
505
- }
538
+ for (int index = blockIdx .x ; index < num_moe_blocks; index += gridDim .x ) {
539
+ old_expert_id = expert_id;
540
+ int tmp_expert_id = expert_ids_ptr[index];
541
+ if (tmp_expert_id == -1 ) continue ;
542
+ expert_id = tmp_expert_id;
543
+ perm_int_ptr += (expert_id - old_expert_id) * size_k;
544
+ read_moe_block_data (index);
545
+
546
+ for (int i = 0 ; i < block_num_valid_tokens; i++)
547
+ permute_row (block_sorted_ids[i]);
506
548
}
507
549
}
508
550
@@ -570,7 +612,10 @@ __global__ void Marlin(
570
612
571
613
constexpr int pack_factor = 32 / w_type.size_bits ();
572
614
constexpr int moe_block_size = 16 * thread_m_blocks;
573
- constexpr int group_size = 16 * group_blocks;
615
+ const int group_size = (!has_act_order && group_blocks == -1 ) ?
616
+ prob_k : 16 * group_blocks;
617
+ const int zp_row_stride = is_zp_float ?
618
+ prob_k / group_size / 8 : prob_k / group_size / (pack_factor * 4 );
574
619
575
620
// parallel: num valid moe blocks
576
621
int num_tokens_past_padded = num_tokens_past_padded_ptr[0 ];
@@ -657,8 +702,9 @@ __global__ void Marlin(
657
702
// when move to next moe block, find the next block_id and expert_id
658
703
// and then read moe block data
659
704
auto update_next_moe_block_data = [&]() {
660
- old_expert_id = expert_id ;
705
+ if (par_id >= parallel) return ;
661
706
707
+ old_expert_id = expert_id;
662
708
if (num_invalid_blocks > 0 ) {
663
709
int skip_count = block_id == -1 ? par_id : 0 ;
664
710
block_id++;
@@ -679,7 +725,9 @@ __global__ void Marlin(
679
725
680
726
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4 );
681
727
scales_ptr += (expert_id - old_expert_id) * prob_n * prob_k / group_size / 8 ;
682
- zp_ptr += (expert_id - old_expert_id) * prob_n * prob_k / group_size / (pack_factor * 4 );
728
+ if constexpr (has_zp) {
729
+ zp_ptr += (expert_id - old_expert_id) * prob_n * zp_row_stride;
730
+ }
683
731
684
732
read_moe_block_data (block_id);
685
733
};
@@ -701,12 +749,13 @@ __global__ void Marlin(
701
749
while (remaining_ntiles_global > 0 ) {
702
750
int skip_count = block_id_write == -1 ?
703
751
(num_tiles_write_zero * blockIdx .x ) / n_tiles : 0 ;
752
+ block_id_write++;
704
753
for (int i = block_id_write; i < num_tokens_past_padded / moe_block_size; i++) {
705
- if (expert_ids_ptr[i] ! = -1 ) {
754
+ if (expert_ids_ptr[i] = = -1 ) {
706
755
if (skip_count == 0 ) {
707
756
block_id_write = i;
708
757
break ;
709
- };
758
+ }
710
759
skip_count--;
711
760
};
712
761
}
@@ -725,18 +774,15 @@ __global__ void Marlin(
725
774
int num_int4s = moe_block_size * stride_n;
726
775
int num_int4s_per_thread = div_ceil (num_int4s, threads);
727
776
728
- for (int i = 0 ; i < num_int4s_per_thread; i++) {
729
- int index = num_int4s_per_thread * threadIdx .x + i;
730
- if (index < num_int4s) break ;
731
-
732
- int row = num_int4s / stride_n;
777
+ for (int index = threadIdx .x ; index < num_int4s; index += threads) {
778
+ int row = index / stride_n;
779
+ if (row >= block_num_valid_tokens) break ;
733
780
int sorted_row = block_sorted_ids[row];
734
- int col = num_int4s % stride_n;
781
+ int col = index % stride_n;
735
782
int true_index = sorted_row * global_stride_n + off_stride_n + col;
736
783
C[true_index] = {0 , 0 , 0 , 0 };
737
784
}
738
785
739
- block_id_write++;
740
786
ntile_id = 0 ;
741
787
remaining_ntiles_global -= remaining_ntiles_in_block;
742
788
}
@@ -2305,10 +2351,19 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
2305
2351
2306
2352
if (has_act_order) {
2307
2353
// Permute A columns
2308
- int block_rows = div_ceil (prob_m, blocks);
2309
- permute_cols_kernel<<<blocks, default_threads, 0 , stream>>> (
2310
- A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
2354
+ auto kernel = permute_cols_kernel<16 >;
2355
+ if (moe_block_size == 16 ) {}
2356
+ else if (moe_block_size == 32 ) kernel = permute_cols_kernel<32 >;
2357
+ else if (moe_block_size == 48 ) kernel = permute_cols_kernel<48 >;
2358
+ else if (moe_block_size == 64 ) kernel = permute_cols_kernel<64 >;
2359
+ else TORCH_CHECK (false , " unsupported moe_block_size " , moe_block_size);
2360
+
2361
+ kernel<<<blocks, default_threads, 0 , stream>>> (
2362
+ A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr,
2363
+ expert_ids_ptr, num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
2311
2364
A_ptr = a_tmp_ptr;
2365
+ prob_m = prob_m * top_k;
2366
+ top_k = 1 ;
2312
2367
}
2313
2368
2314
2369
// If we have a full K, then we can run the non-act-order version of Marlin
@@ -2320,23 +2375,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
2320
2375
2321
2376
if (false ) {
2322
2377
}
2323
- // GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
2324
- // GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
2325
- // GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
2326
- // GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
2327
- // GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
2328
- // GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
2329
- // GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
2330
- // GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
2378
+ GPTQ_CALL_IF (vllm::kU4B8 , 16 , 4 , 256 )
2379
+ GPTQ_CALL_IF (vllm::kU4B8 , 8 , 8 , 256 )
2380
+ GPTQ_CALL_IF (vllm::kU4B8 , 8 , 4 , 128 )
2381
+ GPTQ_CALL_IF (vllm::kU4B8 , 4 , 8 , 128 )
2382
+ GPTQ_CALL_IF (vllm::kU8B128 , 16 , 4 , 256 )
2383
+ GPTQ_CALL_IF (vllm::kU8B128 , 8 , 8 , 256 )
2384
+ GPTQ_CALL_IF (vllm::kU8B128 , 8 , 4 , 128 )
2385
+ GPTQ_CALL_IF (vllm::kU8B128 , 4 , 8 , 128 )
2331
2386
2332
2387
AWQ_CALL_IF (vllm::kU4 , 16 , 4 , 256 )
2333
2388
AWQ_CALL_IF (vllm::kU4 , 8 , 8 , 256 )
2334
2389
AWQ_CALL_IF (vllm::kU4 , 8 , 4 , 128 )
2335
2390
AWQ_CALL_IF (vllm::kU4 , 4 , 8 , 128 )
2336
- // AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
2337
- // AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
2338
- // AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
2339
- // AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
2391
+ AWQ_CALL_IF (vllm::kU8 , 16 , 4 , 256 )
2392
+ AWQ_CALL_IF (vllm::kU8 , 8 , 8 , 256 )
2393
+ AWQ_CALL_IF (vllm::kU8 , 8 , 4 , 128 )
2394
+ AWQ_CALL_IF (vllm::kU8 , 4 , 8 , 128 )
2340
2395
2341
2396
// HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
2342
2397
// HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
@@ -2470,7 +2525,6 @@ torch::Tensor moe_wna16_marlin_gemm(
2470
2525
" Unexpected g_idx.size(-1) = " , g_idx.size (-1 ),
2471
2526
" and perm.size(-1) = " , perm.size (-1 ),
2472
2527
" , where size_k = " , size_k);
2473
-
2474
2528
} else {
2475
2529
g_idx = torch::empty ({0 }, options);
2476
2530
perm = torch::empty ({0 }, options);
@@ -2479,7 +2533,7 @@ torch::Tensor moe_wna16_marlin_gemm(
2479
2533
bool has_act_order = g_idx.size (-1 ) > 0 && perm.size (-1 ) > 0 ;
2480
2534
2481
2535
if (has_act_order) {
2482
- a_tmp = torch::empty ({size_m, size_k}, options);
2536
+ a_tmp = torch::empty ({size_m * top_k , size_k}, options);
2483
2537
if (is_k_full) {
2484
2538
TORCH_CHECK (num_groups > 1 , " For act_order, num_groups must be > 1" );
2485
2539
TORCH_CHECK (size_k % num_groups == 0 , " size_k = " , size_k,
@@ -2577,18 +2631,18 @@ torch::Tensor moe_wna16_marlin_gemm(
2577
2631
is_k_full, has_zp, num_groups, group_size, dev,
2578
2632
at::cuda::getCurrentCUDAStream (dev), thread_k, thread_n, sms,
2579
2633
use_atomic_add, use_fp32_reduce, is_zp_float);
2580
- // } else if (a.scalar_type() == at::ScalarType::BFloat16) {
2581
- // MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
2582
- // a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
2583
- // c .data_ptr<at::BFloat16 >(), c_tmp .data_ptr<float >(),
2584
- // b_scales .data_ptr<at::BFloat16> (), b_zeros .data_ptr(), g_idx .data_ptr(),
2585
- // perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
2586
- // sorted_token_ids .data_ptr(), expert_ids .data_ptr(),
2587
- // num_tokens_past_padded .data_ptr(), topk_weights.data_ptr() ,
2588
- // moe_block_size, top_k, mul_topk_weights, size_m, size_n, size_k ,
2589
- // workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp ,
2590
- // num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2591
- // thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
2634
+ } else if (a.scalar_type () == at::ScalarType::BFloat16) {
2635
+ MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
2636
+ a.data_ptr <at::BFloat16>(), b_q_weight.data_ptr (), c. data_ptr <at::BFloat16> (),
2637
+ c_tmp .data_ptr <float >(), b_scales .data_ptr <at::BFloat16 >(),
2638
+ b_zeros .data_ptr (), g_idx .data_ptr (), perm .data_ptr (),
2639
+ a_tmp.data_ptr <at::BFloat16>(), sorted_token_ids. data_ptr (),
2640
+ expert_ids .data_ptr (), num_tokens_past_padded .data_ptr (),
2641
+ topk_weights .data_ptr (), moe_block_size, top_k, mul_topk_weights, is_ep ,
2642
+ size_m, size_n, size_k, workspace. data_ptr (), b_q_type, has_act_order ,
2643
+ is_k_full, has_zp, num_groups, group_size, dev ,
2644
+ at::cuda::getCurrentCUDAStream (dev), thread_k, thread_n, sms ,
2645
+ use_atomic_add, use_fp32_reduce, is_zp_float);
2592
2646
} else {
2593
2647
TORCH_CHECK (false , " gpt_marlin_gemm only supports bfloat16 and float16" );
2594
2648
}
0 commit comments