Skip to content

Commit c1d8045

Browse files
xiezhq-hermannjimoosciuc
authored andcommitted
Large page size aligned hierarchical caching (sgl-project#4581)
1 parent bc400e9 commit c1d8045

File tree

8 files changed

+242
-71
lines changed

8 files changed

+242
-71
lines changed

python/sglang/srt/managers/cache_controller.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,15 @@ def __init__(
149149
self,
150150
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
151151
mem_pool_host: HostKVCache,
152+
page_size: int,
152153
load_cache_event: threading.Event = None,
153154
write_policy: str = "write_through_selective",
154155
):
155156
self.mem_pool_device_allocator = token_to_kv_pool_allocator
156157
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
157158
self.mem_pool_host = mem_pool_host
158159
self.write_policy = write_policy
160+
self.page_size = page_size
159161

160162
self.load_cache_event = load_cache_event
161163
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -184,7 +186,12 @@ def __init__(
184186
self.load_stream = torch.cuda.Stream()
185187

186188
self.write_thread = threading.Thread(
187-
target=self.write_thread_func_buffer, daemon=True
189+
target=(
190+
self.write_thread_func_buffer
191+
if self.page_size == 1
192+
else self.write_thread_func_direct
193+
),
194+
daemon=True,
188195
)
189196
self.load_thread = threading.Thread(
190197
target=self.load_thread_func_layer_by_layer, daemon=True
@@ -205,7 +212,12 @@ def reset(self):
205212
self.ack_load_queue.queue.clear()
206213

207214
self.write_thread = threading.Thread(
208-
target=self.write_thread_func_buffer, daemon=True
215+
target=(
216+
self.write_thread_func_buffer
217+
if self.page_size == 1
218+
else self.write_thread_func_direct
219+
),
220+
daemon=True,
209221
)
210222
self.load_thread = threading.Thread(
211223
target=self.load_thread_func_layer_by_layer, daemon=True
@@ -260,10 +272,12 @@ def write_thread_func_direct(self):
260272
while not self.stop_event.is_set():
261273
try:
262274
operation = self.write_queue.get(block=True, timeout=1)
263-
operation.data = self.mem_pool_device.get_flat_data(
264-
operation.device_indices
275+
self.mem_pool_host.write_page_all_layers(
276+
operation.host_indices,
277+
operation.device_indices,
278+
self.mem_pool_device,
265279
)
266-
self.mem_pool_host.transfer(operation.host_indices, operation.data)
280+
self.write_stream.synchronize()
267281
self.mem_pool_host.complete_io(operation.host_indices)
268282
for node_id in operation.node_ids:
269283
if node_id != 0:
@@ -320,12 +334,21 @@ def load_thread_func_layer_by_layer(self):
320334

321335
self.layer_done_counter.reset()
322336
for i in range(self.mem_pool_host.layer_num):
323-
flat_data = self.mem_pool_host.get_flat_data_by_layer(
324-
batch_operation.host_indices, i
325-
)
326-
self.mem_pool_device.transfer_per_layer(
327-
batch_operation.device_indices, flat_data, i
328-
)
337+
if self.page_size == 1:
338+
flat_data = self.mem_pool_host.get_flat_data_by_layer(
339+
batch_operation.host_indices, i
340+
)
341+
self.mem_pool_device.transfer_per_layer(
342+
batch_operation.device_indices, flat_data, i
343+
)
344+
else:
345+
self.mem_pool_host.load_page_per_layer(
346+
batch_operation.host_indices,
347+
batch_operation.device_indices,
348+
self.mem_pool_device,
349+
i,
350+
)
351+
self.load_stream.synchronize()
329352
self.layer_done_counter.increment()
330353

331354
self.mem_pool_host.complete_io(batch_operation.host_indices)

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1522,7 +1522,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
15221522
]
15231523

15241524
if self.enable_hierarchical_cache:
1525-
self.tree_cache.read_to_load_cache()
1525+
self.tree_cache.ready_to_load_cache()
15261526

15271527
if adder.new_chunked_req is not None:
15281528
assert self.chunked_req is None

python/sglang/srt/mem_cache/hiradix_cache.py

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
TokenToKVPoolAllocator,
1717
)
1818
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
19-
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
2019

2120
logger = logging.getLogger(__name__)
2221

@@ -31,29 +30,25 @@ def __init__(
3130
page_size: int,
3231
hicache_ratio: float,
3332
):
34-
if page_size != 1:
35-
raise ValueError(
36-
"Page size larger than 1 is not yet supported in HiRadixCache."
37-
)
3833
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
3934
if isinstance(self.kv_cache, MHATokenToKVPool):
4035
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
41-
self.kv_cache, hicache_ratio
36+
self.kv_cache, hicache_ratio, page_size
4237
)
4338
elif isinstance(self.kv_cache, MLATokenToKVPool):
4439
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
45-
self.kv_cache, hicache_ratio
40+
self.kv_cache, hicache_ratio, page_size
4641
)
4742
else:
48-
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
43+
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
4944

5045
self.tp_group = tp_cache_group
51-
self.page_size = page_size
5246

5347
self.load_cache_event = threading.Event()
5448
self.cache_controller = HiCacheController(
5549
token_to_kv_pool_allocator,
5650
self.token_to_kv_pool_host,
51+
page_size,
5752
load_cache_event=self.load_cache_event,
5853
)
5954

@@ -65,7 +60,7 @@ def __init__(
6560
self.write_through_threshold = 1
6661
self.load_back_threshold = 10
6762
super().__init__(
68-
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
63+
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
6964
)
7065

7166
def reset(self):
@@ -299,18 +294,26 @@ def init_load_back(
299294

300295
return last_node, prefix_indices
301296

302-
def read_to_load_cache(self):
297+
def ready_to_load_cache(self):
303298
self.load_cache_event.set()
304299

305300
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
306-
if self.disable:
307-
return [], self.root_node
301+
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
302+
if self.disable or len(key) == 0:
303+
if include_evicted:
304+
return empty_value, self.root_node, self.root_node
305+
else:
306+
return empty_value, self.root_node
307+
308+
if self.page_size != 1:
309+
page_aligned_len = len(key) // self.page_size * self.page_size
310+
key = key[:page_aligned_len]
308311

309312
value, last_node = self._match_prefix_helper(self.root_node, key)
310313
if value:
311314
value = torch.cat(value)
312315
else:
313-
value = torch.tensor([], dtype=torch.int64)
316+
value = empty_value
314317

315318
last_node_global = last_node
316319
while last_node.evicted:
@@ -323,11 +326,13 @@ def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
323326

324327
def _match_prefix_helper(self, node: TreeNode, key: List):
325328
node.last_access_time = time.time()
329+
child_key = self.get_child_key_fn(key)
326330
value = []
327-
while len(key) > 0 and key[0] in node.children.keys():
328-
child = node.children[key[0]]
331+
332+
while len(key) > 0 and child_key in node.children.keys():
333+
child = node.children[child_key]
329334
child.last_access_time = time.time()
330-
prefix_len = _key_match(child.key, key)
335+
prefix_len = self.key_match_fn(child.key, key)
331336
if prefix_len < len(child.key):
332337
new_node = self._split_node(child.key, child, prefix_len)
333338
if not new_node.evicted:
@@ -339,12 +344,16 @@ def _match_prefix_helper(self, node: TreeNode, key: List):
339344
value.append(child.value)
340345
node = child
341346
key = key[prefix_len:]
347+
348+
if len(key):
349+
child_key = self.get_child_key_fn(key)
350+
342351
return value, node
343352

344353
def _split_node(self, key, child: TreeNode, split_len: int):
345354
# child node split into new_node -> child
346355
new_node = TreeNode()
347-
new_node.children = {key[split_len]: child}
356+
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
348357
new_node.parent = child.parent
349358
new_node.lock_ref = child.lock_ref
350359
new_node.key = child.key[:split_len]
@@ -361,60 +370,61 @@ def _split_node(self, key, child: TreeNode, split_len: int):
361370
child.host_value = child.host_value[split_len:]
362371
child.parent = new_node
363372
child.key = child.key[split_len:]
364-
new_node.parent.children[key[0]] = new_node
373+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
365374
return new_node
366375

367376
def _insert_helper(self, node: TreeNode, key: List, value):
368377
node.last_access_time = time.time()
369378
if len(key) == 0:
370379
return 0
371380

372-
if key[0] in node.children.keys():
373-
child = node.children[key[0]]
374-
prefix_len = _key_match(child.key, key)
381+
child_key = self.get_child_key_fn(key)
382+
total_prefix_length = 0
383+
384+
while len(key) > 0 and child_key in node.children.keys():
385+
node = node.children[child_key]
386+
node.last_access_time = time.time()
387+
prefix_len = self.key_match_fn(node.key, key)
375388

376-
if prefix_len == len(child.key):
377-
if child.evicted:
389+
if prefix_len == len(node.key):
390+
if node.evicted:
378391
# change the reference if the node is evicted
379392
# this often happens in the case of KV cache recomputation
380-
child.value = value[:prefix_len]
381-
self.token_to_kv_pool_host.update_synced(child.host_value)
382-
self.evictable_size_ += len(value[:prefix_len])
383-
return self._insert_helper(
384-
child, key[prefix_len:], value[prefix_len:]
385-
)
393+
node.value = value[:prefix_len]
394+
self.token_to_kv_pool_host.update_synced(node.host_value)
395+
self.evictable_size_ += len(node.value)
386396
else:
387-
self.inc_hit_count(child)
388-
return prefix_len + self._insert_helper(
389-
child, key[prefix_len:], value[prefix_len:]
390-
)
391-
392-
# partial match, split the node
393-
new_node = self._split_node(child.key, child, prefix_len)
394-
if new_node.evicted:
395-
new_node.value = value[:prefix_len]
396-
self.token_to_kv_pool_host.update_synced(new_node.host_value)
397-
self.evictable_size_ += len(new_node.value)
398-
return self._insert_helper(
399-
new_node, key[prefix_len:], value[prefix_len:]
400-
)
397+
self.inc_hit_count(node)
398+
total_prefix_length += prefix_len
401399
else:
402-
self.inc_hit_count(new_node)
403-
return prefix_len + self._insert_helper(
404-
new_node, key[prefix_len:], value[prefix_len:]
405-
)
400+
# partial match, split the node
401+
new_node = self._split_node(node.key, node, prefix_len)
402+
if new_node.evicted:
403+
new_node.value = value[:prefix_len]
404+
self.token_to_kv_pool_host.update_synced(new_node.host_value)
405+
self.evictable_size_ += len(new_node.value)
406+
else:
407+
self.inc_hit_count(new_node)
408+
total_prefix_length += prefix_len
409+
node = new_node
410+
411+
key = key[prefix_len:]
412+
value = value[prefix_len:]
413+
414+
if len(key):
415+
child_key = self.get_child_key_fn(key)
406416

407417
if len(key):
408418
new_node = TreeNode()
409419
new_node.parent = node
410420
new_node.key = key
411421
new_node.value = value
412-
node.children[key[0]] = new_node
422+
node.children[child_key] = new_node
413423
self.evictable_size_ += len(value)
414424

415425
if self.cache_controller.write_policy == "write_through":
416426
self.write_backup(new_node)
417-
return 0
427+
return total_prefix_length
418428

419429
def _collect_leaves_device(self):
420430
def is_leaf(node):

0 commit comments

Comments
 (0)