Skip to content

Commit 3b48346

Browse files
committed
Support page size > 1
1 parent e35a93f commit 3b48346

17 files changed

+902
-325
lines changed

python/sglang/srt/managers/schedule_batch.py

Lines changed: 186 additions & 88 deletions
Large diffs are not rendered by default.

python/sglang/srt/managers/schedule_policy.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,23 @@ class CacheAgnosticPolicy(Enum):
7373
class SchedulePolicy:
7474
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
7575

76-
def __init__(
77-
self,
78-
policy: str,
79-
tree_cache: BasePrefixCache,
80-
enable_hierarchical_cache: bool = False,
81-
):
76+
def __init__(self, policy: str, tree_cache: BasePrefixCache):
8277
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
8378
self.tree_cache = tree_cache
84-
self.enable_hierarchical_cache = enable_hierarchical_cache
8579

8680
# It is used to find the matching prefix for in-batch prefix caching.
8781
self.waiting_queue_radix_tree = RadixCache(
88-
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
82+
req_to_token_pool=None,
83+
token_to_kv_pool_allocator=None,
84+
page_size=1,
85+
disable=False,
8986
)
9087

9188
def calc_priority(self, waiting_queue: List[Req]) -> bool:
89+
if self.policy == CacheAgnosticPolicy.FCFS:
90+
# A shortcut for FCFS
91+
return
92+
9293
policy = self._determine_active_policy(waiting_queue)
9394

9495
prefix_computed = False
@@ -106,9 +107,7 @@ def calc_priority(self, waiting_queue: List[Req]) -> bool:
106107
else:
107108
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
108109
else:
109-
if policy == CacheAgnosticPolicy.FCFS:
110-
pass
111-
elif policy == CacheAgnosticPolicy.LOF:
110+
if policy == CacheAgnosticPolicy.LOF:
112111
SchedulePolicy._sort_by_longest_output(waiting_queue)
113112
elif policy == CacheAgnosticPolicy.RANDOM:
114113
SchedulePolicy._sort_randomly(waiting_queue)
@@ -118,7 +117,7 @@ def calc_priority(self, waiting_queue: List[Req]) -> bool:
118117
return prefix_computed
119118

120119
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
121-
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
120+
if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128:
122121
# Turn off the expensive prefix matching and sorting when the #queue is large.
123122
return CacheAgnosticPolicy.FCFS
124123
return self.policy
@@ -442,7 +441,7 @@ def add_req_state(r, insert_sort=False):
442441
def add_one_req(
443442
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
444443
):
445-
if req.sampling_params.ignore_eos and self.tree_cache.disable:
444+
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
446445
return self.add_one_req_ignore_eos(req, has_chunked_req)
447446

448447
total_tokens = req.extend_input_len + min(

0 commit comments

Comments
 (0)