Skip to content

Commit c96c1b0

Browse files
committed
NIXL DP support (sgl-project#5681)
1 parent 6f9eea5 commit c96c1b0

File tree

1 file changed

+160
-47
lines changed
  • python/sglang/srt/disaggregation/nixl

1 file changed

+160
-47
lines changed

python/sglang/srt/disaggregation/nixl/conn.py

Lines changed: 160 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import uuid
1111
from collections import defaultdict
1212
from functools import cache
13-
from typing import Dict, List, Optional, Tuple, Union
13+
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
1414

1515
import numpy as np
1616
import numpy.typing as npt
@@ -32,6 +32,8 @@
3232

3333
logger = logging.getLogger(__name__)
3434

35+
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
36+
3537

3638
@dataclasses.dataclass
3739
class TransferInfo:
@@ -45,19 +47,36 @@ class TransferInfo:
4547
dst_aux_index: int
4648
dst_gpu_id: int
4749

50+
def is_dummy(self):
51+
return self.endpoint == ""
52+
4853
@classmethod
4954
def from_zmq(cls, msg: List[bytes]):
50-
return cls(
51-
room=int(msg[0].decode("ascii")),
52-
endpoint=msg[1].decode("ascii"),
53-
dst_port=int(msg[2].decode("ascii")),
54-
agent_metadata=msg[3],
55-
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
56-
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
57-
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
58-
dst_aux_index=int(msg[7].decode("ascii")),
59-
dst_gpu_id=int(msg[8].decode("ascii")),
60-
)
55+
if len(msg) == 1:
56+
# dummy msg
57+
return cls(
58+
room=int(msg[0].decode("ascii")),
59+
endpoint="",
60+
dst_port=0,
61+
agent_metadata=b"",
62+
dst_kv_ptrs=[],
63+
dst_kv_indices=np.array([], dtype=np.int64),
64+
dst_aux_ptrs=[],
65+
dst_aux_index=0,
66+
dst_gpu_id=0,
67+
)
68+
else:
69+
return cls(
70+
room=int(msg[0].decode("ascii")),
71+
endpoint=msg[1].decode("ascii"),
72+
dst_port=int(msg[2].decode("ascii")),
73+
agent_metadata=msg[3],
74+
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
75+
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
76+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
77+
dst_aux_index=int(msg[7].decode("ascii")),
78+
dst_gpu_id=int(msg[8].decode("ascii")),
79+
)
6180

6281

6382
@dataclasses.dataclass
@@ -98,6 +117,19 @@ def __init__(
98117
# for p/d multi node infer
99118
self.bootstrap_port = server_args.disaggregation_bootstrap_port
100119
self.dist_init_addr = server_args.dist_init_addr
120+
self.tp_size = server_args.tp_size
121+
122+
self.tp_rank = args.engine_rank
123+
self.enable_dp_attention = server_args.enable_dp_attention
124+
if self.enable_dp_attention:
125+
assert (
126+
server_args.dp_size > 1
127+
), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
128+
self.dp_size = server_args.dp_size
129+
self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
130+
self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
131+
self.dp_rank = args.engine_rank // self.tp_size_of_dp
132+
101133
self.rank_port = None
102134
self.server_socket = zmq.Context().socket(zmq.PULL)
103135
self.register_buffer_to_engine()
@@ -110,7 +142,10 @@ def __init__(
110142
self._start_bootstrap_thread()
111143
self._register_to_bootstrap()
112144
elif self.disaggregation_mode == DisaggregationMode.DECODE:
113-
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
145+
# bootstrap key -> (engine_rank - >real source remote, engine_rank -> dummy remote)
146+
self.prefill_peer_infos: Dict[
147+
str, Tuple[Dict[int, NixlEngineInfo], Dict[int, NixlEngineInfo]]
148+
] = {}
114149
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
115150
TransferStatus
116151
)
@@ -180,7 +215,7 @@ def send_kvcache(
180215
src_descs,
181216
dst_descs,
182217
peer_name,
183-
notif.encode("ascii"),
218+
notif.encode("ascii"), # type: ignore
184219
)
185220
if not xfer_handle:
186221
raise Exception("KVSender failed to create transfer")
@@ -213,7 +248,7 @@ def send_aux(
213248
src_descs,
214249
dst_descs,
215250
peer_name,
216-
notif.encode("ascii"),
251+
notif.encode("ascii"), # type: ignore
217252
)
218253
if not xfer_handle:
219254
raise Exception("KVSender failed to create transfer")
@@ -240,6 +275,9 @@ def add_transfer_request(
240275
req = self.transfer_infos[bootstrap_room]
241276
assert bootstrap_room == req.room
242277

278+
if req.is_dummy():
279+
return []
280+
243281
peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
244282
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
245283
assert len(chunked_dst_kv_indice) == len(kv_indices)
@@ -256,6 +294,7 @@ def add_transfer_request(
256294
handles = [kv_xfer_handle]
257295
# Only the last chunk we need to send the aux data.
258296
if is_last:
297+
assert aux_index is not None
259298
aux_xfer_handle = self.send_aux(
260299
peer_name,
261300
aux_index,
@@ -372,14 +411,13 @@ def send(
372411

373412
def poll(self) -> KVPoll:
374413
if not self.has_sent:
375-
return KVPoll.WaitingForInput
376-
414+
return KVPoll.WaitingForInput # type: ignore
377415
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
378416
if all([x == "DONE" for x in states]):
379-
return KVPoll.Success
417+
return KVPoll.Success # type: ignore
380418
if any([x == "ERR" for x in states]):
381419
raise Exception("KVSender transfer encountered an error.")
382-
return KVPoll.WaitingForInput
420+
return KVPoll.WaitingForInput # type: ignore
383421

384422
def failure_exception(self):
385423
raise Exception("Fake KVSender Exception")
@@ -401,7 +439,7 @@ def __init__(
401439
# NOTE: key distinguished by bootstrap_addr and engine_rank
402440
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
403441

404-
if bootstrap_key not in self.kv_mgr.connection_pool:
442+
if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
405443
self.bootstrap_info = self._get_bootstrap_info_from_server(
406444
self.kv_mgr.kv_args.engine_rank
407445
)
@@ -410,25 +448,76 @@ def __init__(
410448
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
411449
)
412450
else:
413-
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
451+
self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
414452
else:
415-
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
416-
453+
self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
417454
assert self.bootstrap_info is not None
418455

419-
def _get_bootstrap_info_from_server(self, engine_rank):
456+
# return: (real source remotes, others dummy remotes)
457+
def _get_bootstrap_info_from_server(
458+
self, engine_rank
459+
) -> Optional[Tuple[Dict[int, NixlEngineInfo], Dict[int, NixlEngineInfo]]]:
420460
"""Fetch the bootstrap info from the bootstrap server."""
421461
try:
422-
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
423-
response = requests.get(url)
424-
if response.status_code == 200:
462+
if self.kv_mgr.enable_dp_attention:
463+
url = f"http://{self.bootstrap_addr}/route"
464+
response = requests.get(url)
465+
if response.status_code != 200:
466+
logger.error(
467+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
468+
)
469+
return None
470+
425471
bootstrap_info = response.json()
426-
return bootstrap_info
427-
else:
428-
logger.error(
429-
f"Failed to get prefill server info: {response.status_code}, {response.text}"
472+
assert isinstance(bootstrap_info, dict)
473+
bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
474+
475+
# split out who need to send to this rank.
476+
# currently for dpsk mla model, those ranks share the same latent cache.
477+
# pick one as the real source
478+
479+
prefill_tp_size = len(bootstrap_info.keys())
480+
481+
assert (
482+
prefill_tp_size >= self.kv_mgr.tp_size_of_dp
483+
), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
484+
485+
num_remote_tp_rank_we_managed = (
486+
prefill_tp_size // self.kv_mgr.tp_size_of_dp
430487
)
431-
return None
488+
489+
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
490+
remote_tp_ranks = list(range(0, prefill_tp_size))
491+
# split it into tp_size_of_dp parts and get our part
492+
remote_tp_ranks_grouped = [
493+
remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
494+
for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
495+
]
496+
managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
497+
picked_rank = managed_ranks[0]
498+
499+
assert len(managed_ranks) == num_remote_tp_rank_we_managed
500+
501+
logger.debug(
502+
f"Rank {self.kv_mgr.kv_args.engine_rank} managed {managed_ranks}, picked {picked_rank} as real source"
503+
)
504+
505+
return {picked_rank: bootstrap_info[picked_rank]}, {
506+
rk: bootstrap_info[rk]
507+
for rk in bootstrap_info.keys()
508+
if rk in managed_ranks and rk != picked_rank
509+
}
510+
else:
511+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
512+
response = requests.get(url)
513+
if response.status_code == 200:
514+
bootstrap_info = response.json()
515+
return {engine_rank: bootstrap_info}, {}
516+
else:
517+
logger.error(
518+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
519+
)
520+
return None
432521
except Exception as e:
433522
logger.error(f"Error fetching prefill info from bootstrap: {e}")
434523
return None
@@ -440,11 +529,20 @@ def _connect(self, endpoint: str):
440529
return socket
441530

442531
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
443-
self.prefill_server_url = (
444-
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
445-
)
532+
533+
assert self.bootstrap_info is not None
534+
535+
sources = self.bootstrap_info[0]
536+
dummy = self.bootstrap_info[1]
537+
538+
assert len(sources) == 1, "Only support one source now"
539+
540+
remote_rank = list(self.bootstrap_info[0].keys())[0]
541+
542+
self.prefill_server_url = f"{self.bootstrap_info[0][remote_rank]['rank_ip']}:{self.bootstrap_info[0][remote_rank]['rank_port']}"
543+
446544
logger.debug(
447-
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
545+
f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {self.bootstrap_info[0].keys()}, dummy: {self.bootstrap_info[1].keys()} "
448546
)
449547

450548
packed_kv_data_ptrs = b"".join(
@@ -466,17 +564,26 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
466564
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
467565
]
468566
)
567+
568+
for dummy_rank, dummy_info in dummy.items():
569+
dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
570+
self._connect("tcp://" + dummy_url).send_multipart(
571+
[
572+
str(self.bootstrap_room).encode("ascii"),
573+
]
574+
)
575+
469576
self.started_transfer = True
470577

471578
def poll(self) -> KVPoll:
472579
if not self.started_transfer:
473-
return KVPoll.WaitingForInput
580+
return KVPoll.WaitingForInput # type: ignore
474581

475582
self.kv_mgr.update_transfer_status()
476583

477-
if self.kv_mgr.check_transfer_done(self.bootstrap_room):
478-
return KVPoll.Success
479-
return KVPoll.WaitingForInput
584+
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
585+
return KVPoll.Success # type: ignore
586+
return KVPoll.WaitingForInput # type: ignore
480587

481588
def failure_exception(self):
482589
raise Exception("Fake KVReceiver Exception")
@@ -564,13 +671,13 @@ async def _handle_route_put(self, request: web.Request):
564671
engine_rank = int(data["engine_rank"])
565672
agent_name = data["agent_name"]
566673

567-
# Add lock to make sure thread-safe
568674
if role == "Prefill":
569-
self.prefill_port_table[engine_rank] = {
570-
"rank_ip": rank_ip,
571-
"rank_port": rank_port,
572-
"agent_name": agent_name,
573-
}
675+
async with self.lock:
676+
self.prefill_port_table[engine_rank] = {
677+
"rank_ip": rank_ip,
678+
"rank_port": rank_port,
679+
"agent_name": agent_name,
680+
}
574681
logger.info(
575682
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
576683
)
@@ -580,7 +687,13 @@ async def _handle_route_put(self, request: web.Request):
580687
async def _handle_route_get(self, request: web.Request):
581688
engine_rank = request.query.get("engine_rank")
582689
if not engine_rank:
583-
return web.Response(text="Missing rank", status=400)
690+
logger.debug(
691+
f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
692+
)
693+
# Return a dict of all engine_rank
694+
async with self.lock:
695+
bootstrap_info = self.prefill_port_table
696+
return web.json_response(bootstrap_info, status=200)
584697

585698
# Find corresponding prefill info
586699
async with self.lock:

0 commit comments

Comments
 (0)