Skip to content

Commit 0978fca

Browse files
whybeyoungxwu-intel
authored andcommitted
[PD] support pd fake transfer for warmup (sgl-project#5726)
1 parent 70faced commit 0978fca

File tree

6 files changed

+146
-7
lines changed

6 files changed

+146
-7
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
3333
from sglang.srt.disaggregation.utils import (
3434
DisaggregationMode,
35+
FakeBootstrapHost,
3536
KVClassType,
3637
ReqToMetadataIdxAllocator,
3738
TransferBackend,
@@ -133,8 +134,13 @@ def _init_kv_manager(self) -> BaseKVManager:
133134

134135
def add(self, req: Req) -> None:
135136
"""Add a request to the pending queue."""
136-
137-
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
137+
if req.bootstrap_host == FakeBootstrapHost:
138+
# Fake transfer for warmup reqs
139+
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
140+
else:
141+
kv_receiver_class = get_kv_class(
142+
self.transfer_backend, KVClassType.RECEIVER
143+
)
138144
kv_receiver = kv_receiver_class(
139145
mgr=self.kv_manager,
140146
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .conn import FakeKVReceiver, FakeKVSender
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import logging
2+
from typing import Dict, List, Optional, Tuple, Union
3+
4+
import numpy as np
5+
import numpy.typing as npt
6+
7+
from sglang.srt.disaggregation.base.conn import (
8+
BaseKVManager,
9+
BaseKVReceiver,
10+
BaseKVSender,
11+
KVArgs,
12+
KVPoll,
13+
)
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
19+
class FakeKVSender(BaseKVSender):
20+
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
21+
self.has_sent = False
22+
23+
def poll(self) -> KVPoll:
24+
if self.has_sent is False:
25+
# Assume handshake completed instantly
26+
return KVPoll.WaitingForInput
27+
else:
28+
# Assume transfer completed instantly
29+
logger.info("FakeKVSender poll success")
30+
return KVPoll.Success
31+
32+
def init(
33+
self,
34+
kv_indices: list[int],
35+
aux_index: Optional[int] = None,
36+
dest_ranks: Optional[list[int]] = None,
37+
):
38+
logger.info(
39+
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
40+
)
41+
pass
42+
43+
def send(
44+
self,
45+
kv_indices: npt.NDArray[np.int64],
46+
index_slice: slice,
47+
is_last: bool,
48+
):
49+
logger.info(
50+
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
51+
)
52+
if is_last:
53+
self.has_sent = True
54+
logger.info(f"FakeKVSender send success")
55+
else:
56+
self.has_sent = False
57+
logger.info(f"FakeKVSender send fake transfering")
58+
59+
def failure_exception(self):
60+
raise Exception("Fake KVSender Exception")
61+
62+
63+
class FakeKVReceiver(BaseKVReceiver):
64+
def __init__(
65+
self,
66+
mgr: BaseKVManager,
67+
bootstrap_addr: str,
68+
bootstrap_room: Optional[int] = None,
69+
):
70+
self.has_init = False
71+
72+
def poll(self) -> KVPoll:
73+
if self.has_init is False:
74+
# Assume handshake completed instantly
75+
return KVPoll.WaitingForInput
76+
else:
77+
# Assume transfer completed instantly
78+
logger.info("FakeKVReceiver poll success")
79+
return KVPoll.Success
80+
81+
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
82+
self.has_init = True
83+
logger.info(
84+
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
85+
)
86+
87+
def failure_exception(self):
88+
raise Exception("Fake KVReceiver Exception")

python/sglang/srt/disaggregation/prefill.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
3030
from sglang.srt.disaggregation.utils import (
3131
DisaggregationMode,
32+
FakeBootstrapHost,
3233
KVClassType,
3334
ReqToMetadataIdxAllocator,
3435
TransferBackend,
@@ -116,7 +117,11 @@ def _init_kv_manager(self) -> BaseKVManager:
116117
return kv_manager
117118

118119
def add(self, req: Req) -> None:
119-
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
120+
if req.bootstrap_host == FakeBootstrapHost:
121+
# Fake transfer for warmup reqs
122+
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
123+
else:
124+
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
120125
req.disagg_kv_sender = kv_sender_class(
121126
mgr=self.kv_manager,
122127
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",

python/sglang/srt/disaggregation/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
1515
DECODE = "decode"
1616

1717

18+
FakeBootstrapHost = "2.2.2.2"
19+
20+
1821
def poll_and_all_reduce(pollers, gloo_group):
1922
polls = [int(poller.poll()) for poller in pollers]
2023
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
@@ -59,6 +62,8 @@ class KVClassType(Enum):
5962

6063

6164
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
65+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
66+
6267
if transfer_backend == TransferBackend.MOONCAKE:
6368
from sglang.srt.disaggregation.mooncake import (
6469
MooncakeKVBootstrapServer,
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
7075
class_mapping = {
7176
KVClassType.MANAGER: MooncakeKVManager,
7277
KVClassType.SENDER: MooncakeKVSender,
73-
KVClassType.RECEIVER: MooncakeKVReceiver,
78+
KVClassType.RECEIVER: (MooncakeKVReceiver),
7479
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
7580
}
7681
return class_mapping.get(class_type)
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
8590
class_mapping = {
8691
KVClassType.MANAGER: NixlKVManager,
8792
KVClassType.SENDER: NixlKVSender,
88-
KVClassType.RECEIVER: NixlKVReceiver,
93+
KVClassType.RECEIVER: (NixlKVReceiver),
8994
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
9095
}
9196
return class_mapping.get(class_type)
97+
if transfer_backend == TransferBackend.FAKE:
98+
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
99+
100+
class_mapping = {
101+
KVClassType.SENDER: FakeKVSender,
102+
KVClassType.RECEIVER: (FakeKVReceiver),
103+
}
104+
return class_mapping.get(class_type)
105+
92106
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
93107

94108

python/sglang/srt/entrypoints/http_server.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from fastapi.middleware.cors import CORSMiddleware
4343
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
4444

45+
from sglang.srt.disaggregation.utils import FakeBootstrapHost
4546
from sglang.srt.entrypoints.engine import _launch_subprocesses
4647
from sglang.srt.function_call_parser import FunctionCallParser
4748
from sglang.srt.managers.io_struct import (
@@ -821,8 +822,32 @@ def _wait_and_warmup(
821822
)
822823
assert res.status_code == 200, f"{res}"
823824
else:
824-
# Warmup request currently hangs in disaggregation mode, so we skip it.
825-
logger.info("Skipping warmup request in disaggregation mode")
825+
logger.info(f"Start of prefill warmup ...")
826+
json_data = {
827+
"sampling_params": {
828+
"temperature": 0.0,
829+
"max_new_tokens": 8,
830+
"ignore_eos": True,
831+
},
832+
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
833+
# This is a hack to ensure fake transfer is enabled during prefill warmup
834+
# ensure each dp rank has a unique bootstrap_room during prefill warmup
835+
"bootstrap_room": [
836+
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
837+
for i in range(server_args.dp_size)
838+
],
839+
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
840+
}
841+
res = requests.post(
842+
url + request_name,
843+
json=json_data,
844+
headers=headers,
845+
timeout=1800, # because of deep gemm precache is very long if not precache.
846+
)
847+
logger.info(
848+
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
849+
)
850+
826851
except Exception:
827852
last_traceback = get_exception_traceback()
828853
if pipe_finish_writer is not None:

0 commit comments

Comments
 (0)