Skip to content

Commit ca00065

Browse files
committed
Revert "[DeepEP] Reduce routed scaling overhead"
This reverts commit 2a4fc7e.
1 parent 2a4fc7e commit ca00065

File tree

2 files changed

+10
-86
lines changed

2 files changed

+10
-86
lines changed

python/sglang/srt/layers/moe/ep_moe/kernels.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -707,71 +707,3 @@ def grouped_gemm_triton(
707707
**config,
708708
)
709709
return c
710-
711-
712-
@triton.jit
713-
def _masked_scale_kernel(
714-
in_ptr,
715-
in_stride_0,
716-
in_stride_1,
717-
out_ptr,
718-
out_stride_0,
719-
out_stride_1,
720-
masked_m_ptr,
721-
scale: float,
722-
D: tl.constexpr,
723-
BLOCK_D: tl.constexpr,
724-
BLOCK_NUM_PER_EXPERT: tl.constexpr,
725-
):
726-
pid_expert = tl.program_id(0)
727-
pid_token = tl.program_id(1)
728-
pid_dim = tl.program_id(2)
729-
730-
TOKENS_CUR_EXPERT = tl.load(masked_m_ptr + pid_expert).to(tl.int32)
731-
732-
offs_in_d = pid_dim * BLOCK_D + tl.arange(0, BLOCK_D)
733-
mask = offs_in_d < D
734-
735-
in_ptr_offs = in_ptr + pid_expert * in_stride_0 + offs_in_d
736-
out_ptr_offs = out_ptr + pid_expert * out_stride_0 + offs_in_d
737-
738-
for token_index in tl.range(pid_token, TOKENS_CUR_EXPERT, BLOCK_NUM_PER_EXPERT):
739-
v = tl.load(in_ptr_offs + token_index * in_stride_1, mask)
740-
tl.store(out_ptr_offs + token_index * out_stride_1, v * scale, mask)
741-
742-
743-
def masked_scale(
744-
x: torch.Tensor, masked_m: torch.Tensor, scale: float, out: torch.Tensor = None
745-
):
746-
assert x.stride(-1) == 1
747-
748-
if out is None:
749-
out = torch.empty_like(x)
750-
751-
expert_num = len(masked_m)
752-
753-
if expert_num < 4:
754-
BLOCK_NUM_PER_EXPERT = 64
755-
else:
756-
BLOCK_NUM_PER_EXPERT = 32
757-
758-
BLOCK_D = 512
759-
760-
grid = (len(masked_m), BLOCK_NUM_PER_EXPERT, triton.cdiv(x.size(-1), BLOCK_D))
761-
762-
_masked_scale_kernel[grid](
763-
x,
764-
x.stride(0),
765-
x.stride(1),
766-
out,
767-
out.stride(0),
768-
out.stride(1),
769-
masked_m,
770-
scale,
771-
x.size(-1),
772-
BLOCK_D=BLOCK_D,
773-
BLOCK_NUM_PER_EXPERT=BLOCK_NUM_PER_EXPERT,
774-
num_warps=1,
775-
num_stages=6,
776-
)
777-
return out

python/sglang/srt/models/deepseek_v2.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
RowParallelLinear,
5050
)
5151
from sglang.srt.layers.logits_processor import LogitsProcessor
52-
from sglang.srt.layers.moe.ep_moe.kernels import masked_scale
5352
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
5453
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
5554
from sglang.srt.layers.moe.topk import select_experts
@@ -338,24 +337,17 @@ def forward_deepep(
338337
topk_weights,
339338
forward_mode=forward_mode,
340339
)
341-
final_hidden_states = self.experts(
342-
hidden_states=hidden_states,
343-
reorder_topk_ids=reorder_topk_ids,
344-
seg_indptr=seg_indptr,
345-
masked_m=masked_m,
346-
expected_m=expected_m,
347-
forward_mode=forward_mode,
348-
)
349-
if self.ep_size > 1 and masked_m is not None:
350-
final_hidden_states = masked_scale(
351-
final_hidden_states,
352-
masked_m,
353-
self.routed_scaling_factor,
354-
final_hidden_states,
340+
final_hidden_states = (
341+
self.experts(
342+
hidden_states=hidden_states,
343+
reorder_topk_ids=reorder_topk_ids,
344+
seg_indptr=seg_indptr,
345+
masked_m=masked_m,
346+
expected_m=expected_m,
347+
forward_mode=forward_mode,
355348
)
356-
else:
357-
final_hidden_states *= self.routed_scaling_factor
358-
349+
* self.routed_scaling_factor
350+
)
359351
if self.ep_size > 1:
360352
final_hidden_states = self.deepep_dispatcher.combine(
361353
final_hidden_states,

0 commit comments

Comments
 (0)