16
16
TokenToKVPoolAllocator ,
17
17
)
18
18
from sglang .srt .mem_cache .radix_cache import RadixCache , TreeNode
19
- from sglang .srt .mem_cache .radix_cache import _key_match_page_size1 as _key_match
20
19
21
20
logger = logging .getLogger (__name__ )
22
21
@@ -31,29 +30,25 @@ def __init__(
31
30
page_size : int ,
32
31
hicache_ratio : float ,
33
32
):
34
- if page_size != 1 :
35
- raise ValueError (
36
- "Page size larger than 1 is not yet supported in HiRadixCache."
37
- )
38
33
self .kv_cache = token_to_kv_pool_allocator .get_kvcache ()
39
34
if isinstance (self .kv_cache , MHATokenToKVPool ):
40
35
self .token_to_kv_pool_host = MHATokenToKVPoolHost (
41
- self .kv_cache , hicache_ratio
36
+ self .kv_cache , hicache_ratio , page_size
42
37
)
43
38
elif isinstance (self .kv_cache , MLATokenToKVPool ):
44
39
self .token_to_kv_pool_host = MLATokenToKVPoolHost (
45
- self .kv_cache , hicache_ratio
40
+ self .kv_cache , hicache_ratio , page_size
46
41
)
47
42
else :
48
- raise ValueError (f"Only MHA and MLA supports swap kv_cache to host. " )
43
+ raise ValueError (f"HiRadixCache only supports MHA and MLA yet " )
49
44
50
45
self .tp_group = tp_cache_group
51
- self .page_size = page_size
52
46
53
47
self .load_cache_event = threading .Event ()
54
48
self .cache_controller = HiCacheController (
55
49
token_to_kv_pool_allocator ,
56
50
self .token_to_kv_pool_host ,
51
+ page_size ,
57
52
load_cache_event = self .load_cache_event ,
58
53
)
59
54
@@ -65,7 +60,7 @@ def __init__(
65
60
self .write_through_threshold = 1
66
61
self .load_back_threshold = 10
67
62
super ().__init__ (
68
- req_to_token_pool , token_to_kv_pool_allocator , self . page_size , disable = False
63
+ req_to_token_pool , token_to_kv_pool_allocator , page_size , disable = False
69
64
)
70
65
71
66
def reset (self ):
@@ -299,18 +294,26 @@ def init_load_back(
299
294
300
295
return last_node , prefix_indices
301
296
302
- def read_to_load_cache (self ):
297
+ def ready_to_load_cache (self ):
303
298
self .load_cache_event .set ()
304
299
305
300
def match_prefix (self , key : List [int ], include_evicted = False , ** kwargs ):
306
- if self .disable :
307
- return [], self .root_node
301
+ empty_value = torch .empty ((0 ,), dtype = torch .int64 , device = self .device )
302
+ if self .disable or len (key ) == 0 :
303
+ if include_evicted :
304
+ return empty_value , self .root_node , self .root_node
305
+ else :
306
+ return empty_value , self .root_node
307
+
308
+ if self .page_size != 1 :
309
+ page_aligned_len = len (key ) // self .page_size * self .page_size
310
+ key = key [:page_aligned_len ]
308
311
309
312
value , last_node = self ._match_prefix_helper (self .root_node , key )
310
313
if value :
311
314
value = torch .cat (value )
312
315
else :
313
- value = torch . tensor ([], dtype = torch . int64 )
316
+ value = empty_value
314
317
315
318
last_node_global = last_node
316
319
while last_node .evicted :
@@ -323,11 +326,13 @@ def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
323
326
324
327
def _match_prefix_helper (self , node : TreeNode , key : List ):
325
328
node .last_access_time = time .time ()
329
+ child_key = self .get_child_key_fn (key )
326
330
value = []
327
- while len (key ) > 0 and key [0 ] in node .children .keys ():
328
- child = node .children [key [0 ]]
331
+
332
+ while len (key ) > 0 and child_key in node .children .keys ():
333
+ child = node .children [child_key ]
329
334
child .last_access_time = time .time ()
330
- prefix_len = _key_match (child .key , key )
335
+ prefix_len = self . key_match_fn (child .key , key )
331
336
if prefix_len < len (child .key ):
332
337
new_node = self ._split_node (child .key , child , prefix_len )
333
338
if not new_node .evicted :
@@ -339,12 +344,16 @@ def _match_prefix_helper(self, node: TreeNode, key: List):
339
344
value .append (child .value )
340
345
node = child
341
346
key = key [prefix_len :]
347
+
348
+ if len (key ):
349
+ child_key = self .get_child_key_fn (key )
350
+
342
351
return value , node
343
352
344
353
def _split_node (self , key , child : TreeNode , split_len : int ):
345
354
# child node split into new_node -> child
346
355
new_node = TreeNode ()
347
- new_node .children = {key [split_len ] : child }
356
+ new_node .children = {self . get_child_key_fn ( key [split_len :]) : child }
348
357
new_node .parent = child .parent
349
358
new_node .lock_ref = child .lock_ref
350
359
new_node .key = child .key [:split_len ]
@@ -361,60 +370,61 @@ def _split_node(self, key, child: TreeNode, split_len: int):
361
370
child .host_value = child .host_value [split_len :]
362
371
child .parent = new_node
363
372
child .key = child .key [split_len :]
364
- new_node .parent .children [key [ 0 ] ] = new_node
373
+ new_node .parent .children [self . get_child_key_fn ( key ) ] = new_node
365
374
return new_node
366
375
367
376
def _insert_helper (self , node : TreeNode , key : List , value ):
368
377
node .last_access_time = time .time ()
369
378
if len (key ) == 0 :
370
379
return 0
371
380
372
- if key [0 ] in node .children .keys ():
373
- child = node .children [key [0 ]]
374
- prefix_len = _key_match (child .key , key )
381
+ child_key = self .get_child_key_fn (key )
382
+ total_prefix_length = 0
383
+
384
+ while len (key ) > 0 and child_key in node .children .keys ():
385
+ node = node .children [child_key ]
386
+ node .last_access_time = time .time ()
387
+ prefix_len = self .key_match_fn (node .key , key )
375
388
376
- if prefix_len == len (child .key ):
377
- if child .evicted :
389
+ if prefix_len == len (node .key ):
390
+ if node .evicted :
378
391
# change the reference if the node is evicted
379
392
# this often happens in the case of KV cache recomputation
380
- child .value = value [:prefix_len ]
381
- self .token_to_kv_pool_host .update_synced (child .host_value )
382
- self .evictable_size_ += len (value [:prefix_len ])
383
- return self ._insert_helper (
384
- child , key [prefix_len :], value [prefix_len :]
385
- )
393
+ node .value = value [:prefix_len ]
394
+ self .token_to_kv_pool_host .update_synced (node .host_value )
395
+ self .evictable_size_ += len (node .value )
386
396
else :
387
- self .inc_hit_count (child )
388
- return prefix_len + self ._insert_helper (
389
- child , key [prefix_len :], value [prefix_len :]
390
- )
391
-
392
- # partial match, split the node
393
- new_node = self ._split_node (child .key , child , prefix_len )
394
- if new_node .evicted :
395
- new_node .value = value [:prefix_len ]
396
- self .token_to_kv_pool_host .update_synced (new_node .host_value )
397
- self .evictable_size_ += len (new_node .value )
398
- return self ._insert_helper (
399
- new_node , key [prefix_len :], value [prefix_len :]
400
- )
397
+ self .inc_hit_count (node )
398
+ total_prefix_length += prefix_len
401
399
else :
402
- self .inc_hit_count (new_node )
403
- return prefix_len + self ._insert_helper (
404
- new_node , key [prefix_len :], value [prefix_len :]
405
- )
400
+ # partial match, split the node
401
+ new_node = self ._split_node (node .key , node , prefix_len )
402
+ if new_node .evicted :
403
+ new_node .value = value [:prefix_len ]
404
+ self .token_to_kv_pool_host .update_synced (new_node .host_value )
405
+ self .evictable_size_ += len (new_node .value )
406
+ else :
407
+ self .inc_hit_count (new_node )
408
+ total_prefix_length += prefix_len
409
+ node = new_node
410
+
411
+ key = key [prefix_len :]
412
+ value = value [prefix_len :]
413
+
414
+ if len (key ):
415
+ child_key = self .get_child_key_fn (key )
406
416
407
417
if len (key ):
408
418
new_node = TreeNode ()
409
419
new_node .parent = node
410
420
new_node .key = key
411
421
new_node .value = value
412
- node .children [key [ 0 ] ] = new_node
422
+ node .children [child_key ] = new_node
413
423
self .evictable_size_ += len (value )
414
424
415
425
if self .cache_controller .write_policy == "write_through" :
416
426
self .write_backup (new_node )
417
- return 0
427
+ return total_prefix_length
418
428
419
429
def _collect_leaves_device (self ):
420
430
def is_leaf (node ):
0 commit comments