Skip to content

Better PD initialization #5751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 74 additions & 23 deletions python/sglang/srt/disaggregation/mini_lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,45 @@
"""

import asyncio
import dataclasses
import logging
import random
import urllib
from itertools import chain
from typing import List
from typing import List, Optional

import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse

from sglang.srt.disaggregation.utils import PDRegistryRequest


def setup_logger():
logger = logging.getLogger("pdlb")
logger.setLevel(logging.INFO)

formatter = logging.Formatter(
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)

return logger


logger = setup_logger()


@dataclasses.dataclass
class PrefillConfig:
def __init__(self, url: str, bootstrap_port: int):
self.url = url
self.bootstrap_port = bootstrap_port
url: str
bootstrap_port: Optional[int] = None


class MiniLoadBalancer:
Expand All @@ -28,6 +51,10 @@ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[st
self.decode_servers = decode_servers

def select_pair(self):
# TODO: return some message instead of panic
assert len(self.prefill_configs) > 0, "No prefill servers available"
assert len(self.decode_servers) > 0, "No decode servers available"

prefill_config = random.choice(self.prefill_configs)
decode_server = random.choice(self.decode_servers)
return prefill_config.url, prefill_config.bootstrap_port, decode_server
Expand All @@ -47,7 +74,7 @@ async def generate(
session.post(f"{decode_server}/{endpoint}", json=modified_request),
]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
_, decode_response = await asyncio.gather(*tasks)

return ORJSONResponse(
content=await decode_response.json(),
Expand Down Expand Up @@ -268,6 +295,32 @@ async def get_models():
raise HTTPException(status_code=500, detail=str(e))


@app.post("/register")
async def register(obj: PDRegistryRequest):
if obj.mode == "prefill":
load_balancer.prefill_configs.append(
PrefillConfig(obj.registry_url, obj.bootstrap_port)
)
logger.info(
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
)
elif obj.mode == "decode":
load_balancer.decode_servers.append(obj.registry_url)
logger.info(f"Registered decode server: {obj.registry_url}")
else:
raise HTTPException(
status_code=400,
detail="Invalid mode. Must be either PREFILL or DECODE.",
)

logger.info(
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
f"#Decode servers: {len(load_balancer.decode_servers)}"
)

return Response(status_code=200)


def run(prefill_configs, decode_addrs, host, port):
global load_balancer
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
Expand All @@ -279,15 +332,16 @@ def run(prefill_configs, decode_addrs, host, port):

parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
parser.add_argument(
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
)
parser.add_argument(
"--prefill-bootstrap-ports",
help="Comma-separated bootstrap ports for prefill servers",
default="8998",
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
)
parser.add_argument(
"--decode", required=True, help="Comma-separated URLs for decode servers"
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
Expand All @@ -297,22 +351,19 @@ def run(prefill_configs, decode_addrs, host, port):
)
args = parser.parse_args()

prefill_urls = args.prefill.split(",")
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]

if len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(prefill_urls)
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(prefill_urls):
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
exit(1)

prefill_configs = []
for url, port in zip(prefill_urls, bootstrap_ports):
prefill_configs.append(PrefillConfig(url, port))

decode_addrs = args.decode.split(",")
prefill_configs = [
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]

run(prefill_configs, decode_addrs, args.host, args.port)
run(prefill_configs, args.decode, args.host, args.port)
45 changes: 44 additions & 1 deletion python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import dataclasses
import warnings
from collections import deque
from enum import Enum
from typing import List
from typing import List, Optional

import numpy as np
import requests
import torch
import torch.distributed as dist

from sglang.srt.utils import get_ip


class DisaggregationMode(Enum):
NULL = "null"
Expand Down Expand Up @@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
def kv_to_page_num(num_kv_indices: int, page_size: int):
# ceil(num_kv_indices / page_size)
return (num_kv_indices + page_size - 1) // page_size


@dataclasses.dataclass
class PDRegistryRequest:
"""A request to register a machine itself to the LB."""

mode: str
registry_url: str
bootstrap_port: Optional[int] = None

def __post_init__(self):
if self.mode == "prefill" and self.bootstrap_port is None:
raise ValueError("Bootstrap port must be set in PREFILL mode.")
elif self.mode == "decode" and self.bootstrap_port is not None:
raise ValueError("Bootstrap port must not be set in DECODE mode.")
elif self.mode not in ["prefill", "decode"]:
raise ValueError(
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
)


def register_disaggregation_server(
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
):
boostrap_port = bootstrap_port if mode == "prefill" else None
registry_request = PDRegistryRequest(
mode=mode,
registry_url=f"http://{get_ip()}:{server_port}",
bootstrap_port=boostrap_port,
)
res = requests.post(
f"{pdlb_url}/register",
json=dataclasses.asdict(registry_request),
)
if res.status_code != 200:
warnings.warn(
f"Failed to register disaggregation server: {res.status_code} {res.text}"
)
13 changes: 12 additions & 1 deletion python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse

from sglang.srt.disaggregation.utils import FakeBootstrapHost
from sglang.srt.disaggregation.utils import (
FakeBootstrapHost,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
Expand Down Expand Up @@ -871,5 +874,13 @@ def _wait_and_warmup(
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())

if server_args.pdlb_url is not None:
register_disaggregation_server(
server_args.disaggregation_mode,
server_args.port,
server_args.disaggregation_bootstrap_port,
server_args.pdlb_url,
)

if launch_callback is not None:
launch_callback()
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,10 @@ def handle_generate_request(
)
custom_logit_processor = None

if recv_req.bootstrap_port is None:
# Use default bootstrap port
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

req = Req(
recv_req.rid,
recv_req.input_text,
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake"
disaggregation_ib_device: Optional[str] = None
pdlb_url: Optional[str] = None

def __post_init__(self):
# Expert parallelism
Expand Down Expand Up @@ -1254,6 +1255,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
)
parser.add_argument(
"--pdlb-url",
type=str,
default=None,
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
Loading