@@ -73,22 +73,23 @@ class CacheAgnosticPolicy(Enum):
73
73
class SchedulePolicy :
74
74
Policy = Union [CacheAwarePolicy , CacheAgnosticPolicy ]
75
75
76
- def __init__ (
77
- self ,
78
- policy : str ,
79
- tree_cache : BasePrefixCache ,
80
- enable_hierarchical_cache : bool = False ,
81
- ):
76
+ def __init__ (self , policy : str , tree_cache : BasePrefixCache ):
82
77
self .policy = self ._validate_and_adjust_policy (policy , tree_cache )
83
78
self .tree_cache = tree_cache
84
- self .enable_hierarchical_cache = enable_hierarchical_cache
85
79
86
80
# It is used to find the matching prefix for in-batch prefix caching.
87
81
self .waiting_queue_radix_tree = RadixCache (
88
- req_to_token_pool = None , token_to_kv_pool_allocator = None , disable = False
82
+ req_to_token_pool = None ,
83
+ token_to_kv_pool_allocator = None ,
84
+ page_size = 1 ,
85
+ disable = False ,
89
86
)
90
87
91
88
def calc_priority (self , waiting_queue : List [Req ]) -> bool :
89
+ if self .policy == CacheAgnosticPolicy .FCFS :
90
+ # A shortcut for FCFS
91
+ return
92
+
92
93
policy = self ._determine_active_policy (waiting_queue )
93
94
94
95
prefix_computed = False
@@ -106,9 +107,7 @@ def calc_priority(self, waiting_queue: List[Req]) -> bool:
106
107
else :
107
108
raise ValueError (f"Unknown CacheAware Policy: { policy = } " )
108
109
else :
109
- if policy == CacheAgnosticPolicy .FCFS :
110
- pass
111
- elif policy == CacheAgnosticPolicy .LOF :
110
+ if policy == CacheAgnosticPolicy .LOF :
112
111
SchedulePolicy ._sort_by_longest_output (waiting_queue )
113
112
elif policy == CacheAgnosticPolicy .RANDOM :
114
113
SchedulePolicy ._sort_randomly (waiting_queue )
@@ -118,7 +117,7 @@ def calc_priority(self, waiting_queue: List[Req]) -> bool:
118
117
return prefix_computed
119
118
120
119
def _determine_active_policy (self , waiting_queue : List [Req ]) -> Policy :
121
- if len ( waiting_queue ) > 128 and self .policy == CacheAwarePolicy .LPM :
120
+ if self .policy == CacheAwarePolicy .LPM and len ( waiting_queue ) > 128 :
122
121
# Turn off the expensive prefix matching and sorting when the #queue is large.
123
122
return CacheAgnosticPolicy .FCFS
124
123
return self .policy
@@ -442,7 +441,7 @@ def add_req_state(r, insert_sort=False):
442
441
def add_one_req (
443
442
self , req : Req , has_chunked_req : bool , enable_hierarchical_cache : bool = False
444
443
):
445
- if req .sampling_params .ignore_eos and self .tree_cache . disable :
444
+ if req .sampling_params .ignore_eos and getattr ( self .tree_cache , " disable" , True ) :
446
445
return self .add_one_req_ignore_eos (req , has_chunked_req )
447
446
448
447
total_tokens = req .extend_input_len + min (
0 commit comments