|
3 | 3 |
|
4 | 4 | #include "fused_moe.hpp"
|
5 | 5 |
|
6 |
| -float fused_moe( |
7 |
| - fused_moe_traits t, |
8 |
| - fused_moe_args a, |
9 |
| - const ck_tile::stream_config& s) { |
10 |
| - auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1}; |
| 6 | +float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) |
| 7 | +{ |
| 8 | + auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1}; |
11 | 9 |
|
12 |
| - auto o_data_bytes = [&]() { |
13 |
| - if (t.prec_o == "fp32") |
14 |
| - return 4; |
15 |
| - else if (t.prec_o == "fp16" || t.prec_o == "bf16") |
16 |
| - return 2; |
17 |
| - else if (t.prec_o == "int8" || t.prec_o == "fp8") |
18 |
| - return 1; |
19 |
| - return 1; |
20 |
| - }(); |
| 10 | + auto o_data_bytes = [&]() { |
| 11 | + if(t.prec_o == "fp32") |
| 12 | + return 4; |
| 13 | + else if(t.prec_o == "fp16" || t.prec_o == "bf16") |
| 14 | + return 2; |
| 15 | + else if(t.prec_o == "int8" || t.prec_o == "fp8") |
| 16 | + return 1; |
| 17 | + return 1; |
| 18 | + }(); |
21 | 19 |
|
22 |
| - auto t0 = fused_moesorting_trait{"int32", "fp32"}; |
23 |
| - auto a0 = fused_moesorting_args{ |
24 |
| - a.topk_ids_ptr, // const void* p_topk_ids; |
25 |
| - a.topk_weight_ptr, // const void* p_weights; |
26 |
| - a.sorted_token_ids_ptr, // void* p_sorted_token_ids; |
27 |
| - a.sorted_weight_ptr, // void* p_sorted_weights; |
28 |
| - a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; |
29 |
| - a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; |
30 |
| - a.o_ptr, // void* p_moe_buf; |
31 |
| - a.num_tokens, // index_t tokens; |
32 |
| - a.block_m, // index_t unit_size; |
33 |
| - a.num_experts, // index_t num_experts; |
34 |
| - a.topk, // index_t topk; |
35 |
| - a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; |
36 |
| - }; |
| 20 | + auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; |
| 21 | + auto a0 = fused_moesorting_args{ |
| 22 | + a.topk_ids_ptr, // const void* p_topk_ids; |
| 23 | + a.topk_weight_ptr, // const void* p_weights; |
| 24 | + a.local_expert_mask_ptr, // const void* p_local_expert_mask; |
| 25 | + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; |
| 26 | + a.sorted_weight_ptr, // void* p_sorted_weights; |
| 27 | + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; |
| 28 | + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; |
| 29 | + a.o_ptr, // void* p_moe_buf; |
| 30 | + a.num_tokens, // index_t tokens; |
| 31 | + a.block_m, // index_t unit_size; |
| 32 | + a.num_experts, // index_t num_experts; |
| 33 | + a.topk, // index_t topk; |
| 34 | + a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; |
| 35 | + }; |
37 | 36 |
|
38 |
| - auto t1 = fused_moegemm_traits{ |
39 |
| - t.prec_i, |
40 |
| - t.prec_w, |
41 |
| - t.prec_o, |
42 |
| - t.prec_st, |
43 |
| - t.prec_sw, |
44 |
| - t.prec_sq, |
45 |
| - t.prec_kw, |
46 |
| - t.block_m, |
47 |
| - t.activation, |
48 |
| - t.gate_only, |
49 |
| - t.fused_quant}; |
50 |
| - auto a1 = fused_moegemm_args{ |
51 |
| - a.a_ptr, // const void* a_ptr; |
52 |
| - a.a_scale_ptr, // const void* a_scale_ptr; |
53 |
| - a.g_ptr, // const void* g_ptr; |
54 |
| - a.d_ptr, // const void* d_ptr; |
55 |
| - a.g_scale_ptr, // const void* g_scale_ptr; |
56 |
| - a.d_scale_ptr, // const void* d_scale_ptr; |
57 |
| - a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr; |
58 |
| - a.o_ptr, // void* o_ptr; |
59 |
| - a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr; |
60 |
| - a.sorted_weight_ptr, // const void* sorted_weight_ptr; |
61 |
| - a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr; |
62 |
| - a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr; |
63 |
| - a.hidden_size, // index_t hidden_size; |
64 |
| - a.intermediate_size, // index_t intermediate_size; |
65 |
| - a.num_tokens, // index_t num_tokens; |
66 |
| - a.num_experts, // index_t num_experts; |
67 |
| - a.topk, // index_t topk; |
68 |
| - a.stride_token // index_t stride_token; |
69 |
| - }; |
| 37 | + auto t1 = fused_moegemm_traits{t.prec_i, |
| 38 | + t.prec_w, |
| 39 | + t.prec_o, |
| 40 | + t.prec_st, |
| 41 | + t.prec_sw, |
| 42 | + t.prec_sq, |
| 43 | + t.prec_kw, |
| 44 | + t.block_m, |
| 45 | + t.activation, |
| 46 | + t.gate_only, |
| 47 | + t.fused_quant}; |
| 48 | + auto a1 = fused_moegemm_args{ |
| 49 | + a.a_ptr, // const void* a_ptr; |
| 50 | + a.a_scale_ptr, // const void* a_scale_ptr; |
| 51 | + a.g_ptr, // const void* g_ptr; |
| 52 | + a.d_ptr, // const void* d_ptr; |
| 53 | + a.g_scale_ptr, // const void* g_scale_ptr; |
| 54 | + a.d_scale_ptr, // const void* d_scale_ptr; |
| 55 | + a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr; |
| 56 | + a.o_ptr, // void* o_ptr; |
| 57 | + a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr; |
| 58 | + a.sorted_weight_ptr, // const void* sorted_weight_ptr; |
| 59 | + a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr; |
| 60 | + a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr; |
| 61 | + a.hidden_size, // index_t hidden_size; |
| 62 | + a.intermediate_size, // index_t intermediate_size; |
| 63 | + a.num_tokens, // index_t num_tokens; |
| 64 | + a.num_experts, // index_t num_experts; |
| 65 | + a.topk, // index_t topk; |
| 66 | + a.stride_token // index_t stride_token; |
| 67 | + }; |
70 | 68 |
|
71 |
| - float r0 = -1; |
72 |
| - float r1 = -1; |
| 69 | + float r0 = -1; |
| 70 | + float r1 = -1; |
73 | 71 |
|
74 |
| - float r = ck_tile::launch_kernel( |
75 |
| - s, |
76 |
| - [=, &r0](const ck_tile::stream_config&) { |
77 |
| - r0 = fused_moesorting(t0, a0, s_sub); |
78 |
| - }, |
79 |
| - [=, &r1](const ck_tile::stream_config&) { |
80 |
| - r1 = fused_moegemm(t1, a1, s_sub); |
81 |
| - }); |
| 72 | + float r = ck_tile::launch_kernel( |
| 73 | + s, |
| 74 | + [=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); }, |
| 75 | + [=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); }); |
82 | 76 |
|
83 |
| - // keep unsupported case return negative |
84 |
| - if (r0 < 0 || r1 < 0) |
85 |
| - return -1; |
| 77 | + // keep unsupported case return negative |
| 78 | + if(r0 < 0 || r1 < 0) |
| 79 | + return -1; |
86 | 80 |
|
87 |
| - return r; |
| 81 | + return r; |
88 | 82 | }
|
0 commit comments