Skip to content

Commit fc75032

Browse files
ShangmingCaiHanHan009527
authored andcommitted
[PD] Add support for different TP sizes per DP rank (sgl-project#5922)
Signed-off-by: Shangming Cai <[email protected]>
1 parent 1436242 commit fc75032

File tree

8 files changed

+472
-132
lines changed

8 files changed

+472
-132
lines changed

python/sglang/srt/disaggregation/base/conn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
args: KVArgs,
3838
disaggregation_mode: DisaggregationMode,
3939
server_args: ServerArgs,
40+
is_mla_backend: Optional[bool] = False,
4041
): ...
4142

4243

python/sglang/srt/disaggregation/decode.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ReqToMetadataIdxAllocator,
3838
TransferBackend,
3939
get_kv_class,
40+
is_mla_backend,
4041
kv_to_page_indices,
4142
poll_and_all_reduce,
4243
)
@@ -86,6 +87,7 @@ def __init__(
8687
self.req_to_token_pool = req_to_token_pool
8788
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
8889
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
90+
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
8991
self.aux_dtype = aux_dtype
9092
self.metadata_buffers = metadata_buffers
9193
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
@@ -131,7 +133,10 @@ def _init_kv_manager(self) -> BaseKVManager:
131133
kv_args.gpu_id = self.scheduler.gpu_id
132134
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
133135
kv_manager = kv_manager_class(
134-
kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
136+
kv_args,
137+
DisaggregationMode.DECODE,
138+
self.scheduler.server_args,
139+
self.is_mla_backend,
135140
)
136141
return kv_manager
137142

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 246 additions & 119 deletions
Large diffs are not rendered by default.

python/sglang/srt/disaggregation/nixl/conn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
args: KVArgs,
8484
disaggregation_mode: DisaggregationMode,
8585
server_args: ServerArgs,
86+
is_mla_backend: Optional[bool] = False,
8687
):
8788
try:
8889
from nixl._api import nixl_agent

python/sglang/srt/disaggregation/prefill.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ReqToMetadataIdxAllocator,
3434
TransferBackend,
3535
get_kv_class,
36+
is_mla_backend,
3637
kv_to_page_indices,
3738
kv_to_page_num,
3839
poll_and_all_reduce,
@@ -68,6 +69,7 @@ def __init__(
6869
scheduler: Scheduler,
6970
):
7071
self.token_to_kv_pool = token_to_kv_pool
72+
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
7173
self.aux_dtype = aux_dtype
7274

7375
self.metadata_buffers = metadata_buffers
@@ -114,7 +116,7 @@ def _init_kv_manager(self) -> BaseKVManager:
114116
kv_args,
115117
DisaggregationMode.PREFILL,
116118
self.scheduler.server_args,
117-
# self.scheduler.disagg_launch_done,
119+
self.is_mla_backend,
118120
)
119121
return kv_manager
120122

python/sglang/srt/disaggregation/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,47 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
105105
def kv_to_page_num(num_kv_indices: int, page_size: int):
106106
# ceil(num_kv_indices / page_size)
107107
return (num_kv_indices + page_size - 1) // page_size
108+
109+
110+
@dataclasses.dataclass
111+
class PDRegistryRequest:
112+
"""A request to register a machine itself to the LB."""
113+
114+
mode: str
115+
registry_url: str
116+
bootstrap_port: Optional[int] = None
117+
118+
def __post_init__(self):
119+
if self.mode == "prefill" and self.bootstrap_port is None:
120+
raise ValueError("Bootstrap port must be set in PREFILL mode.")
121+
elif self.mode == "decode" and self.bootstrap_port is not None:
122+
raise ValueError("Bootstrap port must not be set in DECODE mode.")
123+
elif self.mode not in ["prefill", "decode"]:
124+
raise ValueError(
125+
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
126+
)
127+
128+
129+
def register_disaggregation_server(
130+
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
131+
):
132+
boostrap_port = bootstrap_port if mode == "prefill" else None
133+
registry_request = PDRegistryRequest(
134+
mode=mode,
135+
registry_url=f"http://{get_ip()}:{server_port}",
136+
bootstrap_port=boostrap_port,
137+
)
138+
res = requests.post(
139+
f"{pdlb_url}/register",
140+
json=dataclasses.asdict(registry_request),
141+
)
142+
if res.status_code != 200:
143+
warnings.warn(
144+
f"Failed to register disaggregation server: {res.status_code} {res.text}"
145+
)
146+
147+
148+
def is_mla_backend(target_kv_pool) -> bool:
149+
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
150+
151+
return isinstance(target_kv_pool, MLATokenToKVPool)

test/srt/run_suite.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,26 @@ class TestFile:
8383
TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
8484
],
8585
"per-commit-2-gpu": [
86-
TestFile("models/lora/test_lora_tp.py", 300),
87-
TestFile("test_data_parallelism.py", 90),
88-
TestFile("test_dp_attention.py", 90),
89-
TestFile("test_expert_distribution.py", 100),
90-
TestFile("test_eplb.py", 100),
91-
TestFile("test_mla_tp.py", 420),
92-
TestFile("test_moe_ep.py", 220),
93-
TestFile("test_patch_torch.py", 30),
94-
TestFile("test_update_weights_from_distributed.py", 100),
95-
TestFile("test_verl_engine.py", 100),
96-
TestFile("test_two_batch_overlap.py", 100),
86+
TestFile("models/lora/test_lora_tp.py", 116),
87+
TestFile("test_data_parallelism.py", 73),
88+
TestFile("test_dp_attention.py", 137),
89+
TestFile("test_mla_tp.py", 170),
90+
TestFile("test_moe_ep.py", 181),
91+
TestFile("test_patch_torch.py", 19),
92+
TestFile("test_update_weights_from_distributed.py", 103),
93+
TestFile("test_verl_engine.py", 64),
94+
],
95+
"per-commit-8-gpu": [
96+
# Disabled deepep tests temporarily because it takes too much time.
97+
# TODO: re-enable them after reducing the test time with compilation cache and smaller models.
98+
# TestFile("test_deepep_intranode.py", 50),
99+
# TestFile("test_deepep_low_latency.py", 50),
100+
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
101+
TestFile("test_disaggregation.py", 210),
102+
TestFile("test_local_attn.py", 250),
103+
TestFile("test_disaggregation_different_tp.py", 210),
104+
TestFile("test_full_deepseek_v3.py", 250),
105+
TestFile("test_pp_single_node.py", 150),
97106
],
98107
"nightly": [
99108
TestFile("test_nightly_gsm8k_eval.py"),
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import os
2+
import subprocess
3+
import time
4+
import unittest
5+
from types import SimpleNamespace
6+
7+
import requests
8+
9+
from sglang.srt.utils import kill_process_tree
10+
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
11+
from sglang.test.test_utils import (
12+
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
13+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
14+
DEFAULT_URL_FOR_TEST,
15+
CustomTestCase,
16+
popen_launch_pd_server,
17+
run_with_timeout,
18+
)
19+
20+
21+
class TestDisaggregationMooncakeDifferentTP(CustomTestCase):
22+
@classmethod
23+
def setUpClass(cls):
24+
# Temporarily disable JIT DeepGEMM
25+
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
26+
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
27+
28+
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
29+
cls.base_host = "127.0.0.1"
30+
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
31+
cls.lb_url = DEFAULT_URL_FOR_TEST
32+
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
33+
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
34+
35+
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
36+
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
37+
38+
cls.wait_server_ready(cls.prefill_url + "/health")
39+
cls.wait_server_ready(cls.decode_url + "/health")
40+
41+
lb_command = [
42+
"python3",
43+
"-m",
44+
"sglang.srt.disaggregation.mini_lb",
45+
"--prefill",
46+
cls.prefill_url,
47+
"--decode",
48+
cls.decode_url,
49+
"--host",
50+
cls.base_host,
51+
"--port",
52+
str(cls.base_port),
53+
]
54+
55+
print("Starting load balancer:", " ".join(lb_command))
56+
cls.process_lb = subprocess.Popen(
57+
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
58+
)
59+
cls.wait_server_ready(cls.lb_url + "/health")
60+
61+
@classmethod
62+
def start_prefill(cls):
63+
prefill_args = [
64+
"--trust-remote-code",
65+
"--disaggregation-mode",
66+
"prefill",
67+
"--host",
68+
cls.base_host,
69+
"--port",
70+
str(cls.base_port + 100),
71+
"--tp",
72+
"4",
73+
]
74+
cls.process_prefill = popen_launch_pd_server(
75+
cls.model,
76+
cls.prefill_url,
77+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
78+
other_args=prefill_args,
79+
)
80+
81+
@classmethod
82+
def start_decode(cls):
83+
decode_args = [
84+
"--trust-remote-code",
85+
"--disaggregation-mode",
86+
"decode",
87+
"--host",
88+
cls.base_host,
89+
"--port",
90+
str(cls.base_port + 200),
91+
"--tp",
92+
"2",
93+
"--base-gpu-id",
94+
"4",
95+
]
96+
cls.process_decode = popen_launch_pd_server(
97+
cls.model,
98+
cls.decode_url,
99+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
100+
other_args=decode_args,
101+
)
102+
103+
@classmethod
104+
def wait_server_ready(cls, url, timeout=60):
105+
start_time = time.time()
106+
while True:
107+
try:
108+
response = requests.get(url)
109+
if response.status_code == 200:
110+
print(f"Server {url} is ready")
111+
return
112+
except Exception:
113+
pass
114+
115+
if time.time() - start_time > timeout:
116+
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
117+
time.sleep(1)
118+
119+
@classmethod
120+
def tearDownClass(cls):
121+
# Restore JIT DeepGEMM environment variable
122+
if cls.original_jit_deepgemm is not None:
123+
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
124+
else:
125+
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
126+
127+
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
128+
if process:
129+
try:
130+
kill_process_tree(process.pid)
131+
except Exception as e:
132+
print(f"Error killing process {process.pid}: {e}")
133+
134+
def test_gsm8k(self):
135+
args = SimpleNamespace(
136+
num_shots=5,
137+
data_path=None,
138+
num_questions=200,
139+
max_new_tokens=512,
140+
parallel=128,
141+
host="http://127.0.0.1",
142+
port=int(self.lb_url.split(":")[-1]),
143+
)
144+
metrics = run_eval_few_shot_gsm8k(args)
145+
print(f"Evaluation metrics: {metrics}")
146+
147+
self.assertGreater(metrics["accuracy"], 0.60)
148+
149+
150+
if __name__ == "__main__":
151+
unittest.main()

0 commit comments

Comments
 (0)