Skip to content

Commit f613d14

Browse files
Ying1123lifuhuang
authored andcommitted
[PP] Fix init_memory_pool desync & add PP for mixtral (sgl-project#6223)
1 parent dab5072 commit f613d14

File tree

8 files changed

+179
-47
lines changed

8 files changed

+179
-47
lines changed

.github/workflows/pr-test.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,18 @@ jobs:
229229
cd test/srt
230230
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
231231
232+
- name: Benchmark offline decode throughput (PP=2)
233+
timeout-minutes: 10
234+
run: |
235+
cd test/srt
236+
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_offline_throughput_default_decode
237+
238+
- name: Benchmark offline prefill throughput (PP=2)
239+
timeout-minutes: 10
240+
run: |
241+
cd test/srt
242+
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_long_context_prefill
243+
232244
accuracy-test-1-gpu:
233245
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
234246
github.event.pull_request.draft == false

python/sglang/srt/managers/schedule_policy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,9 +468,6 @@ def add_one_req(
468468
return AddReqResult.OTHER
469469

470470
with self._lock_node(req.last_node):
471-
if total_tokens > self.rem_total_tokens:
472-
return AddReqResult.NO_TOKEN
473-
474471
if (
475472
enable_hierarchical_cache
476473
and req.last_node_global is not None

python/sglang/srt/managers/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def event_loop_pp(self):
719719
server_is_idle = False
720720
result = self.run_batch(self.cur_batch)
721721

722-
# send the outputs to the next step
722+
# (last rank) send the outputs to the next step
723723
if self.pp_group.is_last_rank:
724724
if self.cur_batch:
725725
next_token_ids, bids[mb_id] = (
@@ -759,18 +759,18 @@ def event_loop_pp(self):
759759
self.process_batch_result(mbs[next_mb_id], output_result)
760760
last_mbs[next_mb_id] = mbs[next_mb_id]
761761

762-
# carry the outputs to the next stage
762+
# (not last rank)
763763
if not self.pp_group.is_last_rank:
764764
if self.cur_batch:
765765
bids[mb_id] = result.bid
766+
# carry the outputs to the next stage
767+
# send the outputs from the last round to let the next stage worker run post processing
766768
if pp_outputs:
767-
# send the outputs from the last round to let the next stage worker run post processing
768769
self.pp_group.send_tensor_dict(
769770
pp_outputs.tensors,
770771
all_gather_group=self.attn_tp_group,
771772
)
772773

773-
if not self.pp_group.is_last_rank:
774774
# send out reqs to the next stage
775775
dp_offset = self.dp_rank * self.attn_tp_size
776776
if self.attn_tp_rank == 0:

python/sglang/srt/model_executor/model_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
3333
from sglang.srt.distributed import (
3434
get_tp_group,
35+
get_world_group,
3536
init_distributed_environment,
3637
initialize_model_parallel,
3738
set_custom_all_reduce,
@@ -404,7 +405,10 @@ def init_torch_distributed(self):
404405
)
405406

406407
min_per_gpu_memory = get_available_gpu_memory(
407-
self.device, self.gpu_id, distributed=self.tp_size > 1
408+
self.device,
409+
self.gpu_id,
410+
distributed=get_world_group().world_size > 1,
411+
cpu_group=get_world_group().cpu_group,
408412
)
409413
self.tp_group = get_tp_group()
410414
self.attention_tp_group = get_attention_tp_group()
@@ -716,7 +720,10 @@ def init_lora_manager(self):
716720

717721
def profile_max_num_token(self, total_gpu_memory: int):
718722
available_gpu_memory = get_available_gpu_memory(
719-
self.device, self.gpu_id, distributed=self.tp_size > 1
723+
self.device,
724+
self.gpu_id,
725+
distributed=get_world_group().world_size > 1,
726+
cpu_group=get_world_group().cpu_group,
720727
)
721728
if self.use_mla_backend:
722729
num_layers = (

python/sglang/srt/models/mixtral.py

Lines changed: 98 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
1717
"""Inference-only Mixtral model."""
1818

19-
from typing import Iterable, Optional, Tuple
19+
import logging
20+
from typing import Iterable, Optional, Tuple, Union
2021

2122
import torch
2223
from torch import nn
2324
from transformers import MixtralConfig
2425

2526
from sglang.srt.distributed import (
27+
get_pp_group,
2628
get_tensor_model_parallel_world_size,
2729
tensor_model_parallel_all_reduce,
2830
)
@@ -38,14 +40,17 @@
3840
from sglang.srt.layers.quantization.base_config import QuantizationConfig
3941
from sglang.srt.layers.radix_attention import RadixAttention
4042
from sglang.srt.layers.rotary_embedding import get_rope
43+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
4144
from sglang.srt.layers.vocab_parallel_embedding import (
4245
ParallelLMHead,
4346
VocabParallelEmbedding,
4447
)
4548
from sglang.srt.managers.schedule_batch import global_server_args_dict
46-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
4750
from sglang.srt.model_loader.weight_utils import default_weight_loader
48-
from sglang.srt.utils import add_prefix
51+
from sglang.srt.utils import add_prefix, make_layers
52+
53+
logger = logging.getLogger(__name__)
4954

5055

5156
class MixtralMoE(nn.Module):
@@ -257,43 +262,68 @@ def __init__(
257262
super().__init__()
258263
self.padding_idx = config.pad_token_id
259264
self.vocab_size = config.vocab_size
265+
self.pp_group = get_pp_group()
260266

261-
self.embed_tokens = VocabParallelEmbedding(
262-
config.vocab_size,
263-
config.hidden_size,
264-
prefix=add_prefix("embed_tokens", prefix),
265-
)
266-
self.layers = nn.ModuleList(
267-
[
268-
MixtralDecoderLayer(
269-
config,
270-
i,
271-
quant_config=quant_config,
272-
prefix=add_prefix(f"layers.{i}", prefix),
273-
)
274-
for i in range(config.num_hidden_layers)
275-
]
267+
if self.pp_group.is_first_rank:
268+
self.embed_tokens = VocabParallelEmbedding(
269+
config.vocab_size,
270+
config.hidden_size,
271+
prefix=add_prefix("embed_tokens", prefix),
272+
)
273+
else:
274+
self.embed_tokens = PPMissingLayer()
275+
276+
self.layers, self.start_layer, self.end_layer = make_layers(
277+
config.num_hidden_layers,
278+
lambda idx, prefix: MixtralDecoderLayer(
279+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
280+
),
281+
pp_rank=self.pp_group.rank_in_group,
282+
pp_size=self.pp_group.world_size,
283+
prefix="layers",
284+
return_tuple=True,
276285
)
277-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286+
287+
if self.pp_group.is_last_rank:
288+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289+
else:
290+
self.norm = PPMissingLayer(return_tuple=True)
278291

279292
def forward(
280293
self,
281294
input_ids: torch.Tensor,
282295
positions: torch.Tensor,
283296
forward_batch: ForwardBatch,
284297
input_embeds: torch.Tensor = None,
285-
) -> torch.Tensor:
286-
if input_embeds is None:
287-
hidden_states = self.embed_tokens(input_ids)
298+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
299+
) -> Union[torch.Tensor, PPProxyTensors]:
300+
if self.pp_group.is_first_rank:
301+
if input_embeds is None:
302+
hidden_states = self.embed_tokens(input_ids)
303+
else:
304+
hidden_states = input_embeds
305+
residual = None
288306
else:
289-
hidden_states = input_embeds
290-
residual = None
291-
for i in range(len(self.layers)):
307+
assert pp_proxy_tensors is not None
308+
hidden_states = pp_proxy_tensors["hidden_states"]
309+
residual = pp_proxy_tensors["residual"]
310+
311+
for i in range(self.start_layer, self.end_layer):
292312
layer = self.layers[i]
293313
hidden_states, residual = layer(
294314
positions, hidden_states, forward_batch, residual
295315
)
296-
hidden_states, _ = self.norm(hidden_states, residual)
316+
317+
if not self.pp_group.is_last_rank:
318+
return PPProxyTensors(
319+
{
320+
"hidden_states": hidden_states,
321+
"residual": residual,
322+
}
323+
)
324+
else:
325+
hidden_states, _ = self.norm(hidden_states, residual)
326+
297327
return hidden_states
298328

299329

@@ -306,6 +336,7 @@ def __init__(
306336
prefix: str = "",
307337
) -> None:
308338
super().__init__()
339+
self.pp_group = get_pp_group()
309340
self.config = config
310341
self.quant_config = quant_config
311342
self.model = MixtralModel(
@@ -322,12 +353,31 @@ def forward(
322353
positions: torch.Tensor,
323354
forward_batch: ForwardBatch,
324355
input_embeds: torch.Tensor = None,
356+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
325357
) -> torch.Tensor:
326-
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
327-
return self.logits_processor(
328-
input_ids, hidden_states, self.lm_head, forward_batch
358+
hidden_states = self.model(
359+
input_ids,
360+
positions,
361+
forward_batch,
362+
input_embeds,
363+
pp_proxy_tensors=pp_proxy_tensors,
329364
)
330365

366+
if self.pp_group.is_last_rank:
367+
return self.logits_processor(
368+
input_ids, hidden_states, self.lm_head, forward_batch
369+
)
370+
else:
371+
return hidden_states
372+
373+
@property
374+
def start_layer(self):
375+
return self.model.start_layer
376+
377+
@property
378+
def end_layer(self):
379+
return self.model.end_layer
380+
331381
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
332382
stacked_params_mapping = [
333383
# (param_name, shard_name, shard_id)
@@ -348,6 +398,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
348398

349399
params_dict = dict(self.named_parameters())
350400
for name, loaded_weight in weights:
401+
layer_id = get_layer_id(name)
402+
if (
403+
layer_id is not None
404+
and hasattr(self.model, "start_layer")
405+
and (
406+
layer_id < self.model.start_layer
407+
or layer_id >= self.model.end_layer
408+
)
409+
):
410+
continue
411+
351412
if "rotary_emb.inv_freq" in name:
352413
continue
353414

@@ -398,11 +459,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
398459
if name is None:
399460
continue
400461

401-
param = params_dict[name]
402-
weight_loader = getattr(
403-
param, "weight_loader", default_weight_loader
404-
)
405-
weight_loader(param, loaded_weight)
462+
if name in params_dict.keys():
463+
param = params_dict[name]
464+
weight_loader = getattr(
465+
param, "weight_loader", default_weight_loader
466+
)
467+
weight_loader(param, loaded_weight)
468+
else:
469+
logger.warning(f"Parameter {name} not found in params_dict")
406470

407471

408472
EntryClass = MixtralForCausalLM

python/sglang/srt/server_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,12 @@ def __post_init__(self):
347347
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
348348
)
349349

350+
if self.pp_size > 1:
351+
self.disable_overlap_schedule = True
352+
logger.warning(
353+
"Pipeline parallelism is incompatible with overlap schedule."
354+
)
355+
350356
# Speculative Decoding
351357
if self.speculative_algorithm == "NEXTN":
352358
# NEXTN shares the same implementation of EAGLE

python/sglang/srt/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,9 @@ def inner_func(*args, **kwargs):
282282
return wrapper
283283

284284

285-
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
285+
def get_available_gpu_memory(
286+
device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
287+
):
286288
"""
287289
Get available memory for cuda:gpu_id device.
288290
When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
344346
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
345347

346348
if distributed:
347-
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
348-
torch.device(device, gpu_id)
349+
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
350+
torch.distributed.all_reduce(
351+
tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
349352
)
350-
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
351353
free_gpu_memory = tensor.item()
352354

353355
return free_gpu_memory / (1 << 30)

test/srt/test_bench_serving.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,50 @@ def test_moe_offline_throughput_without_radix_cache(self):
272272
else:
273273
self.assertGreater(res["output_throughput"], 2200)
274274

275+
def test_pp_offline_throughput_default_decode(self):
276+
res = run_bench_serving(
277+
model=DEFAULT_MOE_MODEL_NAME_FOR_TEST,
278+
num_prompts=1000,
279+
request_rate=float("inf"),
280+
random_input_len=1,
281+
random_output_len=1024,
282+
other_server_args=["--pp", "2"],
283+
need_warmup=True,
284+
seed=42,
285+
)
286+
287+
if is_in_ci():
288+
write_github_step_summary(
289+
f"### test_pp_offline_throughput_default_decode\n"
290+
f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
291+
)
292+
self.assertGreater(res["output_throughput"], 7500)
293+
294+
def test_pp_long_context_prefill(self):
295+
res = run_bench_serving(
296+
model="meta-llama/Llama-3.3-70B-Instruct",
297+
num_prompts=4,
298+
request_rate=float("inf"),
299+
random_input_len=128000,
300+
random_output_len=1,
301+
dataset_name="random",
302+
other_server_args=[
303+
"--quantization",
304+
"fp8",
305+
"--pp",
306+
2,
307+
],
308+
need_warmup=False,
309+
seed=42,
310+
)
311+
312+
if is_in_ci():
313+
write_github_step_summary(
314+
f"### test_pp_long_context_latency_prefill\n"
315+
f'input_throughput: {res["input_throughput"]:.2f} ms\n'
316+
)
317+
self.assertGreater(res["input_throughput"], 4000)
318+
275319

276320
if __name__ == "__main__":
277321
unittest.main()

0 commit comments

Comments
 (0)