Skip to content

Commit da1bd6b

Browse files
lambert0312jimoosciuc
authored andcommitted
Fix broadcast use cuda device lead to memory capacity unbalanced (sgl-project#5416)
1 parent 3d31c0f commit da1bd6b

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
logger = logging.getLogger(__name__)
3333

34+
3435
def find_available_ports(base_port: int, count: int) -> List[int]:
3536
"""Find consecutive available ports starting from base_port."""
3637
available_ports = []
@@ -43,6 +44,7 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
4344

4445
return available_ports
4546

47+
4648
def group_concurrent_contiguous(
4749
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
4850
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
@@ -265,7 +267,9 @@ def transfer_thread():
265267
)
266268
if ret != 0:
267269
self.request_status[kv_chunk.room] = KVPoll.Failed
268-
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
270+
self.sync_status_to_decode_endpoint(
271+
req.endpoint, req.dst_port, req.room
272+
)
269273
continue
270274

271275
if kv_chunk.is_last:
@@ -279,7 +283,9 @@ def transfer_thread():
279283
self.request_status[req.room] = (
280284
KVPoll.Success if ret == 0 else KVPoll.Failed
281285
)
282-
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
286+
self.sync_status_to_decode_endpoint(
287+
req.endpoint, req.dst_port, req.room
288+
)
283289
self.transfer_infos.pop(req.room)
284290

285291
except queue.Empty:
@@ -443,13 +449,14 @@ def _get_prefill_info_from_bootstrap(self, tp_rank: int):
443449
prefill_info = response.json()
444450
return prefill_info
445451
else:
446-
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}")
452+
logger.error(
453+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
454+
)
447455
return None
448456
except Exception as e:
449457
logger.error(f"Error fetching prefill info from bootstrap: {e}")
450458
return None
451459

452-
453460
@cache
454461
def _connect(self, endpoint: str):
455462
socket = zmq.Context().socket(zmq.PUSH)
@@ -466,17 +473,25 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
466473
)
467474
if prefill_info is None:
468475
logger.error(
469-
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}")
476+
logger.error(
477+
f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}"
478+
)
470479
)
471480
else:
472-
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info
481+
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = (
482+
prefill_info
483+
)
473484
else:
474485
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
475486

476487
if prefill_info:
477-
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
488+
self.prefill_server_url = (
489+
f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
490+
)
478491

479-
logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}")
492+
logger.info(
493+
f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
494+
)
480495
self.handshake_prefill_server(kv_indices, aux_index)
481496

482497
def handshake_prefill_server(
@@ -598,8 +613,13 @@ async def _handle_kv_route_put(self, request: web.Request):
598613
# Add lock to make sure thread-safe
599614
if role == "Prefill":
600615
async with self.lock:
601-
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port}
602-
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}")
616+
self.prefill_port_table[tp_rank] = {
617+
"serve_ip": serve_ip,
618+
"serve_port": serve_port,
619+
}
620+
logger.info(
621+
f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}"
622+
)
603623

604624
return web.Response(text="OK", status=200)
605625

python/sglang/srt/entrypoints/verl_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def generate(
118118
rank=self._tp_rank,
119119
dist_group=self._device_mesh_cpu.get_group(),
120120
src=self._device_mesh_cpu.mesh[0].item(),
121+
force_cpu_device=False,
121122
)
122123

123124
return output

python/sglang/srt/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,9 +846,12 @@ def broadcast_pyobj(
846846
rank: int,
847847
dist_group: Optional[torch.distributed.ProcessGroup] = None,
848848
src: int = 0,
849+
force_cpu_device: bool = True,
849850
):
850851
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
851-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
852+
device = torch.device(
853+
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
854+
)
852855

853856
if rank == 0:
854857
if len(data) == 0:

0 commit comments

Comments
 (0)