Skip to content

Commit f8e4609

Browse files
Fix prefill OOM error in the case of large page size (#5081)
1 parent 683707c commit f8e4609

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

python/sglang/srt/managers/schedule_policy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,10 @@ def add_one_req(
455455
total_tokens = req.extend_input_len + min(
456456
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
457457
)
458-
input_tokens = req.extend_input_len
458+
input_tokens = (
459+
-(-req.extend_input_len // self.tree_cache.page_size)
460+
* self.tree_cache.page_size
461+
)
459462
prefix_len = len(req.prefix_indices)
460463

461464
if total_tokens >= self.rem_total_tokens:
@@ -477,7 +480,10 @@ def add_one_req(
477480
req.last_node_global, req.prefix_indices
478481
)
479482
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
480-
input_tokens = req.extend_input_len
483+
input_tokens = (
484+
-(-req.extend_input_len // self.tree_cache.page_size)
485+
* self.tree_cache.page_size
486+
)
481487
prefix_len = len(req.prefix_indices)
482488

483489
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def init_memory_pool_and_cache(self):
502502
self.tree_cache = ChunkCache(
503503
req_to_token_pool=self.req_to_token_pool,
504504
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
505+
page_size=self.page_size,
505506
)
506507
else:
507508
if self.enable_hierarchical_cache:

python/sglang/srt/mem_cache/chunk_cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def __init__(
2424
self,
2525
req_to_token_pool: ReqToTokenPool,
2626
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27+
page_size: int,
2728
):
2829
self.req_to_token_pool = req_to_token_pool
2930
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31+
self.page_size = page_size
3032

3133
def reset(self):
3234
pass

0 commit comments

Comments
 (0)