@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
46
46
}
47
47
48
48
__global__ void compute_arg_sorts (const int * __restrict__ topk_ids,
49
+ const int32_t * __restrict__ expert_offsets,
49
50
int32_t * input_permutation,
50
51
int32_t * output_permutation,
51
52
int32_t * atomic_buffer, const int topk_length,
52
53
const int topk) {
53
- int expert_id = blockIdx .x ;
54
+ int const blk_expert_id = blockIdx .x ;
55
+ int const num_experts = gridDim .x ;
56
+ int32_t const num_tokens = expert_offsets[num_experts];
54
57
55
58
for (int i = threadIdx .x ; i < topk_length; i += THREADS_PER_EXPERT) {
56
- if (topk_ids[i] == expert_id) {
59
+ int const expert_id = topk_ids[i];
60
+ if (expert_id == -1 && blockIdx .x == 0 ) {
61
+ // output_permutation is used to re-order the moe outputs. It is
62
+ // used as c2 = c2[c_map], where c2 is a torch.tensor that is the
63
+ // output of the cutlass kernels and c_map is the output_permutation.
64
+ // c2 is initialized to zeros, therefore by setting the output_permutation
65
+ // to num_tokens, we are guaranteed to fill the moe outputs to zero
66
+ // for "invalid" topk_ids.
67
+ output_permutation[i] = num_tokens;
68
+ } else if (expert_id == blk_expert_id) {
57
69
int start = atomicAdd (&atomic_buffer[expert_id], 1 );
58
70
input_permutation[start] = i / topk;
59
71
output_permutation[i] = start;
@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
83
95
static_cast <int32_t *>(atomic_buffer.data_ptr ()), num_experts);
84
96
compute_arg_sorts<<<num_experts, num_threads, 0 , stream>>> (
85
97
static_cast <const int32_t *>(topk_ids.data_ptr ()),
98
+ static_cast <const int32_t *>(expert_offsets.data_ptr ()),
86
99
static_cast <int32_t *>(input_permutation.data_ptr ()),
87
100
static_cast <int32_t *>(output_permutation.data_ptr ()),
88
101
static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (),
0 commit comments