Skip to content

Commit ab4b560

Browse files
authored
[PD] Support page size > 1 (#5561)
1 parent 20f1c8e commit ab4b560

File tree

4 files changed

+58
-9
lines changed

4 files changed

+58
-9
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ReqToMetadataIdxAllocator,
3636
TransferBackend,
3737
get_kv_class,
38+
kv_to_page_indices,
3839
poll_and_all_reduce,
3940
)
4041
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -205,7 +206,10 @@ def pop_preallocated(self) -> List[DecodeRequest]:
205206
self.req_to_metadata_buffer_idx_allocator.alloc()
206207
)
207208
assert decode_req.metadata_buffer_index is not None
208-
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
209+
page_indices = kv_to_page_indices(
210+
kv_indices, self.token_to_kv_pool_allocator.page_size
211+
)
212+
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
209213
preallocated_reqs.append(decode_req)
210214
indices_to_remove.add(i)
211215

@@ -245,10 +249,30 @@ def _pre_alloc(self, req: Req) -> torch.Tensor:
245249
assert req_pool_indices is not None
246250

247251
req.req_pool_idx = req_pool_indices[0]
248-
kv_loc = self.token_to_kv_pool_allocator.alloc(
249-
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
250-
)
251-
252+
if self.token_to_kv_pool_allocator.page_size == 1:
253+
kv_loc = self.token_to_kv_pool_allocator.alloc(
254+
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
255+
)
256+
else:
257+
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
258+
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
259+
prefix_lens=torch.tensor(
260+
[0],
261+
dtype=torch.int64,
262+
device=self.token_to_kv_pool_allocator.device,
263+
),
264+
seq_lens=torch.tensor(
265+
[num_tokens],
266+
dtype=torch.int64,
267+
device=self.token_to_kv_pool_allocator.device,
268+
),
269+
last_loc=torch.tensor(
270+
[-1],
271+
dtype=torch.int64,
272+
device=self.token_to_kv_pool_allocator.device,
273+
),
274+
extend_num_tokens=num_tokens,
275+
)
252276
assert kv_loc is not None
253277

254278
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)

python/sglang/srt/disaggregation/prefill.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
ReqToMetadataIdxAllocator,
3232
TransferBackend,
3333
get_kv_class,
34+
kv_to_page_indices,
35+
kv_to_page_num,
3436
poll_and_all_reduce,
3537
)
3638
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
@@ -154,7 +156,8 @@ def pop_bootstrapped(self) -> List[Req]:
154156
self.req_to_metadata_buffer_idx_allocator.alloc()
155157
)
156158
assert req.metadata_buffer_index is not None
157-
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
159+
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
160+
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
158161

159162
bootstrapped_reqs.append(req)
160163
indices_to_remove.add(i)
@@ -300,4 +303,7 @@ def send_kv_chunk(
300303
req.metadata_buffer_index, token_id
301304
)
302305
is_last = token_id is not None
303-
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
306+
page_indices = kv_to_page_indices(
307+
kv_indices, self.token_to_kv_pool_allocator.page_size
308+
)
309+
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)

python/sglang/srt/disaggregation/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from enum import Enum
55
from typing import List
66

7+
import numpy as np
78
import torch
89
import torch.distributed as dist
910

@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
7374
}
7475
return class_mapping.get(class_type)
7576
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
77+
78+
79+
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
80+
# 1. The page is guaruanteed to be full except the last page.
81+
# 2. page index = kv_index // page_size
82+
# The return vector is kv_indices[::page_size] // page_size
83+
if page_size == 1: # shortcut
84+
return kv_indices
85+
return kv_indices[::page_size] // page_size
86+
87+
88+
def kv_to_page_num(num_kv_indices: int, page_size: int):
89+
# ceil(num_kv_indices / page_size)
90+
return (num_kv_indices + page_size - 1) // page_size

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,12 @@ def get_contiguous_buf_infos(self):
286286
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
287287
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
288288
kv_item_lens = [
289-
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
290-
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
289+
self.get_key_buffer(i)[0].nbytes * self.page_size
290+
for i in range(self.layer_num)
291+
] + [
292+
self.get_value_buffer(i)[0].nbytes * self.page_size
293+
for i in range(self.layer_num)
294+
]
291295
return kv_data_ptrs, kv_data_lens, kv_item_lens
292296

293297
# Todo: different memory layout

0 commit comments

Comments
 (0)