Skip to content

Commit 44adc90

Browse files
xiezhq-hermanntarinkk
authored andcommitted
upstream hicache fixes (sgl-project#5570)
1 parent 6c45067 commit 44adc90

File tree

8 files changed

+89
-46
lines changed

8 files changed

+89
-46
lines changed

python/sglang/srt/managers/schedule_batch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,14 @@ def init_next_round_input(
571571
self.prefix_indices, self.last_node = tree_cache.match_prefix(
572572
rid=self.rid, key=self.adjust_max_prefix_ids()
573573
)
574+
elif enable_hierarchical_cache:
575+
# in case last_node is evicted during scheduling, we need to update the prefix_indices
576+
while self.last_node.evicted:
577+
self.prefix_indices = self.prefix_indices[
578+
: -len(self.last_node.host_value)
579+
]
580+
self.last_node = self.last_node.parent
581+
574582
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
575583

576584
def adjust_max_prefix_ids(self):

python/sglang/srt/managers/scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ def init_memory_pool_and_cache(self):
489489
tp_cache_group=self.tp_cpu_group,
490490
page_size=self.page_size,
491491
hicache_ratio=server_args.hicache_ratio,
492+
hicache_size=server_args.hicache_size,
493+
hicache_write_policy=server_args.hicache_write_policy,
492494
)
493495
else:
494496
self.tree_cache = RadixCache(

python/sglang/srt/mem_cache/hiradix_cache.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@ def __init__(
2929
tp_cache_group: torch.distributed.ProcessGroup,
3030
page_size: int,
3131
hicache_ratio: float,
32+
hicache_size: int,
33+
hicache_write_policy: str,
3234
):
3335
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
3436
if isinstance(self.kv_cache, MHATokenToKVPool):
3537
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
36-
self.kv_cache, hicache_ratio, page_size
38+
self.kv_cache, hicache_ratio, hicache_size, page_size
3739
)
3840
elif isinstance(self.kv_cache, MLATokenToKVPool):
3941
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
40-
self.kv_cache, hicache_ratio, page_size
42+
self.kv_cache, hicache_ratio, hicache_size, page_size
4143
)
4244
else:
4345
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -50,14 +52,17 @@ def __init__(
5052
self.token_to_kv_pool_host,
5153
page_size,
5254
load_cache_event=self.load_cache_event,
55+
write_policy=hicache_write_policy,
5356
)
5457

5558
# record the nodes with ongoing write through
5659
self.ongoing_write_through = {}
5760
# record the node segments with ongoing load back
5861
self.ongoing_load_back = {}
5962
# todo: dynamically adjust the threshold
60-
self.write_through_threshold = 1
63+
self.write_through_threshold = (
64+
1 if hicache_write_policy == "write_through" else 3
65+
)
6166
self.load_back_threshold = 10
6267
super().__init__(
6368
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -76,7 +81,7 @@ def get_height(self, node: TreeNode):
7681
height += 1
7782
return height
7883

79-
def write_backup(self, node: TreeNode):
84+
def write_backup(self, node: TreeNode, write_back=False):
8085
host_indices = self.cache_controller.write(
8186
device_indices=node.value,
8287
node_id=node.id,
@@ -90,21 +95,29 @@ def write_backup(self, node: TreeNode):
9095
if host_indices is not None:
9196
node.host_value = host_indices
9297
self.ongoing_write_through[node.id] = node
93-
self.inc_lock_ref(node)
98+
if not write_back:
99+
# no need to lock nodes if write back
100+
self.inc_lock_ref(node)
94101
else:
95102
return 0
96103

97104
return len(host_indices)
98105

99106
def inc_hit_count(self, node: TreeNode):
100-
if self.cache_controller.write_policy != "write_through_selective":
107+
if node.backuped or self.cache_controller.write_policy == "write_back":
101108
return
102109
node.hit_count += 1
103-
if node.host_value is None and node.hit_count > self.write_through_threshold:
110+
if node.hit_count >= self.write_through_threshold:
104111
self.write_backup(node)
105112
node.hit_count = 0
106113

107-
def writing_check(self):
114+
def writing_check(self, write_back=False):
115+
if write_back:
116+
# blocking till all write back complete
117+
while len(self.ongoing_write_through) > 0:
118+
ack_id = self.cache_controller.ack_write_queue.get()
119+
del self.ongoing_write_through[ack_id]
120+
return
108121
queue_size = torch.tensor(
109122
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
110123
)
@@ -143,29 +156,25 @@ def evict(self, num_tokens: int):
143156
heapq.heapify(leaves)
144157

145158
num_evicted = 0
146-
pending_nodes = []
159+
write_back_nodes = []
147160
while num_evicted < num_tokens and len(leaves):
148161
x = heapq.heappop(leaves)
149162

150163
if x.lock_ref > 0:
151164
continue
152165

153-
if x.host_value is None:
166+
if not x.backuped:
154167
if self.cache_controller.write_policy == "write_back":
155-
num_evicted += self.write_backup(x)
156-
pending_nodes.append(x)
157-
elif self.cache_controller.write_policy == "write_through_selective":
158-
num_evicted += self._evict_write_through_selective(x)
168+
# write to host if the node is not backuped
169+
num_evicted += self.write_backup(x, write_back=True)
170+
write_back_nodes.append(x)
159171
else:
160-
assert (
161-
self.cache_controller.write_policy != "write_through"
162-
), "write_through should be inclusive"
163-
raise NotImplementedError
172+
num_evicted += self._evict_regular(x)
164173
else:
165-
num_evicted += self._evict_write_through(x)
174+
num_evicted += self._evict_backuped(x)
166175

167176
for child in x.parent.children.values():
168-
if child in pending_nodes:
177+
if child in write_back_nodes:
169178
continue
170179
if not child.evicted:
171180
break
@@ -174,23 +183,20 @@ def evict(self, num_tokens: int):
174183
heapq.heappush(leaves, x.parent)
175184

176185
if self.cache_controller.write_policy == "write_back":
177-
# blocking till all write back complete
178-
while len(self.ongoing_write_through) > 0:
179-
self.writing_check()
180-
time.sleep(0.1)
181-
for node in pending_nodes:
182-
assert node.host_value is not None
183-
self._evict_write_through(node)
186+
self.writing_check(write_back=True)
187+
for node in write_back_nodes:
188+
assert node.backuped
189+
self._evict_backuped(node)
184190

185-
def _evict_write_through(self, node: TreeNode):
191+
def _evict_backuped(self, node: TreeNode):
186192
# evict a node already written to host
187193
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
188194
assert num_evicted > 0
189195
self.evictable_size_ -= num_evicted
190196
node.value = None
191197
return num_evicted
192198

193-
def _evict_write_through_selective(self, node: TreeNode):
199+
def _evict_regular(self, node: TreeNode):
194200
# evict a node not initiated write to host
195201
self.cache_controller.mem_pool_device_allocator.free(node.value)
196202
num_evicted = len(node.value)
@@ -339,11 +345,13 @@ def _match_prefix_helper(self, node: TreeNode, key: List):
339345
prefix_len = self.key_match_fn(child.key, key)
340346
if prefix_len < len(child.key):
341347
new_node = self._split_node(child.key, child, prefix_len)
348+
self.inc_hit_count(new_node)
342349
if not new_node.evicted:
343350
value.append(new_node.value)
344351
node = new_node
345352
break
346353
else:
354+
self.inc_hit_count(child)
347355
if not child.evicted:
348356
value.append(child.value)
349357
node = child
@@ -369,7 +377,7 @@ def _split_node(self, key, child: TreeNode, split_len: int):
369377
else:
370378
new_node.value = child.value[:split_len]
371379
child.value = child.value[split_len:]
372-
if child.host_value is not None:
380+
if child.backuped:
373381
new_node.host_value = child.host_value[:split_len]
374382
child.host_value = child.host_value[split_len:]
375383
child.parent = new_node
@@ -426,8 +434,8 @@ def _insert_helper(self, node: TreeNode, key: List, value):
426434
node.children[child_key] = new_node
427435
self.evictable_size_ += len(value)
428436

429-
if self.cache_controller.write_policy == "write_through":
430-
self.write_backup(new_node)
437+
if self.cache_controller.write_policy != "write_back":
438+
self.inc_hit_count(new_node)
431439
return total_prefix_length
432440

433441
def _collect_leaves_device(self):

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -624,26 +624,27 @@ def __init__(
624624
self,
625625
device_pool: MHATokenToKVPool,
626626
host_to_device_ratio: float,
627+
host_size: int,
627628
pin_memory: bool,
628629
device: str,
629630
page_size: int,
630631
):
631-
assert (
632-
host_to_device_ratio >= 1
633-
), "The host memory should be larger than the device memory with the current protocol"
634-
# todo, other ways of configuring the size
635-
636632
self.device_pool = device_pool
637-
self.host_to_device_ratio = host_to_device_ratio
633+
self.dtype = device_pool.store_dtype
638634
self.pin_memory = pin_memory
639635
self.device = device
640636
self.page_size = page_size
641-
642-
self.size = int(device_pool.size * host_to_device_ratio)
637+
self.size_per_token = self.get_size_per_token()
638+
if host_size > 0:
639+
self.size = int(host_size * 1e9 // self.size_per_token)
640+
else:
641+
self.size = int(device_pool.size * host_to_device_ratio)
643642
# Align the host memory pool size to the page size
644643
self.size = self.size - (self.size % self.page_size)
645-
self.dtype = device_pool.store_dtype
646-
self.size_per_token = self.get_size_per_token()
644+
645+
assert (
646+
self.size > device_pool.size
647+
), "The host memory should be larger than the device memory with the current protocol"
647648

648649
# Verify there is enough available host memory.
649650
host_mem = psutil.virtual_memory()
@@ -795,12 +796,13 @@ def __init__(
795796
self,
796797
device_pool: MHATokenToKVPool,
797798
host_to_device_ratio: float,
799+
host_size: int,
798800
page_size: int,
799801
pin_memory: bool = True,
800802
device: str = "cpu",
801803
):
802804
super().__init__(
803-
device_pool, host_to_device_ratio, pin_memory, device, page_size
805+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
804806
)
805807

806808
def get_size_per_token(self):
@@ -869,12 +871,13 @@ def __init__(
869871
self,
870872
device_pool: MLATokenToKVPool,
871873
host_to_device_ratio: float,
874+
host_size: int,
872875
page_size: int,
873876
pin_memory: bool = True,
874877
device: str = "cpu",
875878
):
876879
super().__init__(
877-
device_pool, host_to_device_ratio, pin_memory, device, page_size
880+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
878881
)
879882

880883
def get_size_per_token(self):

python/sglang/srt/server_args.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ class ServerArgs:
180180
tool_call_parser: Optional[str] = None
181181
enable_hierarchical_cache: bool = False
182182
hicache_ratio: float = 2.0
183+
hicache_size: int = 0
184+
hicache_write_policy: str = "write_through_selective"
183185
flashinfer_mla_disable_ragged: bool = False
184186
warmups: Optional[str] = None
185187
moe_dense_tp_size: Optional[int] = None
@@ -1116,10 +1118,22 @@ def add_cli_args(parser: argparse.ArgumentParser):
11161118
parser.add_argument(
11171119
"--hicache-ratio",
11181120
type=float,
1119-
required=False,
11201121
default=ServerArgs.hicache_ratio,
11211122
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
11221123
)
1124+
parser.add_argument(
1125+
"--hicache-size",
1126+
type=int,
1127+
default=ServerArgs.hicache_size,
1128+
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
1129+
)
1130+
parser.add_argument(
1131+
"--hicache-write-policy",
1132+
type=str,
1133+
choices=["write_back", "write_through", "write_through_selective"],
1134+
default=ServerArgs.hicache_write_policy,
1135+
help="The write policy of hierarchical cache.",
1136+
)
11231137
parser.add_argument(
11241138
"--enable-deepep-moe",
11251139
action="store_true",

test/srt/test_hicache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def setUpClass(cls):
2323
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
2424
other_args=[
2525
"--enable-hierarchical-cache",
26+
"--mem-fraction-static",
27+
0.7,
28+
"--hicache-size",
29+
100,
2630
],
2731
)
2832

test/srt/test_hicache_mla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def setUpClass(cls):
2424
other_args=[
2525
"--trust-remote-code",
2626
"--enable-hierarchical-cache",
27+
"--hicache-ratio",
28+
2,
2729
],
2830
)
2931

test/srt/test_hicache_page.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def setUpClass(cls):
2424
other_args=[
2525
"--enable-hierarchical-cache",
2626
"--page-size",
27-
"32",
27+
32,
28+
"--hicache-write-policy",
29+
"write-back",
2830
],
2931
)
3032

0 commit comments

Comments
 (0)