Skip to content

Commit 4fe6701

Browse files
varun-sundar-rabindranathvarun sundar rabindranath
authored andcommitted
[Kernel] Add expert_map support to Cutlass FP8 MOE (vllm-project#16861)
Signed-off-by: varun sundar rabindranath <[email protected]> Co-authored-by: varun sundar rabindranath <[email protected]>
1 parent c2e6ce6 commit 4fe6701

File tree

5 files changed

+333
-173
lines changed

5 files changed

+333
-173
lines changed

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
4646
}
4747

4848
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
49+
const int32_t* __restrict__ expert_offsets,
4950
int32_t* input_permutation,
5051
int32_t* output_permutation,
5152
int32_t* atomic_buffer, const int topk_length,
5253
const int topk) {
53-
int expert_id = blockIdx.x;
54+
int const blk_expert_id = blockIdx.x;
55+
int const num_experts = gridDim.x;
56+
int32_t const num_tokens = expert_offsets[num_experts];
5457

5558
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
56-
if (topk_ids[i] == expert_id) {
59+
int const expert_id = topk_ids[i];
60+
if (expert_id == -1 && blockIdx.x == 0) {
61+
// output_permutation is used to re-order the moe outputs. It is
62+
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
63+
// output of the cutlass kernels and c_map is the output_permutation.
64+
// c2 is initialized to zeros, therefore by setting the output_permutation
65+
// to num_tokens, we are guaranteed to fill the moe outputs to zero
66+
// for "invalid" topk_ids.
67+
output_permutation[i] = num_tokens;
68+
} else if (expert_id == blk_expert_id) {
5769
int start = atomicAdd(&atomic_buffer[expert_id], 1);
5870
input_permutation[start] = i / topk;
5971
output_permutation[i] = start;
@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
8395
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
8496
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
8597
static_cast<const int32_t*>(topk_ids.data_ptr()),
98+
static_cast<const int32_t*>(expert_offsets.data_ptr()),
8699
static_cast<int32_t*>(input_permutation.data_ptr()),
87100
static_cast<int32_t*>(output_permutation.data_ptr()),
88101
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),

0 commit comments

Comments
 (0)