diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 080babb44b..745e0c26fb 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -10,7 +10,7 @@ import uuid from collections import defaultdict from functools import cache -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union import numpy as np import numpy.typing as npt @@ -32,6 +32,38 @@ logger = logging.getLogger(__name__) +NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]] + + +# From Mooncake backend. +def group_concurrent_contiguous( + src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + src_groups = [] + dst_groups = [] + current_src = [src_indices[0]] + current_dst = [dst_indices[0]] + + for i in range(1, len(src_indices)): + src_contiguous = src_indices[i] == src_indices[i - 1] + 1 + dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1 + if src_contiguous and dst_contiguous: + current_src.append(src_indices[i]) + current_dst.append(dst_indices[i]) + else: + src_groups.append(current_src) + dst_groups.append(current_dst) + current_src = [src_indices[i]] + current_dst = [dst_indices[i]] + + src_groups.append(current_src) + dst_groups.append(current_dst) + + return src_groups, dst_groups + + +GUARD = "NixlMsgGuard".encode("ascii") + @dataclasses.dataclass class TransferInfo: @@ -45,19 +77,36 @@ class TransferInfo: dst_aux_index: int dst_gpu_id: int + def is_dummy(self): + return self.endpoint == "" + @classmethod def from_zmq(cls, msg: List[bytes]): - return cls( - room=int(msg[0].decode("ascii")), - endpoint=msg[1].decode("ascii"), - dst_port=int(msg[2].decode("ascii")), - agent_metadata=msg[3], - dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), - dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), - dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), - dst_aux_index=int(msg[7].decode("ascii")), - dst_gpu_id=int(msg[8].decode("ascii")), - ) + if len(msg) == 1: + # dummy msg + return cls( + room=int(msg[0].decode("ascii")), + endpoint="", + dst_port=0, + agent_metadata=b"", + dst_kv_ptrs=[], + dst_kv_indices=np.array([], dtype=np.int64), + dst_aux_ptrs=[], + dst_aux_index=0, + dst_gpu_id=0, + ) + else: + return cls( + room=int(msg[0].decode("ascii")), + endpoint=msg[1].decode("ascii"), + dst_port=int(msg[2].decode("ascii")), + agent_metadata=msg[3], + dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), + dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), + dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), + dst_aux_index=int(msg[7].decode("ascii")), + dst_gpu_id=int(msg[8].decode("ascii")), + ) @dataclasses.dataclass @@ -98,6 +147,19 @@ def __init__( # for p/d multi node infer self.bootstrap_port = server_args.disaggregation_bootstrap_port self.dist_init_addr = server_args.dist_init_addr + self.tp_size = server_args.tp_size + + self.tp_rank = args.engine_rank + self.enable_dp_attention = server_args.enable_dp_attention + if self.enable_dp_attention: + assert ( + server_args.dp_size > 1 + ), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode." + self.dp_size = server_args.dp_size + self.tp_size_of_dp = server_args.tp_size // server_args.dp_size + self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp + self.dp_rank = args.engine_rank // self.tp_size_of_dp + self.rank_port = None self.server_socket = zmq.Context().socket(zmq.PULL) self.register_buffer_to_engine() @@ -110,7 +172,8 @@ def __init__( self._start_bootstrap_thread() self._register_to_bootstrap() elif self.disaggregation_mode == DisaggregationMode.DECODE: - self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} + # bootstrap key -> (remote_engine_rank -> possible remote source info) + self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {} self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( TransferStatus ) @@ -126,6 +189,7 @@ def register_buffer_to_engine(self): ): kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, "")) self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True) + logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}") if not self.kv_descs: raise Exception("NIXL memory registration failed for kv tensors") aux_addrs = [] @@ -134,6 +198,7 @@ def register_buffer_to_engine(self): ): aux_addrs.append((aux_data_ptr, aux_data_len, 0, "")) self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True) + logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}") if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") @@ -157,6 +222,12 @@ def send_kvcache( dst_gpu_id: int, notif: str, ): + # group by indices + prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( + prefill_kv_indices, dst_kv_indices + ) + + logger.debug(f"sending kvcache to {peer_name} with notif {notif}") # Make descs num_layers = len(self.kv_args.kv_data_ptrs) src_addrs = [] @@ -166,12 +237,16 @@ def send_kvcache( dst_ptr = dst_kv_ptrs[layer_id] item_len = self.kv_args.kv_item_lens[layer_id] - for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices): - src_addr = src_ptr + int(prefill_index) * item_len - dst_addr = dst_ptr + int(decode_index) * item_len - length = item_len + for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): + src_addr = src_ptr + int(prefill_index[0]) * item_len + dst_addr = dst_ptr + int(decode_index[0]) * item_len + length = item_len * len(prefill_index) src_addrs.append((src_addr, length, self.kv_args.gpu_id)) dst_addrs.append((dst_addr, length, dst_gpu_id)) + + logger.debug( + f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}" + ) src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True) dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True) # Transfer data @@ -180,7 +255,7 @@ def send_kvcache( src_descs, dst_descs, peer_name, - notif.encode("ascii"), + notif.encode("ascii"), # type: ignore ) if not xfer_handle: raise Exception("KVSender failed to create transfer") @@ -213,7 +288,7 @@ def send_aux( src_descs, dst_descs, peer_name, - notif.encode("ascii"), + notif.encode("ascii"), # type: ignore ) if not xfer_handle: raise Exception("KVSender failed to create transfer") @@ -240,6 +315,9 @@ def add_transfer_request( req = self.transfer_infos[bootstrap_room] assert bootstrap_room == req.room + if req.is_dummy(): + return [] + peer_name = self._add_remote(bootstrap_room, req.agent_metadata) chunked_dst_kv_indice = req.dst_kv_indices[index_slice] assert len(chunked_dst_kv_indice) == len(kv_indices) @@ -256,6 +334,7 @@ def add_transfer_request( handles = [kv_xfer_handle] # Only the last chunk we need to send the aux data. if is_last: + assert aux_index is not None aux_xfer_handle = self.send_aux( peer_name, aux_index, @@ -325,6 +404,13 @@ def bootstrap_thread(): """This thread recvs transfer info from the decode engine""" while True: waiting_req_bytes = self.server_socket.recv_multipart() + logger.debug( + f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}" + ) + assert ( + waiting_req_bytes[0] == GUARD + ), f"First message should be {GUARD}. Foreign traffic?" + waiting_req_bytes = waiting_req_bytes[1:] room = waiting_req_bytes[0].decode("ascii") if room == "None": continue @@ -372,14 +458,13 @@ def send( def poll(self) -> KVPoll: if not self.has_sent: - return KVPoll.WaitingForInput - + return KVPoll.WaitingForInput # type: ignore states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] if all([x == "DONE" for x in states]): - return KVPoll.Success + return KVPoll.Success # type: ignore if any([x == "ERR" for x in states]): raise Exception("KVSender transfer encountered an error.") - return KVPoll.WaitingForInput + return KVPoll.WaitingForInput # type: ignore def failure_exception(self): raise Exception("Fake KVSender Exception") @@ -401,7 +486,7 @@ def __init__( # NOTE: key distinguished by bootstrap_addr and engine_rank bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" - if bootstrap_key not in self.kv_mgr.connection_pool: + if bootstrap_key not in self.kv_mgr.prefill_peer_infos: self.bootstrap_info = self._get_bootstrap_info_from_server( self.kv_mgr.kv_args.engine_rank ) @@ -410,25 +495,79 @@ def __init__( f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" ) else: - self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info + self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info else: - self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] - + self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key] assert self.bootstrap_info is not None - def _get_bootstrap_info_from_server(self, engine_rank): + # return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...] + # In each dict, there are multiple possible remotes named "equal sources". + # We only need to select one to split the traffic. i.e. we totally select len(list) remotes. + def _get_bootstrap_info_from_server( + self, engine_rank + ) -> Optional[List[Dict[int, NixlEngineInfo]]]: """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}" - response = requests.get(url) - if response.status_code == 200: + if self.kv_mgr.enable_dp_attention: + url = f"http://{self.bootstrap_addr}/route" + response = requests.get(url) + if response.status_code != 200: + logger.error( + f"Failed to get prefill server info: {response.status_code}, {response.text}" + ) + return None + bootstrap_info = response.json() - return bootstrap_info - else: - logger.error( - f"Failed to get prefill server info: {response.status_code}, {response.text}" + assert isinstance(bootstrap_info, dict) + bootstrap_info = {int(k): v for k, v in bootstrap_info.items()} + + # split out who need to send to this rank. + # currently for dpsk mla model, those ranks share the same latent cache. + # pick one as the real source + + prefill_tp_size = len(bootstrap_info.keys()) + + assert ( + prefill_tp_size >= self.kv_mgr.tp_size_of_dp + ), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}" + + num_remote_tp_rank_we_managed = ( + prefill_tp_size // self.kv_mgr.tp_size_of_dp + ) + + # We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num) + remote_tp_ranks = list(range(0, prefill_tp_size)) + # split it into tp_size_of_dp parts and get our part + remote_tp_ranks_grouped = [ + remote_tp_ranks[i : i + num_remote_tp_rank_we_managed] + for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp) + ] + managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank] + + assert len(managed_ranks) == num_remote_tp_rank_we_managed + + logger.debug( + f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}" ) - return None + + return [ + { + rk: bootstrap_info[rk] + for rk in bootstrap_info.keys() + if rk in managed_ranks + } + ] + else: + url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}" + response = requests.get(url) + if response.status_code == 200: + bootstrap_info = response.json() + return [{engine_rank: bootstrap_info}] + else: + logger.error( + f"Failed to get prefill server info: {response.status_code}, {response.text}" + ) + return None except Exception as e: logger.error(f"Error fetching prefill info from bootstrap: {e}") return None @@ -440,43 +579,67 @@ def _connect(self, endpoint: str): return socket def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): - self.prefill_server_url = ( - f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}" - ) - logger.debug( - f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" - ) - packed_kv_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs - ) - packed_aux_data_ptrs = b"".join( - struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs - ) - self._connect("tcp://" + self.prefill_server_url).send_multipart( - [ - str(self.bootstrap_room).encode("ascii"), - get_local_ip_by_remote().encode("ascii"), - str(self.kv_mgr.rank_port).encode("ascii"), - self.kv_mgr.agent.get_agent_metadata(), - packed_kv_data_ptrs, - kv_indices.tobytes(), - packed_aux_data_ptrs, - str(aux_index).encode("ascii"), - str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), + assert self.bootstrap_info is not None + assert self.bootstrap_room is not None + + for equal_sources in self.bootstrap_info: + remote_rank = list(equal_sources.keys())[ + self.bootstrap_room % len(equal_sources) ] - ) + self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}" + logger.debug( + f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}" + ) + + packed_kv_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs + ) + packed_aux_data_ptrs = b"".join( + struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs + ) + + logger.debug( + f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}" + ) + self._connect("tcp://" + self.prefill_server_url).send_multipart( + [ + GUARD, + str(self.bootstrap_room).encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.kv_mgr.agent.get_agent_metadata(), + packed_kv_data_ptrs, + kv_indices.tobytes(), + packed_aux_data_ptrs, + str(aux_index).encode("ascii"), + str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), + ] + ) + + for dummy_rank in equal_sources.keys(): + if dummy_rank == remote_rank: + continue + dummy_info = equal_sources[dummy_rank] + dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}" + self._connect("tcp://" + dummy_url).send_multipart( + [ + GUARD, + str(self.bootstrap_room).encode("ascii"), + ] + ) + self.started_transfer = True def poll(self) -> KVPoll: if not self.started_transfer: - return KVPoll.WaitingForInput + return KVPoll.WaitingForInput # type: ignore self.kv_mgr.update_transfer_status() - if self.kv_mgr.check_transfer_done(self.bootstrap_room): - return KVPoll.Success - return KVPoll.WaitingForInput + if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore + return KVPoll.Success # type: ignore + return KVPoll.WaitingForInput # type: ignore def failure_exception(self): raise Exception("Fake KVReceiver Exception") @@ -484,6 +647,7 @@ def failure_exception(self): class NixlKVBootstrapServer(BaseKVBootstrapServer): def __init__(self, port: int): + logger.debug(f"NixlKVBootstrapServer started on port {port}") self.port = port self.app = web.Application() self.store = dict() @@ -564,13 +728,13 @@ async def _handle_route_put(self, request: web.Request): engine_rank = int(data["engine_rank"]) agent_name = data["agent_name"] - # Add lock to make sure thread-safe if role == "Prefill": - self.prefill_port_table[engine_rank] = { - "rank_ip": rank_ip, - "rank_port": rank_port, - "agent_name": agent_name, - } + async with self.lock: + self.prefill_port_table[engine_rank] = { + "rank_ip": rank_ip, + "rank_port": rank_port, + "agent_name": agent_name, + } logger.info( f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}" ) @@ -580,7 +744,13 @@ async def _handle_route_put(self, request: web.Request): async def _handle_route_get(self, request: web.Request): engine_rank = request.query.get("engine_rank") if not engine_rank: - return web.Response(text="Missing rank", status=400) + logger.debug( + f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict" + ) + # Return a dict of all engine_rank + async with self.lock: + bootstrap_info = self.prefill_port_table + return web.json_response(bootstrap_info, status=200) # Find corresponding prefill info async with self.lock: