Skip to content

Commit e6896d3

Browse files
committed
support act order
Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 8772450 commit e6896d3

File tree

1 file changed

+117
-63
lines changed

1 file changed

+117
-63
lines changed

csrc/moe/marlin_moe_wna16.cu

Lines changed: 117 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,16 @@ namespace MARLIN_NAMESPACE_NAME {
3838

3939
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
4040

41-
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
42-
int const* __restrict__ perm_int_ptr,
43-
int4* __restrict__ out_int4_ptr, int size_m,
44-
int size_k, int block_rows) {}
41+
template <int moe_block_size>
42+
__global__ void permute_cols_kernel(
43+
int4 const* __restrict__ a_int4_ptr,
44+
int const* __restrict__ perm_int_ptr,
45+
int4* __restrict__ out_int4_ptr,
46+
const int32_t* __restrict__ sorted_token_ids_ptr,
47+
const int32_t* __restrict__ expert_ids_ptr,
48+
const int32_t* __restrict__ num_tokens_past_padded_ptr,
49+
int size_m, int size_k, int top_k) {};
50+
4551

4652
template <typename scalar_t, // compute dtype, half or nv_float16
4753
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
@@ -54,6 +60,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
5460
const int stages, // number of stages for the async global->shared
5561
// fetch pipeline
5662
const bool has_act_order, // whether act_order is enabled
63+
const bool has_zp, // whether zero-points are enabled
5764
const int group_blocks = -1, // number of consecutive 16x16 blocks
5865
// with a separate quantization scale
5966
const bool is_zp_float // is zero point of float16 type?
@@ -65,12 +72,22 @@ __global__ void Marlin(
6572
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
6673
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
6774
// (k/groupsize)xn
75+
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
76+
// (k/groupsize)x(n/pack_factor)
6877
const int* __restrict__ g_idx, // int32 group indices of shape k
78+
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
79+
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
80+
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
81+
const float* __restrict__ topk_weights_ptr, // moe top weights
82+
int top_k, // num of experts per token
83+
bool mul_topk_weights, // mul topk weights or not
84+
bool is_ep, // expert parallelism
6985
int num_groups, // number of scale groups per output channel
7086
int prob_m, // batch dimension m
7187
int prob_n, // output dimension n
7288
int prob_k, // reduction dimension k
7389
int* locks, // extra global storage for barrier synchronization
90+
bool use_atomic_add, // whether to use atomic add to reduce
7491
bool use_fp32_reduce // whether to use fp32 global reduce
7592
) {}
7693

@@ -455,27 +472,47 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
455472

456473
// For a given "a" of size [M,K] performs a permutation of the K columns based
457474
// on the given "perm" indices.
458-
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
459-
int const* __restrict__ perm_int_ptr,
460-
int4* __restrict__ out_int4_ptr, int size_m,
461-
int size_k, int block_rows) {
462-
int start_row = block_rows * blockIdx.x;
463-
int finish_row = start_row + block_rows;
464-
if (finish_row > size_m) {
465-
finish_row = size_m;
466-
}
467-
int cur_block_rows = finish_row - start_row;
475+
template <int moe_block_size>
476+
__global__ void permute_cols_kernel(
477+
int4 const* __restrict__ a_int4_ptr,
478+
int const* __restrict__ perm_int_ptr,
479+
int4* __restrict__ out_int4_ptr,
480+
const int32_t* __restrict__ sorted_token_ids_ptr,
481+
const int32_t* __restrict__ expert_ids_ptr,
482+
const int32_t* __restrict__ num_tokens_past_padded_ptr,
483+
int size_m, int size_k, int top_k) {
468484

485+
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
486+
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
487+
int32_t block_sorted_ids[moe_block_size];
488+
int block_num_valid_tokens = 0;
489+
int64_t old_expert_id = 0;
490+
int64_t expert_id = 0;
469491
int row_stride = size_k * sizeof(half) / 16;
470492

493+
auto read_moe_block_data = [&](int block_id) {
494+
block_num_valid_tokens = moe_block_size;
495+
int4 *tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
496+
for (int i = 0; i < moe_block_size / 4; i++) {
497+
tmp_block_sorted_ids[i] = ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
498+
}
499+
for (int i = 0; i < moe_block_size; i++) {
500+
if (block_sorted_ids[i] >= size_m * top_k) {
501+
block_num_valid_tokens = i;
502+
break;
503+
};
504+
}
505+
};
506+
471507
auto permute_row = [&](int row) {
472508
int iters = size_k / default_threads;
473509
int rest = size_k % default_threads;
474510

475-
int offset = row * row_stride;
511+
int in_offset = (row / top_k) * row_stride;
512+
int out_offset = row * row_stride;
476513

477-
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
478-
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
514+
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + in_offset);
515+
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
479516

480517
int base_k = 0;
481518

@@ -498,11 +535,16 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
498535
}
499536
};
500537

501-
for (int i = 0; i < cur_block_rows; i++) {
502-
int cur_row = start_row + i;
503-
if (cur_row < size_m) {
504-
permute_row(cur_row);
505-
}
538+
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
539+
old_expert_id = expert_id;
540+
int tmp_expert_id = expert_ids_ptr[index];
541+
if (tmp_expert_id == -1) continue;
542+
expert_id = tmp_expert_id;
543+
perm_int_ptr += (expert_id - old_expert_id) * size_k;
544+
read_moe_block_data(index);
545+
546+
for (int i = 0; i < block_num_valid_tokens; i++)
547+
permute_row(block_sorted_ids[i]);
506548
}
507549
}
508550

@@ -570,7 +612,10 @@ __global__ void Marlin(
570612

571613
constexpr int pack_factor = 32 / w_type.size_bits();
572614
constexpr int moe_block_size = 16 * thread_m_blocks;
573-
constexpr int group_size = 16 * group_blocks;
615+
const int group_size = (!has_act_order && group_blocks == -1) ?
616+
prob_k : 16 * group_blocks;
617+
const int zp_row_stride = is_zp_float ?
618+
prob_k / group_size / 8 : prob_k / group_size / (pack_factor * 4);
574619

575620
// parallel: num valid moe blocks
576621
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
@@ -657,8 +702,9 @@ __global__ void Marlin(
657702
// when move to next moe block, find the next block_id and expert_id
658703
// and then read moe block data
659704
auto update_next_moe_block_data = [&]() {
660-
old_expert_id = expert_id;
705+
if (par_id >= parallel) return;
661706

707+
old_expert_id = expert_id;
662708
if (num_invalid_blocks > 0) {
663709
int skip_count = block_id == -1 ? par_id : 0;
664710
block_id++;
@@ -679,7 +725,9 @@ __global__ void Marlin(
679725

680726
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
681727
scales_ptr += (expert_id - old_expert_id) * prob_n * prob_k / group_size / 8;
682-
zp_ptr += (expert_id - old_expert_id) * prob_n * prob_k / group_size / (pack_factor * 4);
728+
if constexpr (has_zp) {
729+
zp_ptr += (expert_id - old_expert_id) * prob_n * zp_row_stride;
730+
}
683731

684732
read_moe_block_data(block_id);
685733
};
@@ -701,12 +749,13 @@ __global__ void Marlin(
701749
while (remaining_ntiles_global > 0) {
702750
int skip_count = block_id_write == -1 ?
703751
(num_tiles_write_zero * blockIdx.x) / n_tiles : 0;
752+
block_id_write++;
704753
for (int i = block_id_write; i < num_tokens_past_padded / moe_block_size; i++) {
705-
if (expert_ids_ptr[i] != -1) {
754+
if (expert_ids_ptr[i] == -1) {
706755
if (skip_count == 0) {
707756
block_id_write = i;
708757
break;
709-
};
758+
}
710759
skip_count--;
711760
};
712761
}
@@ -725,18 +774,15 @@ __global__ void Marlin(
725774
int num_int4s = moe_block_size * stride_n;
726775
int num_int4s_per_thread = div_ceil(num_int4s, threads);
727776

728-
for (int i = 0; i < num_int4s_per_thread; i++) {
729-
int index = num_int4s_per_thread * threadIdx.x + i;
730-
if (index < num_int4s) break;
731-
732-
int row = num_int4s / stride_n;
777+
for (int index = threadIdx.x; index < num_int4s; index += threads) {
778+
int row = index / stride_n;
779+
if (row >= block_num_valid_tokens) break;
733780
int sorted_row = block_sorted_ids[row];
734-
int col = num_int4s % stride_n;
781+
int col = index % stride_n;
735782
int true_index = sorted_row * global_stride_n + off_stride_n + col;
736783
C[true_index] = {0, 0, 0, 0};
737784
}
738785

739-
block_id_write++;
740786
ntile_id = 0;
741787
remaining_ntiles_global -= remaining_ntiles_in_block;
742788
}
@@ -2305,10 +2351,19 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
23052351

23062352
if (has_act_order) {
23072353
// Permute A columns
2308-
int block_rows = div_ceil(prob_m, blocks);
2309-
permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
2310-
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
2354+
auto kernel = permute_cols_kernel<16>;
2355+
if (moe_block_size == 16) {}
2356+
else if (moe_block_size == 32) kernel = permute_cols_kernel<32>;
2357+
else if (moe_block_size == 48) kernel = permute_cols_kernel<48>;
2358+
else if (moe_block_size == 64) kernel = permute_cols_kernel<64>;
2359+
else TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
2360+
2361+
kernel<<<blocks, default_threads, 0, stream>>>(
2362+
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr,
2363+
expert_ids_ptr, num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
23112364
A_ptr = a_tmp_ptr;
2365+
prob_m = prob_m * top_k;
2366+
top_k = 1;
23122367
}
23132368

23142369
// If we have a full K, then we can run the non-act-order version of Marlin
@@ -2320,23 +2375,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
23202375

23212376
if (false) {
23222377
}
2323-
// GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
2324-
// GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
2325-
// GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
2326-
// GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
2327-
// GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
2328-
// GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
2329-
// GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
2330-
// GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
2378+
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
2379+
GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
2380+
GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
2381+
GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
2382+
GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
2383+
GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
2384+
GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
2385+
GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
23312386

23322387
AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
23332388
AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
23342389
AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
23352390
AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
2336-
// AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
2337-
// AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
2338-
// AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
2339-
// AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
2391+
AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
2392+
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
2393+
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
2394+
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
23402395

23412396
// HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
23422397
// HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
@@ -2470,7 +2525,6 @@ torch::Tensor moe_wna16_marlin_gemm(
24702525
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
24712526
" and perm.size(-1) = ", perm.size(-1),
24722527
", where size_k = ", size_k);
2473-
24742528
} else {
24752529
g_idx = torch::empty({0}, options);
24762530
perm = torch::empty({0}, options);
@@ -2479,7 +2533,7 @@ torch::Tensor moe_wna16_marlin_gemm(
24792533
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
24802534

24812535
if (has_act_order) {
2482-
a_tmp = torch::empty({size_m, size_k}, options);
2536+
a_tmp = torch::empty({size_m * top_k, size_k}, options);
24832537
if (is_k_full) {
24842538
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
24852539
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
@@ -2577,18 +2631,18 @@ torch::Tensor moe_wna16_marlin_gemm(
25772631
is_k_full, has_zp, num_groups, group_size, dev,
25782632
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
25792633
use_atomic_add, use_fp32_reduce, is_zp_float);
2580-
// } else if (a.scalar_type() == at::ScalarType::BFloat16) {
2581-
// MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
2582-
// a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
2583-
// c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
2584-
// b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
2585-
// perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
2586-
// sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
2587-
// num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
2588-
// moe_block_size, top_k, mul_topk_weights, size_m, size_n, size_k,
2589-
// workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
2590-
// num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
2591-
// thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
2634+
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
2635+
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
2636+
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), c.data_ptr<at::BFloat16>(),
2637+
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::BFloat16>(),
2638+
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
2639+
a_tmp.data_ptr<at::BFloat16>(), sorted_token_ids.data_ptr(),
2640+
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
2641+
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
2642+
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
2643+
is_k_full, has_zp, num_groups, group_size, dev,
2644+
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
2645+
use_atomic_add, use_fp32_reduce, is_zp_float);
25922646
} else {
25932647
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
25942648
}

0 commit comments

Comments
 (0)