Skip to content

Improve dp attention port assignment scheme #5889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f96d599
feat: dynamic DP controller port dispatch
jokerwyt Apr 28, 2025
bfecdbb
Merge remote-tracking branch 'gh/main' into dp-port-dispatch
jokerwyt Apr 29, 2025
60f8a55
Fix completions endpoint bootstrap port passing
jokerwyt Apr 27, 2025
3e5d6ed
[WIP] dynamic DP port
jokerwyt Apr 29, 2025
68fdf09
Dynamic DP port assignment
jokerwyt Apr 29, 2025
2865267
Better dynamic port, lower conflict
jokerwyt Apr 29, 2025
6f9eea5
small fix
jokerwyt Apr 29, 2025
c96c1b0
NIXL DP support (#5681)
jokerwyt Apr 29, 2025
e221b28
Remove some debug print
jokerwyt Apr 30, 2025
6468136
Merge branch 'main' of https://github.com/sgl-project/sglang into dp-…
jokerwyt May 13, 2025
8636070
Merge branch 'main' of https://github.com/sgl-project/sglang into dp-…
jokerwyt May 14, 2025
0fa37d3
Merge branch 'main' into dp-port-dispatch
jokerwyt May 14, 2025
ad828c1
Atomic assignment of dp attention scheduler ports
jokerwyt May 16, 2025
ac7662b
Merge branch 'main' of https://github.com/sgl-project/sglang into dp-…
jokerwyt May 16, 2025
f5930a0
Merge branch 'dp-port-dispatch' of github.com:jokerwyt/sglang-public …
jokerwyt May 16, 2025
dc379a6
Refine
jokerwyt May 16, 2025
b21222a
Merge branch 'main' into dp-port-dispatch
jokerwyt May 28, 2025
0872328
Merge branch 'main' into dp-port-dispatch
ispobock May 31, 2025
ee343ac
Merge branch 'main' into dp-port-dispatch
zhyncs May 31, 2025
b23a42f
Merge branch 'main' of https://github.com/sgl-project/sglang into dp-…
jokerwyt Jun 7, 2025
50fc888
Add cmd args and test
jokerwyt Jun 7, 2025
96ed813
Merge branch 'dp-port-dispatch' of github.com:jokerwyt/sglang-public …
jokerwyt Jun 7, 2025
2a7f2d0
Merge branch 'main' into dp-port-dispatch
jokerwyt Jun 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
from sglang.srt.utils import (
bind_port,
configure_logger,
get_free_port,
get_tcp_zmq_socket_binded_to_local_free_port,
get_zmq_socket,
)
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,15 +91,24 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
self.scheduler_procs = []
self.workers = [None] * server_args.dp_size

if server_args.node_rank == 0 and server_args.pick_free_dp_port:
self.workers_port = {}
for dp_rank in range(server_args.dp_size):
port_and_socket = get_tcp_zmq_socket_binded_to_local_free_port(
self.context, zmq.PUSH
)
self.workers[dp_rank] = port_and_socket[1]
self.workers_port[dp_rank] = port_and_socket[0]
logger.debug(f"Port assign to worker {dp_rank}: {port_and_socket[0]}")

if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
self.control_message_step = server_args.tp_size
else:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
self.control_message_step = 1

# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
if server_args.node_rank == 0 and not server_args.pick_free_dp_port:
for dp_rank in range(server_args.dp_size):
self.workers[dp_rank] = get_zmq_socket(
self.context,
Expand Down Expand Up @@ -160,11 +175,62 @@ def launch_tensor_parallel_group_thread(
while True:
time.sleep(30 * 24 * 3600)

def _dispatch_dp_attn_ctrl_zmq_port(self, server_args: ServerArgs):
ex_endpoint = None
if server_args.dist_init_addr is None:
ex_endpoint = f"tcp://127.0.0.1:{server_args.port + 5}"
else:
ex_endpoint = f"tcp://{server_args.dist_init_addr}"

if server_args.node_rank == 0:
free_ports = {i: port for i, port in self.workers_port.items()}
logger.debug(f"Free ports: {free_ports}")

# broadcast dp_port_args to all dp ranks
rep_socket = get_zmq_socket(self.context, zmq.REP, ex_endpoint, True)

connected_nodes = 0
expected_nodes = server_args.nnodes - 1

logger.debug(
f"DP Controller: Node Rank 0 started, waiting for {expected_nodes} nodes to connect."
)
while connected_nodes < expected_nodes:
msg = rep_socket.recv()
logger.debug(f"Node 0 received handshake from node {msg.decode()}")
# send dp_port_args to the node
rep_socket.send_pyobj(free_ports)
connected_nodes += 1
logger.debug(
f"DP Controller: {connected_nodes}/{expected_nodes} nodes connected."
)
logger.debug("DP Controller: All nodes connected")

rep_socket.close()
else:
req_socket = get_zmq_socket(self.context, zmq.REQ, ex_endpoint, False)

req_socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 min timeout
req_socket.setsockopt(zmq.SNDTIMEO, 60 * 1000)

try:
req_socket.send(str(server_args.node_rank).encode())
free_ports = req_socket.recv_pyobj()
logger.debug(
f"Node {server_args.node_rank} received handshake from node 0, len {len(free_ports)}"
)
except zmq.Again:
logger.error("Handshake timeout with node 0")
raise
PortArgs.register_dp_controller_to_attn_tp_rk0_port(free_ports)

def launch_dp_attention_schedulers(self, server_args, port_args):
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
if server_args.pick_free_dp_port:
self._dispatch_dp_attn_ctrl_zmq_port(server_args)
dp_port_args = []
for dp_rank in range(server_args.dp_size):
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
return dp_port_args

def launch_tensor_parallel_group(
Expand Down Expand Up @@ -253,6 +319,9 @@ def round_robin_scheduler(self, req: Req):
self.workers
)
else:
logger.debug(
f"DP round_robin scheduler Room {req.bootstrap_room} -> Worker {req.bootstrap_room % len(self.workers)}"
)
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)

def shortest_queue_scheduler(self, input_requests):
Expand Down
20 changes: 19 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
mm_attention_backend: Optional[str] = None
pick_free_dp_port: bool = False

# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
Expand Down Expand Up @@ -1486,6 +1487,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Set multimodal attention backend.",
)

parser.add_argument(
"--pick-free-dp-port",
action="store_true",
help="Whether to picks dp ports from free ports, or use fixed dp port as default. Useful when get frequent port conflict under huge DP cases",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
Expand Down Expand Up @@ -1557,6 +1564,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
ZMQ_TCP_PORT_DELTA = 233


dp_controller_zmq_ports: dict[int, int] = {}


@dataclasses.dataclass
class PortArgs:
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
Expand All @@ -1572,6 +1582,11 @@ class PortArgs:
# The ipc filename for rpc call between Engine and Scheduler
rpc_ipc_name: str

@staticmethod
def register_dp_controller_to_attn_tp_rk0_port(ports: dict[int, int]):
global dp_controller_zmq_ports
dp_controller_zmq_ports = ports

@staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
port = server_args.port + random.randint(100, 1000)
Expand Down Expand Up @@ -1613,7 +1628,10 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
port_base + 3
) # TokenizerManager to DataParallelController
else:
scheduler_input_port = port_base + 3 + 1 + dp_rank
if server_args.pick_free_dp_port:
scheduler_input_port = dp_controller_zmq_ports[dp_rank]
else:
scheduler_input_port = port_base + 3 + 1 + dp_rank

return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
Expand Down
29 changes: 21 additions & 8 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,9 +1037,7 @@ def pytorch_profile(name, func, *args, data_size=-1):
return result


def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
def config_socket(socket, socket_type: zmq.SocketType):
mem = psutil.virtual_memory()
total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3
Expand All @@ -1048,10 +1046,6 @@ def get_zmq_socket(
else:
buf_size = -1

socket = context.socket(socket_type)
if endpoint.find("[") != -1:
socket.setsockopt(zmq.IPV6, 1)

def set_send_opt():
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
Expand All @@ -1064,12 +1058,31 @@ def set_recv_opt():
set_send_opt()
elif socket_type == zmq.PULL:
set_recv_opt()
elif socket_type == zmq.DEALER:
elif socket_type in [zmq.DEALER, zmq.REQ, zmq.REP]:
set_send_opt()
set_recv_opt()
else:
raise ValueError(f"Unsupported socket type: {socket_type}")


def get_tcp_zmq_socket_binded_to_local_free_port(
context: zmq.Context, socket_type: zmq.SocketType
) -> Tuple[int, zmq.Socket]:
socket = context.socket(socket_type)
config_socket(socket, socket_type)
port = socket.bind_to_random_port("tcp://*")
return port, socket


def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
socket = context.socket(socket_type)
if endpoint.find("[") != -1:
socket.setsockopt(zmq.IPV6, 1)

config_socket(socket, socket_type)

if bind:
socket.bind(endpoint)
else:
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class TestFile:
TestFile("models/lora/test_lora_tp.py", 116),
TestFile("test_data_parallelism.py", 73),
TestFile("test_dp_attention.py", 137),
TestFile("test_dp_attention_port_picking.py", 137),
TestFile("test_mla_tp.py", 170),
TestFile("test_moe_ep.py", 181),
TestFile("test_patch_torch.py", 19),
Expand Down
70 changes: 70 additions & 0 deletions test/srt/test_dp_attention_port_picking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestDPAttentionDP2TP2PortPicking(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
"--enable-torch-compile",
"--torch-compile-max-bs",
"2",
"--pick-free-dp-port",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.5)

def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)

metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)


if __name__ == "__main__":
unittest.main()