Skip to content

Commit adca585

Browse files
yuleilch-wan
andauthored
[DeepEP] Reduce routed scaling overhead (#5277)
Co-authored-by: Cheng Wan <[email protected]>
1 parent 39d9044 commit adca585

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,13 @@ def forward_deepep(
337337
topk_weights,
338338
forward_mode=forward_mode,
339339
)
340-
final_hidden_states = (
341-
self.experts(
342-
hidden_states=hidden_states,
343-
reorder_topk_ids=reorder_topk_ids,
344-
seg_indptr=seg_indptr,
345-
masked_m=masked_m,
346-
expected_m=expected_m,
347-
forward_mode=forward_mode,
348-
)
349-
* self.routed_scaling_factor
340+
final_hidden_states = self.experts(
341+
hidden_states=hidden_states,
342+
reorder_topk_ids=reorder_topk_ids,
343+
seg_indptr=seg_indptr,
344+
masked_m=masked_m,
345+
expected_m=expected_m,
346+
forward_mode=forward_mode,
350347
)
351348
if self.ep_size > 1:
352349
final_hidden_states = self.deepep_dispatcher.combine(
@@ -355,6 +352,8 @@ def forward_deepep(
355352
topk_weights,
356353
forward_mode,
357354
)
355+
final_hidden_states *= self.routed_scaling_factor
356+
358357
if shared_output is not None:
359358
final_hidden_states = final_hidden_states + shared_output
360359

0 commit comments

Comments
 (0)