Skip to content

Commit ab73883

Browse files
committed
[PD] support pd warmup ...
1 parent 7e94424 commit ab73883

File tree

6 files changed

+157
-10
lines changed

6 files changed

+157
-10
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
3333
from sglang.srt.disaggregation.utils import (
3434
DisaggregationMode,
35+
FakeBootstrapHost,
3536
KVClassType,
3637
ReqToMetadataIdxAllocator,
3738
TransferBackend,
@@ -133,8 +134,15 @@ def _init_kv_manager(self) -> BaseKVManager:
133134

134135
def add(self, req: Req) -> None:
135136
"""Add a request to the pending queue."""
136-
137-
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
137+
if req.bootstrap_host == FakeBootstrapHost:
138+
# Fake transfer for warmup reqs
139+
kv_receiver_class = get_kv_class(
140+
self.transfer_backend, KVClassType.RECEIVER, fake_transfer=True
141+
)
142+
else:
143+
kv_receiver_class = get_kv_class(
144+
self.transfer_backend, KVClassType.RECEIVER
145+
)
138146
kv_receiver = kv_receiver_class(
139147
mgr=self.kv_manager,
140148
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .conn import FakeKVReceiver, FakeKVSender
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import logging
2+
from typing import Dict, List, Optional, Tuple, Union
3+
4+
import numpy as np
5+
import numpy.typing as npt
6+
7+
from sglang.srt.disaggregation.base.conn import (
8+
BaseKVManager,
9+
BaseKVReceiver,
10+
BaseKVSender,
11+
KVArgs,
12+
KVPoll,
13+
)
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
19+
class FakeKVSender(BaseKVSender):
20+
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
21+
self.has_sent = False
22+
23+
def poll(self) -> KVPoll:
24+
if self.has_sent is False:
25+
# Assume handshake completed instantly
26+
return KVPoll.WaitingForInput
27+
else:
28+
# Assume transfer completed instantly
29+
logger.info("FakeKVSender poll success")
30+
return KVPoll.Success
31+
32+
def init(
33+
self,
34+
kv_indices: list[int],
35+
aux_index: Optional[int] = None,
36+
dest_ranks: Optional[list[int]] = None,
37+
):
38+
logger.info(
39+
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
40+
)
41+
pass
42+
43+
def send(
44+
self,
45+
kv_indices: npt.NDArray[np.int64],
46+
index_slice: slice,
47+
is_last: bool,
48+
):
49+
logger.info(
50+
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
51+
)
52+
if is_last:
53+
self.has_sent = True
54+
logger.info(f"FakeKVSender send success")
55+
else:
56+
self.has_sent = False
57+
logger.info(f"FakeKVSender send fake transfering")
58+
59+
def failure_exception(self):
60+
raise Exception("Fake KVSender Exception")
61+
62+
63+
class FakeKVReceiver(BaseKVReceiver):
64+
def __init__(
65+
self,
66+
mgr: BaseKVManager,
67+
bootstrap_addr: str,
68+
bootstrap_room: Optional[int] = None,
69+
):
70+
self.has_init = False
71+
72+
def poll(self) -> KVPoll:
73+
if self.has_init is False:
74+
# Assume handshake completed instantly
75+
return KVPoll.WaitingForInput
76+
else:
77+
# Assume transfer completed instantly
78+
logger.info("FakeKVReceiver poll success")
79+
return KVPoll.Success
80+
81+
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
82+
self.has_init = True
83+
logger.info(
84+
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
85+
)
86+
87+
def failure_exception(self):
88+
raise Exception("Fake KVReceiver Exception")

python/sglang/srt/disaggregation/prefill.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
2929
from sglang.srt.disaggregation.utils import (
3030
DisaggregationMode,
31+
FakeBootstrapHost,
3132
KVClassType,
3233
ReqToMetadataIdxAllocator,
3334
TransferBackend,
@@ -115,7 +116,13 @@ def _init_kv_manager(self) -> BaseKVManager:
115116
return kv_manager
116117

117118
def add(self, req: Req) -> None:
118-
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
119+
if req.bootstrap_host == FakeBootstrapHost:
120+
# Fake transfer for warmup reqs
121+
kv_sender_class = get_kv_class(
122+
self.transfer_backend, KVClassType.SENDER, fake_transfer=True
123+
)
124+
else:
125+
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
119126
req.disagg_kv_sender = kv_sender_class(
120127
mgr=self.kv_manager,
121128
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",

python/sglang/srt/disaggregation/utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
1515
DECODE = "decode"
1616

1717

18+
FakeBootstrapHost = "2.2.2.2"
19+
20+
1821
def poll_and_all_reduce(pollers, gloo_group):
1922
polls = [int(poller.poll()) for poller in pollers]
2023
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
@@ -58,7 +61,13 @@ class KVClassType(Enum):
5861
BOOTSTRAP_SERVER = "bootstrap_server"
5962

6063

61-
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
64+
def get_kv_class(
65+
transfer_backend: TransferBackend,
66+
class_type: KVClassType,
67+
fake_transfer: bool = False,
68+
):
69+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
70+
6271
if transfer_backend == TransferBackend.MOONCAKE:
6372
from sglang.srt.disaggregation.mooncake import (
6473
MooncakeKVBootstrapServer,
@@ -69,8 +78,10 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
6978

7079
class_mapping = {
7180
KVClassType.MANAGER: MooncakeKVManager,
72-
KVClassType.SENDER: MooncakeKVSender,
73-
KVClassType.RECEIVER: MooncakeKVReceiver,
81+
KVClassType.SENDER: MooncakeKVSender if not fake_transfer else FakeKVSender,
82+
KVClassType.RECEIVER: (
83+
MooncakeKVReceiver if not fake_transfer else FakeKVReceiver
84+
),
7485
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
7586
}
7687
return class_mapping.get(class_type)
@@ -84,8 +95,10 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
8495

8596
class_mapping = {
8697
KVClassType.MANAGER: NixlKVManager,
87-
KVClassType.SENDER: NixlKVSender,
88-
KVClassType.RECEIVER: NixlKVReceiver,
98+
KVClassType.SENDER: NixlKVSender if not fake_transfer else FakeKVSender,
99+
KVClassType.RECEIVER: (
100+
NixlKVReceiver if not fake_transfer else FakeKVReceiver
101+
),
89102
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
90103
}
91104
return class_mapping.get(class_type)

python/sglang/srt/entrypoints/http_server.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from fastapi.middleware.cors import CORSMiddleware
4343
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
4444

45+
from sglang.srt.disaggregation.utils import FakeBootstrapHost
4546
from sglang.srt.entrypoints.engine import _launch_subprocesses
4647
from sglang.srt.function_call_parser import FunctionCallParser
4748
from sglang.srt.managers.io_struct import (
@@ -814,9 +815,38 @@ def _wait_and_warmup(
814815
timeout=600,
815816
)
816817
assert res.status_code == 200, f"{res}"
818+
elif server_args.disaggregation_mode == "prefill":
819+
logger.info(f"Start of prefill warmup ...")
820+
json_data = {
821+
"sampling_params": {
822+
"temperature": 0.0,
823+
"max_new_tokens": 8,
824+
"ignore_eos": True,
825+
},
826+
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
827+
# This is a hack to ensure fake transfer is enabled during prefill warmup
828+
# ensure each dp rank has a unique bootstrap_room during prefill warmup
829+
"bootstrap_room": [
830+
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
831+
for i in range(server_args.dp_size)
832+
],
833+
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
834+
}
835+
res = requests.post(
836+
url + request_name,
837+
json=json_data,
838+
headers=headers,
839+
timeout=1800, # because of deep gemm precache is very long if not precache.
840+
)
841+
logger.info(
842+
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
843+
)
817844
else:
818-
# Warmup request currently hangs in disaggregation mode, so we skip it.
819-
logger.info("Skipping warmup request in disaggregation mode")
845+
logger.info(
846+
"Skipping warmup request in mode {}".format(
847+
server_args.disaggregation_mode
848+
)
849+
)
820850
except Exception:
821851
last_traceback = get_exception_traceback()
822852
if pipe_finish_writer is not None:

0 commit comments

Comments
 (0)