Skip to content

Commit c0bcbd2

Browse files
fzyzcjylifuhuang
authored andcommitted
Overlap shared expert and routed expert computations (sgl-project#5121)
1 parent 6d0de77 commit c0bcbd2

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

python/sglang/srt/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
)
9191
self.act_fn = SiluAndMul()
9292

93-
def forward(self, x):
93+
def forward(self, x, forward_batch=None):
9494
gate_up, _ = self.gate_up_proj(x)
9595
x = self.act_fn(gate_up)
9696
x, _ = self.down_proj(x)

python/sglang/srt/models/llama4.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
from sglang.srt.layers.rotary_embedding import get_rope
4747
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
4848
from sglang.srt.managers.schedule_batch import global_server_args_dict
49-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
49+
from sglang.srt.model_executor.forward_batch_info import (
50+
ForwardBatch,
51+
ForwardMode,
52+
PPProxyTensors,
53+
)
5054
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
5155
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
5256

@@ -81,6 +85,7 @@ def __init__(
8185
super().__init__()
8286
self.tp_size = get_tensor_model_parallel_world_size()
8387
self.top_k = config.num_experts_per_tok
88+
self.device_module = torch.get_device_module()
8489

8590
intermediate_size_moe = config.intermediate_size
8691
self.router = ReplicatedLinear(
@@ -113,20 +118,61 @@ def __init__(
113118
reduce_results=False, # We need to do scatter before reduce
114119
)
115120

116-
def forward(self, hidden_states):
121+
def forward(self, hidden_states, forward_batch: ForwardBatch):
122+
shared_out, routed_out = self._forward_core(
123+
hidden_states, forward_batch.forward_mode
124+
)
125+
126+
out_aD = routed_out + shared_out
127+
128+
if self.tp_size > 1:
129+
out_aD = tensor_model_parallel_all_reduce(out_aD)
130+
131+
return out_aD
132+
133+
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
134+
if hidden_states.shape[0] < 4:
135+
return self._forward_core_shared_routed_overlap(hidden_states)
136+
else:
137+
return self._forward_core_normal(hidden_states)
138+
139+
def _forward_core_normal(self, hidden_states):
117140
# router_scores: [num_tokens, num_experts]
118141
router_logits, _ = self.router(hidden_states)
119142
shared_out = self.shared_expert(hidden_states)
120143
routed_out = self.experts(
121144
hidden_states=hidden_states,
122145
router_logits=router_logits,
123146
)
124-
out_aD = routed_out + shared_out
147+
return shared_out, routed_out
125148

126-
if self.tp_size > 1:
127-
out_aD = tensor_model_parallel_all_reduce(out_aD)
149+
def _forward_core_shared_routed_overlap(self, hidden_states):
150+
alt_stream = _get_or_create_alt_stream(self.device_module)
128151

129-
return out_aD
152+
alt_stream.wait_stream(self.device_module.current_stream())
153+
154+
shared_out = self.shared_expert(hidden_states)
155+
156+
with self.device_module.stream(alt_stream):
157+
# router_scores: [num_tokens, num_experts]
158+
router_logits, _ = self.router(hidden_states)
159+
routed_out = self.experts(
160+
hidden_states=hidden_states,
161+
router_logits=router_logits,
162+
)
163+
self.device_module.current_stream().wait_stream(alt_stream)
164+
165+
return shared_out, routed_out
166+
167+
168+
_alt_stream = None
169+
170+
171+
def _get_or_create_alt_stream(device_module):
172+
global _alt_stream
173+
if _alt_stream is None:
174+
_alt_stream = device_module.Stream()
175+
return _alt_stream
130176

131177

132178
class Llama4Attention(nn.Module):
@@ -380,7 +426,7 @@ def forward(
380426
)
381427

382428
# Fully Connected
383-
hidden_states = self.feed_forward(hidden_states)
429+
hidden_states = self.feed_forward(hidden_states, forward_batch)
384430

385431
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
386432
# Scatter

0 commit comments

Comments
 (0)