Skip to content

apply fused moe gate in ds v3/r1 #5371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================

import math
import os
from typing import Callable, Optional

Expand All @@ -25,6 +26,8 @@
_is_cuda = is_cuda()
_is_hip = is_hip()

if _is_cuda:
from sgl_kernel import moe_fused_gate

expert_distribution_recorder = ExpertDistributionRecorder()

Expand Down Expand Up @@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()


def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand All @@ -220,23 +227,37 @@ def biased_grouped_topk(
compiled: bool = True,
n_share_experts_fusion: int = 0,
):
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
if (
_is_cuda
and n_share_experts_fusion == 0
and is_power_of_two(correction_bias.shape[0])
):
return moe_fused_gate(
gating_output,
correction_bias,
num_expert_group,
topk_group,
topk,
)
else:
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)


def select_experts(
Expand Down
Loading