Skip to content

Commit 957d30d

Browse files
BBufzhyncs
authored andcommitted
apply fused moe gate in ds v3/r1 (sgl-project#5371)
Co-authored-by: Yineng Zhang <[email protected]>
1 parent 307b76a commit 957d30d

File tree

1 file changed

+37
-16
lines changed
  • python/sglang/srt/layers/moe

1 file changed

+37
-16
lines changed

python/sglang/srt/layers/moe/topk.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# limitations under the License.
1313
# ==============================================================================
1414

15+
import math
1516
import os
1617
from typing import Callable, Optional
1718

@@ -25,6 +26,8 @@
2526
_is_cuda = is_cuda()
2627
_is_hip = is_hip()
2728

29+
if _is_cuda:
30+
from sgl_kernel import moe_fused_gate
2831

2932
expert_distribution_recorder = ExpertDistributionRecorder()
3033

@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
209212
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
210213

211214

215+
def is_power_of_two(n):
216+
return n > 0 and math.log2(n).is_integer()
217+
218+
212219
def biased_grouped_topk(
213220
hidden_states: torch.Tensor,
214221
gating_output: torch.Tensor,
@@ -220,23 +227,37 @@ def biased_grouped_topk(
220227
compiled: bool = True,
221228
n_share_experts_fusion: int = 0,
222229
):
223-
biased_grouped_topk_fn = (
224-
torch.compile(
225-
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
230+
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
231+
if (
232+
_is_cuda
233+
and n_share_experts_fusion == 0
234+
and is_power_of_two(correction_bias.shape[0])
235+
):
236+
return moe_fused_gate(
237+
gating_output,
238+
correction_bias,
239+
num_expert_group,
240+
topk_group,
241+
topk,
242+
)
243+
else:
244+
biased_grouped_topk_fn = (
245+
torch.compile(
246+
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
247+
)
248+
if compiled
249+
else biased_grouped_topk_impl
250+
)
251+
return biased_grouped_topk_fn(
252+
hidden_states,
253+
gating_output,
254+
correction_bias,
255+
topk,
256+
renormalize,
257+
num_expert_group,
258+
topk_group,
259+
n_share_experts_fusion=n_share_experts_fusion,
226260
)
227-
if compiled
228-
else biased_grouped_topk_impl
229-
)
230-
return biased_grouped_topk_fn(
231-
hidden_states,
232-
gating_output,
233-
correction_bias,
234-
topk,
235-
renormalize,
236-
num_expert_group,
237-
topk_group,
238-
n_share_experts_fusion=n_share_experts_fusion,
239-
)
240261

241262

242263
def select_experts(

0 commit comments

Comments
 (0)