|
46 | 46 | from sglang.srt.layers.rotary_embedding import get_rope
|
47 | 47 | from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48 | 48 | from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49 |
| -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors |
| 49 | +from sglang.srt.model_executor.forward_batch_info import ( |
| 50 | + ForwardBatch, |
| 51 | + ForwardMode, |
| 52 | + PPProxyTensors, |
| 53 | +) |
50 | 54 | from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
51 | 55 | from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
52 | 56 |
|
@@ -81,6 +85,7 @@ def __init__(
|
81 | 85 | super().__init__()
|
82 | 86 | self.tp_size = get_tensor_model_parallel_world_size()
|
83 | 87 | self.top_k = config.num_experts_per_tok
|
| 88 | + self.device_module = torch.get_device_module() |
84 | 89 |
|
85 | 90 | intermediate_size_moe = config.intermediate_size
|
86 | 91 | self.router = ReplicatedLinear(
|
@@ -113,20 +118,61 @@ def __init__(
|
113 | 118 | reduce_results=False, # We need to do scatter before reduce
|
114 | 119 | )
|
115 | 120 |
|
116 |
| - def forward(self, hidden_states): |
| 121 | + def forward(self, hidden_states, forward_batch: ForwardBatch): |
| 122 | + shared_out, routed_out = self._forward_core( |
| 123 | + hidden_states, forward_batch.forward_mode |
| 124 | + ) |
| 125 | + |
| 126 | + out_aD = routed_out + shared_out |
| 127 | + |
| 128 | + if self.tp_size > 1: |
| 129 | + out_aD = tensor_model_parallel_all_reduce(out_aD) |
| 130 | + |
| 131 | + return out_aD |
| 132 | + |
| 133 | + def _forward_core(self, hidden_states, forward_mode: ForwardMode): |
| 134 | + if hidden_states.shape[0] < 4: |
| 135 | + return self._forward_core_shared_routed_overlap(hidden_states) |
| 136 | + else: |
| 137 | + return self._forward_core_normal(hidden_states) |
| 138 | + |
| 139 | + def _forward_core_normal(self, hidden_states): |
117 | 140 | # router_scores: [num_tokens, num_experts]
|
118 | 141 | router_logits, _ = self.router(hidden_states)
|
119 | 142 | shared_out = self.shared_expert(hidden_states)
|
120 | 143 | routed_out = self.experts(
|
121 | 144 | hidden_states=hidden_states,
|
122 | 145 | router_logits=router_logits,
|
123 | 146 | )
|
124 |
| - out_aD = routed_out + shared_out |
| 147 | + return shared_out, routed_out |
125 | 148 |
|
126 |
| - if self.tp_size > 1: |
127 |
| - out_aD = tensor_model_parallel_all_reduce(out_aD) |
| 149 | + def _forward_core_shared_routed_overlap(self, hidden_states): |
| 150 | + alt_stream = _get_or_create_alt_stream(self.device_module) |
128 | 151 |
|
129 |
| - return out_aD |
| 152 | + alt_stream.wait_stream(self.device_module.current_stream()) |
| 153 | + |
| 154 | + shared_out = self.shared_expert(hidden_states) |
| 155 | + |
| 156 | + with self.device_module.stream(alt_stream): |
| 157 | + # router_scores: [num_tokens, num_experts] |
| 158 | + router_logits, _ = self.router(hidden_states) |
| 159 | + routed_out = self.experts( |
| 160 | + hidden_states=hidden_states, |
| 161 | + router_logits=router_logits, |
| 162 | + ) |
| 163 | + self.device_module.current_stream().wait_stream(alt_stream) |
| 164 | + |
| 165 | + return shared_out, routed_out |
| 166 | + |
| 167 | + |
| 168 | +_alt_stream = None |
| 169 | + |
| 170 | + |
| 171 | +def _get_or_create_alt_stream(device_module): |
| 172 | + global _alt_stream |
| 173 | + if _alt_stream is None: |
| 174 | + _alt_stream = device_module.Stream() |
| 175 | + return _alt_stream |
130 | 176 |
|
131 | 177 |
|
132 | 178 | class Llama4Attention(nn.Module):
|
@@ -380,7 +426,7 @@ def forward(
|
380 | 426 | )
|
381 | 427 |
|
382 | 428 | # Fully Connected
|
383 |
| - hidden_states = self.feed_forward(hidden_states) |
| 429 | + hidden_states = self.feed_forward(hidden_states, forward_batch) |
384 | 430 |
|
385 | 431 | # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
386 | 432 | # Scatter
|
|
0 commit comments