Skip to content

Commit 7b25e15

Browse files
xiezhq-hermannLayssy
authored andcommitted
Fix prefill OOM error in the case of large page size (sgl-project#5081)
1 parent c860ba4 commit 7b25e15

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

python/sglang/srt/managers/schedule_policy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ def add_one_req(
463463
total_tokens = req.extend_input_len + min(
464464
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
465465
)
466-
input_tokens = req.extend_input_len
466+
input_tokens = (
467+
-(-req.extend_input_len // self.tree_cache.page_size)
468+
* self.tree_cache.page_size
469+
)
467470
prefix_len = len(req.prefix_indices)
468471

469472
if total_tokens >= self.rem_total_tokens:
@@ -540,7 +543,10 @@ def add_one_req(
540543
)
541544

542545
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
543-
input_tokens = req.extend_input_len
546+
input_tokens = (
547+
-(-req.extend_input_len // self.tree_cache.page_size)
548+
* self.tree_cache.page_size
549+
)
544550
prefix_len = len(req.prefix_indices)
545551

546552
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def init_memory_pool_and_cache(self):
506506
self.tree_cache = ChunkCache(
507507
req_to_token_pool=self.req_to_token_pool,
508508
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
509+
page_size=self.page_size,
509510
)
510511
else:
511512
if self.enable_hierarchical_cache:

python/sglang/srt/mem_cache/chunk_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def __init__(
2424
self,
2525
req_to_token_pool: ReqToTokenPool,
2626
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27+
page_size: int,
2728
):
2829
self.req_to_token_pool = req_to_token_pool
2930
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31+
self.page_size = page_size
3032

3133
def reset(self):
3234
pass

0 commit comments

Comments
 (0)