Skip to content

host memory manager #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: feature/pd-disaggregation
Choose a base branch
from
41 changes: 33 additions & 8 deletions python/sglang/srt/managers/kv_transfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from sglang.srt.utils import get_zmq_socket
from sglang.srt.mem_cache.memory_pool import HostMemoryManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,13 +107,14 @@ def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:

class KVTransferAgent:
def __init__(self,
server_args: ServerArgs,
req_to_token_pool: ReqToTokenPool = None,
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None,
layer_num: int = 0,
tp_rank: int = 0,
attn_tp_cpu_group: torch.distributed.ProcessGroup = None,
device: str = "cpu"):
server_args: ServerArgs,
host_memory_manager: HostMemoryManager,
req_to_token_pool: ReqToTokenPool = None,
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None,
layer_num: int = 0,
tp_rank: int = 0,
attn_tp_cpu_group: torch.distributed.ProcessGroup = None,
device: str = "cpu"):
self.kv_buffer = {}
self.layer_num = layer_num
self.tp_rank = tp_rank
Expand Down Expand Up @@ -154,6 +156,7 @@ def __init__(self,
self.device = device
self.engine = TransferEngine(
self.addr, server_args.kv_transfer_config.transfer_engine_metadata_server, server_args.kv_transfer_config.transfer_engine_rdma_device)
self.host_memory_manager = host_memory_manager

def set_kv_buffer(self, req: Req) -> int:
if self.attn_tp_rank != 0:
Expand All @@ -167,9 +170,26 @@ def set_kv_buffer(self, req: Req) -> int:
for i in range(self.layer_num)]
)
kv_cache = safetensors_save({"tensor": flatten.to(self.device)})
self.kv_buffer[req.rid] = kv_cache
self.host_memory_manager.store_data(req.rid, kv_cache, 0)
return len(kv_cache)

# when prefill node receive kv transfer request
def _handle_kv_transfer_fetch(self, req: KVTransferFetch):
kv_cache = self.host_memory_manager.retrieve_data(req.rid)
if kv_cache is None:
logger.error(f"Failed to retrieve KV cache for request {req.rid} from host memory")
self.send_to_pd_disagg_controller.send_pyobj(KVTransferAck(req.rid, req.dst_addr, 1))
return

kv_cache_length = len(kv_cache)
src_ptr = self._allocate_transfer_kv_buffer(kv_cache_length)
self._write_bytes_to_buffer(src_ptr, kv_cache, kv_cache_length)
op_write = 1
self.engine.transfer_sync(req.dst_addr, src_ptr, req.dst_ptr, kv_cache_length, op_write)
self.send_to_pd_disagg_controller.send_pyobj(KVTransferAck(req.rid, req.dst_addr, 0))
self._free_transfer_kv_buffer(src_ptr, kv_cache_length)
self.host_memory_manager.free(req.rid)

def get_kv_buffer(self, req: Req) -> torch.Tensor:
if self.attn_tp_rank == 0:
dst_ptr = self._allocate_transfer_kv_buffer(req.kv_cache_length)
Expand Down Expand Up @@ -240,6 +260,7 @@ def _handle_kv_transfer_fetch(self, req: KVTransferFetch):
self._free_transfer_kv_buffer(src_ptr, kv_cache_length)
del self.kv_buffer[req.rid]


def _allocate_transfer_kv_buffer(self, length: int) -> int:
return self.engine.allocate_managed_buffer(length)

Expand All @@ -261,3 +282,7 @@ def event_loop(self):
self._handle_kv_transfer_fetch(recv_obj)
else:
raise ValueError(f"Unknown message type: {type(recv_obj)}")

def free_kv_buffer(self, req_id: str):
if self.host_memory_manager is not None:
self.host_memory_manager.free(req_id)
30 changes: 28 additions & 2 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.mem_cache.memory_pool import HostMemoryManager

# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
Expand Down Expand Up @@ -275,6 +276,7 @@ def __init__(
tree_cache: BasePrefixCache,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
running_batch: ScheduleBatch,
host_memory_manager: HostMemoryManager,
new_token_ratio: float,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
Expand All @@ -283,6 +285,7 @@ def __init__(
self.tree_cache = tree_cache
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch
self.host_memory_manager = host_memory_manager
self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
Expand Down Expand Up @@ -353,6 +356,11 @@ def add_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]

if self.host_memory_manager is not None:
if not self.host_memory_manager.allocate(req.rid, req.extend_input_len):
return req if truncated else None

self.can_run_list.append(req)
self._prefill_one_req(
0,
Expand Down Expand Up @@ -424,7 +432,13 @@ def add_req_state(r, insert_sort=False):
self.rem_chunk_tokens is None # chunked prefill is disabled
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
):
# Non-chunked prefill
if self.host_memory_manager is not None:
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
)
if not self.host_memory_manager.allocate(req.rid, total_tokens):
return AddReqResult.NO_TOKEN

self.can_run_list.append(req)
self._prefill_one_req(
0,
Expand All @@ -440,6 +454,11 @@ def add_req_state(r, insert_sort=False):

req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]

if self.host_memory_manager is not None:
if not self.host_memory_manager.allocate(req.rid, trunc_len):
return AddReqResult.NO_TOKEN

self.can_run_list.append(req)
self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0)
Expand Down Expand Up @@ -481,7 +500,10 @@ def add_one_req(
prefix_len = len(req.prefix_indices)

if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
if self.host_memory_manager is not None:
if not self.host_memory_manager.allocate(req.rid, total_tokens):
return AddReqResult.NO_TOKEN

self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
Expand All @@ -502,6 +524,10 @@ def add_one_req(
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]

if self.host_memory_manager is not None:
if not self.host_memory_manager.allocate(req.rid, trunc_len):
return AddReqResult.NO_TOKEN

self.can_run_list.append(req)
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
Expand Down
30 changes: 23 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.memory_pool import HostMemoryManager
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
Expand Down Expand Up @@ -180,7 +181,6 @@ def __init__(
self.dp_size,
)
)

# Init inter-process communication
context = zmq.Context(2)
if self.attn_tp_rank == 0:
Expand Down Expand Up @@ -235,6 +235,7 @@ def __init__(
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
self.cell_size = self.tp_worker.get_cell_size()

# Launch a draft worker for speculative decoding
if self.spec_algorithm.is_eagle():
Expand All @@ -251,6 +252,17 @@ def __init__(
else:
self.draft_worker = None

# Launch host memory manager
self.host_memory_manager = HostMemoryManager(
server_args.limit_method,
server_args.limit_value,
server_args.reserve_memory_bytes,
server_args.enable_manager,
server_args.memory_monitor_interval,
server_args.pre_allocate,
self.cell_size,
)

# Get token and memory info from the model worker
(
self.max_total_num_tokens,
Expand Down Expand Up @@ -357,6 +369,7 @@ def __init__(
# Init kv transfer agent
self.kv_transfer_agent = KVTransferAgent(
server_args,
self.host_memory_manager,
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
self.model_config.num_hidden_layers,
Expand Down Expand Up @@ -637,6 +650,7 @@ def process_input_requests(self, recv_reqs: List):
print(f"process_input_requests recv_req: {recv_req}")

output = self._request_dispatcher(recv_req)

if output is not None:
if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None:
Expand Down Expand Up @@ -903,7 +917,7 @@ def log_prefill_stats(
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
self.metrics_collector.log_stats(self.stats)

def log_recover_stats(
self,
adder: PrefillAdder,
Expand Down Expand Up @@ -1053,7 +1067,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
if self.server_args.kv_transfer_config is not None:
if self.last_batch:
self.running_batch = self.last_batch

if self.server_args.kv_transfer_config.role == "decode":
ret = self.recover_new_prefilled_batch()
if not self.running_batch.is_empty():
Expand All @@ -1067,12 +1081,12 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
ret = None
else:
ret = self.get_new_batch_prefill()

if self.server_args.enable_dp_attention:
ret, _ = self.prepare_dp_attn_batch(ret)

return ret

# Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
Expand Down Expand Up @@ -1115,13 +1129,13 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
ret, _ = self.prepare_dp_attn_batch(ret)

return ret

def recover_new_prefilled_batch(self) -> Optional[ScheduleBatch]:
if (
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
) and self.chunked_req is None:
return None

running_bs = len(self.running_batch.reqs) if self.running_batch else 0
if running_bs >= self.max_running_requests:
self.running_batch.batch_is_full = True
Expand All @@ -1133,6 +1147,7 @@ def recover_new_prefilled_batch(self) -> Optional[ScheduleBatch]:
self.tree_cache,
self.token_to_kv_pool_allocator,
self.running_batch,
self.host_memory_manager,
self.new_token_ratio,
self.max_prefill_tokens,
None,
Expand Down Expand Up @@ -1265,6 +1280,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.tree_cache,
self.token_to_kv_pool_allocator,
self.running_batch,
self.host_memory_manager,
self.new_token_ratio,
self.max_prefill_tokens,
self.chunked_prefill_size,
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
cell_size: int = 0,
):
# Parse args
self.tp_rank = tp_rank
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
)
self.cell_size = self.model_runner.get_cell_size()
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
Expand Down Expand Up @@ -131,6 +133,9 @@ def __init__(
self.model_runner.tp_group.cpu_group,
)[0]
set_random_seed(self.random_seed)

def get_cell_size(self):
return self.cell_size

def get_worker_info(self):
return (
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
):
# Load the model
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.cell_size = self.worker.get_cell_size()
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
self.gpu_id = gpu_id
Expand All @@ -85,6 +86,9 @@ def __init__(
if self.device == "cpu":
self.scheduler_stream.synchronize = lambda: None # No-op for CPU

def get_cell_size(self):
return self.cell_size

def get_worker_info(self):
return self.worker.get_worker_info()

Expand Down
Loading
Loading