Skip to content

[PD] Add support for different TP sizes per DP rank #5922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 12, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 88 additions & 57 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, TransferInfo] = {}
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread()
self._register_to_bootstrap()
Expand All @@ -154,6 +154,7 @@ def __init__(
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
else:
raise ValueError(
Expand Down Expand Up @@ -273,8 +274,8 @@ def bootstrap_thread():
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii")
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
if room == "None":
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
Expand All @@ -283,7 +284,11 @@ def bootstrap_thread():
)
continue
room = int(room)
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
self.transfer_infos[room][mooncake_session_id] = TransferInfo.from_zmq(
waiting_req_bytes
)

# NOTE: after bootstrapping we can mark the req as waiting for input
self.request_status[room] = KVPoll.WaitingForInput
Expand All @@ -293,42 +298,46 @@ def transfer_thread():
while True:
try:
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
req = self.transfer_infos[kv_chunk.room]
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
assert len(chunked_dst_kv_indice) == len(
kv_chunk.prefill_kv_indices
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"

ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
chunked_dst_kv_indice,
)
if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
continue

if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
# Note(shangming): might need to assert MLA is used when prefill instances and decode instances have different tp_size_per_dp_rank
reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
for req in reqs_to_be_processed:
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
assert len(chunked_dst_kv_indice) == len(
kv_chunk.prefill_kv_indices
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"

ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
].dst_kv_ptrs,
chunked_dst_kv_indice,
)
self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed
)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
self.transfer_infos.pop(req.room)
if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
continue

if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed
)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
)
self.transfer_infos.pop(req.room)

except queue.Empty:
continue
Expand Down Expand Up @@ -478,40 +487,60 @@ def __init__(
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)

if not self.kv_mgr.enable_dp_attention:
# We assume dp_attention should be activated simultaneously for
# both prefill role and decode role. If the decode instance does
# not enable dp_attention, then dp_attention is not enabled on the
# prefill instance as well. Therefore, we should skip questioning
# the prefill dp size to reduce bootstrap overhead.
self.prefill_dp_size = 1
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_dp_size, tp_size_per_dp_rank = (
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server()
)
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank.
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
if self.prefill_dp_size is None:
if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
)
else:
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
else:
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]

# NOTE: key distinguished by bootstrap_addr and engine_rank
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
)
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
else:
# NOTE(shangming): local_tp_size_per_dp_rank < prefill_tp_size_per_dp_rank is not supported.
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
# For MLA models, we can retrieve KVCache from any prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = None

assert (
self.target_tp_rank is not None
), "decode_tp_size_per_dp_rank < prefill_tp_size_per_dp_rank is not supported yet"
self.target_dp_group = bootstrap_room % self.prefill_dp_size

# NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"

if bootstrap_key not in self.kv_mgr.connection_pool:
self.bootstrap_info = self._get_bootstrap_info_from_server(
self.kv_mgr.kv_args.engine_rank,
self.target_tp_rank,
self.target_dp_group,
)
if self.bootstrap_info is None:
Expand Down Expand Up @@ -552,8 +581,8 @@ def _get_prefill_dp_size_from_server(self) -> int:
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_dp_size"]), int(
prefill_parallel_info["tp_size_per_dp_rank"]
return int(prefill_parallel_info["prefill_tp_size"]), int(
prefill_parallel_info["prefill_dp_size"]
)
else:
logger.error(
Expand Down Expand Up @@ -633,6 +662,7 @@ def __init__(self, port: int):
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.tp_size = None
self.dp_size = None
self.tp_size_per_dp_rank = None
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
Expand Down Expand Up @@ -667,6 +697,9 @@ async def _handle_route_put(self, request: web.Request):
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])

if self.tp_size is None:
self.tp_size = tp_size

if self.dp_size is None:
self.dp_size = dp_size

Expand Down Expand Up @@ -702,17 +735,15 @@ async def _handle_route_get(self, request: web.Request):
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1:
prefill_parallel_info = {
"prefill_tp_size": self.tp_size,
"prefill_dp_size": self.dp_size,
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
}
return web.json_response(prefill_parallel_info, status=200)

# Find corresponding prefill info
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank

async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
tp_rank_in_dp_group
int(engine_rank)
]

if bootstrap_info is not None:
Expand Down
Loading