Skip to content

Commit 2650c39

Browse files
alugoreyfacebook-github-bot
authored andcommitted
Update ck (pytorch#782)
Summary: X-link: pytorch#3701 Pull Request resolved: facebookresearch/FBGEMM#782 Updates the CK version and re-implements kernel generation cc albanD X-link: pytorch/pytorch#144799 Reviewed By: jianyuh Differential Revision: D68613917 Pulled By: xw285cornell fbshipit-source-id: 0be7a88ef9e0245714b671d1c5cb23fc35ed4b7e
1 parent b64488a commit 2650c39

File tree

8 files changed

+198
-105
lines changed

8 files changed

+198
-105
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ struct fused_moe_args {
1515
const void* d_scale_ptr; // [e, 1, k], down scale
1616
const void*
1717
y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
18+
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
1819
void* o_ptr; // [m, k], output token (no need to do zeroing)
1920

2021
const void* topk_ids_ptr; // [tokens, topk]
@@ -50,6 +51,8 @@ struct fused_moe_traits {
5051
int activation; // 0:gelu, 1:silu
5152
int gate_only; // 0:g1u0, 1:g1u1
5253
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
54+
55+
bool local_expert_masking; // if mask experts as local expert
5356
};
5457

5558
float fused_moe(

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ at::Tensor fused_moe_impl(
100100
gate_up_scales.has_value() ? gate_up_scales->data_ptr() : nullptr,
101101
down_scales.has_value() ? down_scales->data_ptr() : nullptr,
102102
smooth_scales.has_value() ? smooth_scales->data_ptr() : nullptr,
103+
nullptr,
103104
output.data_ptr(),
104105
topk_ids.data_ptr(),
105106
topk_weights.data_ptr(),

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moesorting.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
struct fused_moesorting_trait {
1111
std::string index_type;
1212
std::string weight_type; // currently always float
13+
bool local_expert_masking; // if mask experts as local expert
1314
};
1415

1516
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs {};

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip

Lines changed: 69 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,86 +3,80 @@
33

44
#include "fused_moe.hpp"
55

6-
float fused_moe(
7-
fused_moe_traits t,
8-
fused_moe_args a,
9-
const ck_tile::stream_config& s) {
10-
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
6+
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
7+
{
8+
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
119

12-
auto o_data_bytes = [&]() {
13-
if (t.prec_o == "fp32")
14-
return 4;
15-
else if (t.prec_o == "fp16" || t.prec_o == "bf16")
16-
return 2;
17-
else if (t.prec_o == "int8" || t.prec_o == "fp8")
18-
return 1;
19-
return 1;
20-
}();
10+
auto o_data_bytes = [&]() {
11+
if(t.prec_o == "fp32")
12+
return 4;
13+
else if(t.prec_o == "fp16" || t.prec_o == "bf16")
14+
return 2;
15+
else if(t.prec_o == "int8" || t.prec_o == "fp8")
16+
return 1;
17+
return 1;
18+
}();
2119

22-
auto t0 = fused_moesorting_trait{"int32", "fp32"};
23-
auto a0 = fused_moesorting_args{
24-
a.topk_ids_ptr, // const void* p_topk_ids;
25-
a.topk_weight_ptr, // const void* p_weights;
26-
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
27-
a.sorted_weight_ptr, // void* p_sorted_weights;
28-
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
29-
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
30-
a.o_ptr, // void* p_moe_buf;
31-
a.num_tokens, // index_t tokens;
32-
a.block_m, // index_t unit_size;
33-
a.num_experts, // index_t num_experts;
34-
a.topk, // index_t topk;
35-
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
36-
};
20+
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
21+
auto a0 = fused_moesorting_args{
22+
a.topk_ids_ptr, // const void* p_topk_ids;
23+
a.topk_weight_ptr, // const void* p_weights;
24+
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
25+
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
26+
a.sorted_weight_ptr, // void* p_sorted_weights;
27+
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
28+
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
29+
a.o_ptr, // void* p_moe_buf;
30+
a.num_tokens, // index_t tokens;
31+
a.block_m, // index_t unit_size;
32+
a.num_experts, // index_t num_experts;
33+
a.topk, // index_t topk;
34+
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
35+
};
3736

38-
auto t1 = fused_moegemm_traits{
39-
t.prec_i,
40-
t.prec_w,
41-
t.prec_o,
42-
t.prec_st,
43-
t.prec_sw,
44-
t.prec_sq,
45-
t.prec_kw,
46-
t.block_m,
47-
t.activation,
48-
t.gate_only,
49-
t.fused_quant};
50-
auto a1 = fused_moegemm_args{
51-
a.a_ptr, // const void* a_ptr;
52-
a.a_scale_ptr, // const void* a_scale_ptr;
53-
a.g_ptr, // const void* g_ptr;
54-
a.d_ptr, // const void* d_ptr;
55-
a.g_scale_ptr, // const void* g_scale_ptr;
56-
a.d_scale_ptr, // const void* d_scale_ptr;
57-
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
58-
a.o_ptr, // void* o_ptr;
59-
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
60-
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
61-
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
62-
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
63-
a.hidden_size, // index_t hidden_size;
64-
a.intermediate_size, // index_t intermediate_size;
65-
a.num_tokens, // index_t num_tokens;
66-
a.num_experts, // index_t num_experts;
67-
a.topk, // index_t topk;
68-
a.stride_token // index_t stride_token;
69-
};
37+
auto t1 = fused_moegemm_traits{t.prec_i,
38+
t.prec_w,
39+
t.prec_o,
40+
t.prec_st,
41+
t.prec_sw,
42+
t.prec_sq,
43+
t.prec_kw,
44+
t.block_m,
45+
t.activation,
46+
t.gate_only,
47+
t.fused_quant};
48+
auto a1 = fused_moegemm_args{
49+
a.a_ptr, // const void* a_ptr;
50+
a.a_scale_ptr, // const void* a_scale_ptr;
51+
a.g_ptr, // const void* g_ptr;
52+
a.d_ptr, // const void* d_ptr;
53+
a.g_scale_ptr, // const void* g_scale_ptr;
54+
a.d_scale_ptr, // const void* d_scale_ptr;
55+
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
56+
a.o_ptr, // void* o_ptr;
57+
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
58+
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
59+
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
60+
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
61+
a.hidden_size, // index_t hidden_size;
62+
a.intermediate_size, // index_t intermediate_size;
63+
a.num_tokens, // index_t num_tokens;
64+
a.num_experts, // index_t num_experts;
65+
a.topk, // index_t topk;
66+
a.stride_token // index_t stride_token;
67+
};
7068

71-
float r0 = -1;
72-
float r1 = -1;
69+
float r0 = -1;
70+
float r1 = -1;
7371

74-
float r = ck_tile::launch_kernel(
75-
s,
76-
[=, &r0](const ck_tile::stream_config&) {
77-
r0 = fused_moesorting(t0, a0, s_sub);
78-
},
79-
[=, &r1](const ck_tile::stream_config&) {
80-
r1 = fused_moegemm(t1, a1, s_sub);
81-
});
72+
float r = ck_tile::launch_kernel(
73+
s,
74+
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
75+
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
8276

83-
// keep unsupported case return negative
84-
if (r0 < 0 || r1 < 0)
85-
return -1;
77+
// keep unsupported case return negative
78+
if(r0 < 0 || r1 < 0)
79+
return -1;
8680

87-
return r;
81+
return r;
8882
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api.hip

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,16 @@
55
#include "fused_moegemm.hpp"
66
#include "fused_moegemm_api_traits.hpp"
77

8-
// Note: this internal API only declare, not define here, otherwise will block
9-
// `make -j`
8+
// Note: this internal API only declare, not define here, otherwise will block `make -j`
109
template <typename Traits_>
1110
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
1211

1312
template <ck_tile::index_t... Is>
1413
using S = ck_tile::sequence<Is...>;
1514

16-
float fused_moegemm(
17-
fused_moegemm_traits t,
18-
fused_moegemm_args a,
19-
const ck_tile::stream_config& s) {
20-
// clang-format off
15+
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
16+
{
17+
// clang-format off
2118
float r = -1;
2219
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
2320
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
@@ -83,6 +80,6 @@ float fused_moegemm(
8380
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
8481
r = fused_moegemm_<t_>(s, a);
8582
}
86-
// clang-format on
87-
return r;
83+
// clang-format on
84+
return r;
8885
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_bf16_m32.hip

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
#include <ck_tile/core.hpp>
55
#include "fused_moegemm.hpp"
6-
#include "fused_moegemm_api_internal.hpp"
76
#include "fused_moegemm_api_traits.hpp"
7+
#include "fused_moegemm_api_internal.hpp"
88

99
// clang-format off
1010
template float fused_moegemm_<

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_fp16_m32.hip

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
#include <ck_tile/core.hpp>
55
#include "fused_moegemm.hpp"
6-
#include "fused_moegemm_api_internal.hpp"
76
#include "fused_moegemm_api_traits.hpp"
7+
#include "fused_moegemm_api_internal.hpp"
88

99
// clang-format off
1010
template float fused_moegemm_<

0 commit comments

Comments
 (0)