@@ -29,15 +29,17 @@ def __init__(
29
29
tp_cache_group : torch .distributed .ProcessGroup ,
30
30
page_size : int ,
31
31
hicache_ratio : float ,
32
+ hicache_size : int ,
33
+ hicache_write_policy : str ,
32
34
):
33
35
self .kv_cache = token_to_kv_pool_allocator .get_kvcache ()
34
36
if isinstance (self .kv_cache , MHATokenToKVPool ):
35
37
self .token_to_kv_pool_host = MHATokenToKVPoolHost (
36
- self .kv_cache , hicache_ratio , page_size
38
+ self .kv_cache , hicache_ratio , hicache_size , page_size
37
39
)
38
40
elif isinstance (self .kv_cache , MLATokenToKVPool ):
39
41
self .token_to_kv_pool_host = MLATokenToKVPoolHost (
40
- self .kv_cache , hicache_ratio , page_size
42
+ self .kv_cache , hicache_ratio , hicache_size , page_size
41
43
)
42
44
else :
43
45
raise ValueError (f"HiRadixCache only supports MHA and MLA yet" )
@@ -50,14 +52,17 @@ def __init__(
50
52
self .token_to_kv_pool_host ,
51
53
page_size ,
52
54
load_cache_event = self .load_cache_event ,
55
+ write_policy = hicache_write_policy ,
53
56
)
54
57
55
58
# record the nodes with ongoing write through
56
59
self .ongoing_write_through = {}
57
60
# record the node segments with ongoing load back
58
61
self .ongoing_load_back = {}
59
62
# todo: dynamically adjust the threshold
60
- self .write_through_threshold = 1
63
+ self .write_through_threshold = (
64
+ 1 if hicache_write_policy == "write_through" else 3
65
+ )
61
66
self .load_back_threshold = 10
62
67
super ().__init__ (
63
68
req_to_token_pool , token_to_kv_pool_allocator , page_size , disable = False
@@ -76,7 +81,7 @@ def get_height(self, node: TreeNode):
76
81
height += 1
77
82
return height
78
83
79
- def write_backup (self , node : TreeNode ):
84
+ def write_backup (self , node : TreeNode , write_back = False ):
80
85
host_indices = self .cache_controller .write (
81
86
device_indices = node .value ,
82
87
node_id = node .id ,
@@ -90,21 +95,29 @@ def write_backup(self, node: TreeNode):
90
95
if host_indices is not None :
91
96
node .host_value = host_indices
92
97
self .ongoing_write_through [node .id ] = node
93
- self .inc_lock_ref (node )
98
+ if not write_back :
99
+ # no need to lock nodes if write back
100
+ self .inc_lock_ref (node )
94
101
else :
95
102
return 0
96
103
97
104
return len (host_indices )
98
105
99
106
def inc_hit_count (self , node : TreeNode ):
100
- if self .cache_controller .write_policy != "write_through_selective " :
107
+ if node . backuped or self .cache_controller .write_policy == "write_back " :
101
108
return
102
109
node .hit_count += 1
103
- if node .host_value is None and node . hit_count > self .write_through_threshold :
110
+ if node .hit_count >= self .write_through_threshold :
104
111
self .write_backup (node )
105
112
node .hit_count = 0
106
113
107
- def writing_check (self ):
114
+ def writing_check (self , write_back = False ):
115
+ if write_back :
116
+ # blocking till all write back complete
117
+ while len (self .ongoing_write_through ) > 0 :
118
+ ack_id = self .cache_controller .ack_write_queue .get ()
119
+ del self .ongoing_write_through [ack_id ]
120
+ return
108
121
queue_size = torch .tensor (
109
122
self .cache_controller .ack_write_queue .qsize (), dtype = torch .int
110
123
)
@@ -143,29 +156,25 @@ def evict(self, num_tokens: int):
143
156
heapq .heapify (leaves )
144
157
145
158
num_evicted = 0
146
- pending_nodes = []
159
+ write_back_nodes = []
147
160
while num_evicted < num_tokens and len (leaves ):
148
161
x = heapq .heappop (leaves )
149
162
150
163
if x .lock_ref > 0 :
151
164
continue
152
165
153
- if x . host_value is None :
166
+ if not x . backuped :
154
167
if self .cache_controller .write_policy == "write_back" :
155
- num_evicted += self .write_backup (x )
156
- pending_nodes .append (x )
157
- elif self .cache_controller .write_policy == "write_through_selective" :
158
- num_evicted += self ._evict_write_through_selective (x )
168
+ # write to host if the node is not backuped
169
+ num_evicted += self .write_backup (x , write_back = True )
170
+ write_back_nodes .append (x )
159
171
else :
160
- assert (
161
- self .cache_controller .write_policy != "write_through"
162
- ), "write_through should be inclusive"
163
- raise NotImplementedError
172
+ num_evicted += self ._evict_regular (x )
164
173
else :
165
- num_evicted += self ._evict_write_through (x )
174
+ num_evicted += self ._evict_backuped (x )
166
175
167
176
for child in x .parent .children .values ():
168
- if child in pending_nodes :
177
+ if child in write_back_nodes :
169
178
continue
170
179
if not child .evicted :
171
180
break
@@ -174,23 +183,20 @@ def evict(self, num_tokens: int):
174
183
heapq .heappush (leaves , x .parent )
175
184
176
185
if self .cache_controller .write_policy == "write_back" :
177
- # blocking till all write back complete
178
- while len (self .ongoing_write_through ) > 0 :
179
- self .writing_check ()
180
- time .sleep (0.1 )
181
- for node in pending_nodes :
182
- assert node .host_value is not None
183
- self ._evict_write_through (node )
186
+ self .writing_check (write_back = True )
187
+ for node in write_back_nodes :
188
+ assert node .backuped
189
+ self ._evict_backuped (node )
184
190
185
- def _evict_write_through (self , node : TreeNode ):
191
+ def _evict_backuped (self , node : TreeNode ):
186
192
# evict a node already written to host
187
193
num_evicted = self .cache_controller .evict_device (node .value , node .host_value )
188
194
assert num_evicted > 0
189
195
self .evictable_size_ -= num_evicted
190
196
node .value = None
191
197
return num_evicted
192
198
193
- def _evict_write_through_selective (self , node : TreeNode ):
199
+ def _evict_regular (self , node : TreeNode ):
194
200
# evict a node not initiated write to host
195
201
self .cache_controller .mem_pool_device_allocator .free (node .value )
196
202
num_evicted = len (node .value )
@@ -339,11 +345,13 @@ def _match_prefix_helper(self, node: TreeNode, key: List):
339
345
prefix_len = self .key_match_fn (child .key , key )
340
346
if prefix_len < len (child .key ):
341
347
new_node = self ._split_node (child .key , child , prefix_len )
348
+ self .inc_hit_count (new_node )
342
349
if not new_node .evicted :
343
350
value .append (new_node .value )
344
351
node = new_node
345
352
break
346
353
else :
354
+ self .inc_hit_count (child )
347
355
if not child .evicted :
348
356
value .append (child .value )
349
357
node = child
@@ -369,7 +377,7 @@ def _split_node(self, key, child: TreeNode, split_len: int):
369
377
else :
370
378
new_node .value = child .value [:split_len ]
371
379
child .value = child .value [split_len :]
372
- if child .host_value is not None :
380
+ if child .backuped :
373
381
new_node .host_value = child .host_value [:split_len ]
374
382
child .host_value = child .host_value [split_len :]
375
383
child .parent = new_node
@@ -426,8 +434,8 @@ def _insert_helper(self, node: TreeNode, key: List, value):
426
434
node .children [child_key ] = new_node
427
435
self .evictable_size_ += len (value )
428
436
429
- if self .cache_controller .write_policy == "write_through " :
430
- self .write_backup (new_node )
437
+ if self .cache_controller .write_policy != "write_back " :
438
+ self .inc_hit_count (new_node )
431
439
return total_prefix_length
432
440
433
441
def _collect_leaves_device (self ):
0 commit comments