12
12
# limitations under the License.
13
13
# ==============================================================================
14
14
15
+ import math
15
16
import os
16
17
from typing import Callable , Optional
17
18
25
26
_is_cuda = is_cuda ()
26
27
_is_hip = is_hip ()
27
28
29
+ if _is_cuda :
30
+ from sgl_kernel import moe_fused_gate
28
31
29
32
expert_distribution_recorder = ExpertDistributionRecorder ()
30
33
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
209
212
return topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
210
213
211
214
215
+ def is_power_of_two (n ):
216
+ return n > 0 and math .log2 (n ).is_integer ()
217
+
218
+
212
219
def biased_grouped_topk (
213
220
hidden_states : torch .Tensor ,
214
221
gating_output : torch .Tensor ,
@@ -220,23 +227,37 @@ def biased_grouped_topk(
220
227
compiled : bool = True ,
221
228
n_share_experts_fusion : int = 0 ,
222
229
):
223
- biased_grouped_topk_fn = (
224
- torch .compile (
225
- biased_grouped_topk_impl , dynamic = True , backend = get_compiler_backend ()
230
+ # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
231
+ if (
232
+ _is_cuda
233
+ and n_share_experts_fusion == 0
234
+ and is_power_of_two (correction_bias .shape [0 ])
235
+ ):
236
+ return moe_fused_gate (
237
+ gating_output ,
238
+ correction_bias ,
239
+ num_expert_group ,
240
+ topk_group ,
241
+ topk ,
242
+ )
243
+ else :
244
+ biased_grouped_topk_fn = (
245
+ torch .compile (
246
+ biased_grouped_topk_impl , dynamic = True , backend = get_compiler_backend ()
247
+ )
248
+ if compiled
249
+ else biased_grouped_topk_impl
250
+ )
251
+ return biased_grouped_topk_fn (
252
+ hidden_states ,
253
+ gating_output ,
254
+ correction_bias ,
255
+ topk ,
256
+ renormalize ,
257
+ num_expert_group ,
258
+ topk_group ,
259
+ n_share_experts_fusion = n_share_experts_fusion ,
226
260
)
227
- if compiled
228
- else biased_grouped_topk_impl
229
- )
230
- return biased_grouped_topk_fn (
231
- hidden_states ,
232
- gating_output ,
233
- correction_bias ,
234
- topk ,
235
- renormalize ,
236
- num_expert_group ,
237
- topk_group ,
238
- n_share_experts_fusion = n_share_experts_fusion ,
239
- )
240
261
241
262
242
263
def select_experts (
0 commit comments