Skip to content

Commit b4247e7

Browse files
krrishdholakiastefan--
authored andcommitted
Rate Limiting: Check all slots on redis, Reduce number of cache writes (BerriAI#11299)
* fix(base_routing_strategy.py): compress increments to redis - reduces write ops * fix(base_routing_strategy.py): make get and reset in memory keys atomic * fix(base_routing_strategy.py): don't reset keys - causes discrepency on subsequent requests to instance * fix(parallel_request_limiter.py): retrieve values of previous slots from cache more accurate rate limiting with sliding window * fix: fix test * fix: fix linting error
1 parent 939ba2f commit b4247e7

File tree

4 files changed

+168
-70
lines changed

4 files changed

+168
-70
lines changed

litellm/proxy/_new_secret_config.yaml

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,100 @@
11
model_list:
2-
- model_name: fake-openai-endpoint
2+
- model_name: "gemini-2.0-flash"
33
litellm_params:
4-
model: openai/fake
5-
api_key: fake-key
6-
api_base: https://exampleopenaiendpoint-production.up.railway.app/
7-
- model_name: "text-embedding-3-small"
4+
model: vertex_ai/gemini-2.0-flash
5+
vertex_project: my-project-id
6+
vertex_location: us-central1
7+
- model_name: "gpt-4o-mini-openai"
88
litellm_params:
9-
model: text-embedding-3-small
9+
model: gpt-4o-mini
1010
api_key: os.environ/OPENAI_API_KEY
11-
- model_name: papluca/xlm-roberta-base-language-detection
11+
- model_name: "bedrock-nova"
12+
litellm_params:
13+
model: us.amazon.nova-pro-v1:0
14+
- model_name: openrouter_model
15+
litellm_params:
16+
model: openrouter/openrouter_model
17+
api_key: os.environ/OPENROUTER_API_KEY
18+
api_base: http://0.0.0.0:8090
19+
- model_name: dall-e-3-azure
1220
litellm_params:
13-
model: openai/gpt-3.5-turbo
14-
api_base: https://api.openai.com
21+
model: azure/dall-e-3-test
22+
api_version: "2023-12-01-preview"
23+
api_base: os.environ/AZURE_SWEDEN_API_BASE
24+
api_key: os.environ/AZURE_SWEDEN_API_KEY
25+
model_info:
26+
input_cost_per_pixel: 10
27+
- model_name: "claude-3-7-sonnet"
28+
litellm_params:
29+
model: databricks/databricks-claude-3-7-sonnet
30+
api_key: os.environ/DATABRICKS_API_KEY
31+
api_base: os.environ/DATABRICKS_API_BASE
32+
- model_name: "gpt-4.1"
33+
litellm_params:
34+
model: azure/gpt-4.1
35+
api_key: os.environ/AZURE_API_KEY_REALTIME
36+
api_base: https://krris-m2f9a9i7-eastus2.openai.azure.com/
37+
- model_name: "xai/*"
38+
litellm_params:
39+
model: xai/*
40+
api_key: os.environ/XAI_API_KEY
41+
- model_name: "text-embedding-ada-002"
42+
litellm_params:
43+
model: text-embedding-ada-002
1544
api_key: os.environ/OPENAI_API_KEY
16-
45+
- model_name: gemini/gemini-2.0-flash
46+
litellm_params:
47+
model: gemini/gemini-2.0-flash
48+
- model_name: llama-qwen
49+
litellm_params:
50+
model: ollama/qwen2:0.5b
51+
model_info:
52+
input_cost_per_token: 0.75
53+
output_cost_per_token: 3
54+
- model_name: gpt-image-1
55+
litellm_params:
56+
model: gpt-image-1
57+
api_key: os.environ/OPENAI_API_KEY
58+
# drop_params: true
59+
- model_name: "gpt-4o-batch"
60+
litellm_params:
61+
model: azure/gpt-4o-mini
62+
api_base: os.environ/AZURE_API_BASE
63+
api_key: os.environ/AZURE_API_KEY
64+
model_info:
65+
id: my-general-azure-deployment
66+
mode: batch
67+
- model_name: "gpt-4o-batch"
68+
litellm_params:
69+
model: azure/gpt-4o-mini
70+
api_base: https://krris-m2f9a9i7-eastus2.openai.azure.com
71+
api_key: 04d22fb7e9ad4d9c8afe7c6abf97a6fc
72+
model_info:
73+
id: my-unique-azure-deployment
74+
mode: batch
75+
- model_name: fake-openai-endpoint
76+
litellm_params:
77+
model: openai/fake
78+
api_key: fake-key
79+
api_base: https://exampleopenaiendpoint-production.up.railway.app/
80+
81+
general_settings:
82+
store_model_in_db: true
83+
store_prompts_in_spend_logs: true
84+
disable_prisma_schema_update: true
85+
# master_key: os.environ/PROXY_MASTER_KEY
86+
87+
litellm_settings:
88+
cache: true
89+
cache_params:
90+
type: redis
91+
ttl: 600
92+
password: os.environ/REDIS_PASSWORD
93+
supported_call_types: ["acompletion", "completion"]
94+
95+
router_settings:
96+
redis_password: os.environ/REDIS_PASSWORD
97+
98+
99+
100+

litellm/proxy/hooks/parallel_request_limiter_v2.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, internal_usage_cache: InternalUsageCache):
6767
self,
6868
dual_cache=internal_usage_cache.dual_cache,
6969
should_batch_redis_writes=True,
70-
default_sync_interval=0.01,
70+
default_sync_interval=1,
7171
)
7272

7373
def print_verbose(self, print_statement):
@@ -120,6 +120,27 @@ def _get_current_usage_key(
120120
def get_key_pattern_to_sync(self) -> Optional[str]:
121121
return self.prefix + "::"
122122

123+
def _get_slots_to_check(self, current_slot: int) -> List[str]:
124+
slots_to_check = []
125+
current_time = datetime.now()
126+
for i in range(4):
127+
slot_number = (current_slot - i) % 4 # This ensures we wrap around properly
128+
minute = current_time.minute
129+
hour = current_time.hour
130+
131+
# If we need to look at previous minute
132+
if current_slot - i < 0:
133+
if minute == 0:
134+
# If we're at minute 0, go to previous hour
135+
hour = (current_time.hour - 1) % 24
136+
minute = 59
137+
else:
138+
minute = current_time.minute - 1
139+
140+
slot_key = f"{current_time.strftime('%Y-%m-%d')}-{hour:02d}-{minute:02d}-{slot_number}"
141+
slots_to_check.append(slot_key)
142+
return slots_to_check
143+
123144
async def check_key_in_limits_v2(
124145
self,
125146
user_api_key_dict: UserAPIKeyAuth,
@@ -145,25 +166,9 @@ async def check_key_in_limits_v2(
145166
current_slot = (
146167
current_time.second // 15
147168
) # This gives us 0-3 for the current 15s slot
148-
slots_to_check = []
169+
slots_to_check = self._get_slots_to_check(current_slot)
149170
slot_cache_keys = []
150171
# Calculate the last 4 slots, handling minute boundaries
151-
for i in range(4):
152-
slot_number = (current_slot - i) % 4 # This ensures we wrap around properly
153-
minute = current_time.minute
154-
hour = current_time.hour
155-
156-
# If we need to look at previous minute
157-
if current_slot - i < 0:
158-
if minute == 0:
159-
# If we're at minute 0, go to previous hour
160-
hour = (current_time.hour - 1) % 24
161-
minute = 59
162-
else:
163-
minute = current_time.minute - 1
164-
165-
slot_key = f"{current_time.strftime('%Y-%m-%d')}-{hour:02d}-{minute:02d}-{slot_number}"
166-
slots_to_check.append(slot_key)
167172

168173
# For each slot, create keys for all rate limit groups
169174
for slot_key in slots_to_check:
@@ -183,6 +188,8 @@ async def check_key_in_limits_v2(
183188
decrement_list.append(
184189
(key, -1 if increment_value_by_group[group] == 1 else 0)
185190
)
191+
else:
192+
self.add_to_in_memory_keys_to_update(key=key)
186193
slot_cache_keys.append(key)
187194

188195
if (

litellm/router_strategy/base_routing_strategy.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import asyncio
66
from abc import ABC
7-
from typing import List, Optional, Set, Tuple, Union
7+
from typing import Dict, List, Optional, Set, Tuple, Union
88

99
from litellm._logging import verbose_router_logger
1010
from litellm.caching.caching import DualCache
@@ -87,6 +87,7 @@ async def _increment_value_in_current_window(
8787
increment_value=value,
8888
ttl=ttl,
8989
)
90+
9091
self.redis_increment_operation_queue.append(increment_op)
9192
self.add_to_in_memory_keys_to_update(key=key)
9293
return result
@@ -116,24 +117,40 @@ async def _push_in_memory_increments_to_redis(self):
116117
"""
117118
How this works:
118119
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
119-
- This function pushes all increments to Redis in a batched pipeline to optimize performance
120+
- This function compresses multiple increments for the same key into a single operation
121+
- Then pushes all increments to Redis in a batched pipeline to optimize performance
120122
121123
Only runs if Redis is initialized
122124
"""
123125
try:
124126
if not self.dual_cache.redis_cache:
125127
return # Redis is not initialized
126128

127-
# verbose_router_logger.debug(
128-
# "Pushing Redis Increment Pipeline for queue: %s",
129-
# self.redis_increment_operation_queue,
130-
# )
131129
if len(self.redis_increment_operation_queue) > 0:
130+
# Compress operations for the same key
131+
compressed_ops: Dict[str, RedisPipelineIncrementOperation] = {}
132+
ops_to_remove = []
133+
for idx, op in enumerate(self.redis_increment_operation_queue):
134+
if op["key"] in compressed_ops:
135+
# Add to existing increment
136+
compressed_ops[op["key"]]["increment_value"] += op[
137+
"increment_value"
138+
]
139+
else:
140+
compressed_ops[op["key"]] = op
141+
142+
ops_to_remove.append(idx)
143+
# Convert back to list
144+
compressed_queue = list(compressed_ops.values())
145+
132146
await self.dual_cache.redis_cache.async_increment_pipeline(
133-
increment_list=self.redis_increment_operation_queue,
147+
increment_list=compressed_queue,
134148
)
135-
136-
self.redis_increment_operation_queue = []
149+
self.redis_increment_operation_queue = [
150+
op
151+
for idx, op in enumerate(self.redis_increment_operation_queue)
152+
if idx not in ops_to_remove
153+
]
137154

138155
except Exception as e:
139156
verbose_router_logger.error(
@@ -153,6 +170,12 @@ def get_key_pattern_to_sync(self) -> Optional[str]:
153170
def get_in_memory_keys_to_update(self) -> Set[str]:
154171
return self.in_memory_keys_to_update
155172

173+
def get_and_reset_in_memory_keys_to_update(self) -> Set[str]:
174+
"""Atomic get and reset in-memory keys to update"""
175+
keys = self.in_memory_keys_to_update
176+
self.in_memory_keys_to_update = set()
177+
return keys
178+
156179
def reset_in_memory_keys_to_update(self):
157180
self.in_memory_keys_to_update = set()
158181

@@ -174,9 +197,6 @@ async def _sync_in_memory_spend_with_redis(self):
174197
if self.dual_cache.redis_cache is None:
175198
return
176199

177-
# 1. Push all provider spend increments to Redis
178-
await self._push_in_memory_increments_to_redis()
179-
180200
# 2. Fetch all current provider spend from Redis to update in-memory cache
181201
cache_keys = (
182202
self.get_in_memory_keys_to_update()
@@ -195,39 +215,29 @@ async def _sync_in_memory_spend_with_redis(self):
195215
)
196216
)
197217
for k, v in zip(cache_keys_list, in_memory_before):
198-
in_memory_before_dict[k] = v
218+
in_memory_before_dict[k] = float(v or 0)
219+
220+
# 1. Push all provider spend increments to Redis
221+
await self._push_in_memory_increments_to_redis()
199222

200223
# 2. Fetch from Redis
201224
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
202225
key_list=cache_keys_list
203226
)
204227

205-
# 3. Snapshot in-memory after
206-
in_memory_after = (
207-
await self.dual_cache.in_memory_cache.async_batch_get_cache(
208-
keys=cache_keys_list
209-
)
210-
)
211-
in_memory_after_dict = {}
212-
for k, v in zip(cache_keys_list, in_memory_after):
213-
in_memory_after_dict[k] = v
214-
215228
# 4. Merge
216229
for key in cache_keys_list:
217230
redis_val = float(redis_values.get(key, 0) or 0)
218231
before = float(in_memory_before_dict.get(key, 0) or 0)
219-
after = float(in_memory_after_dict.get(key, 0) or 0)
232+
after = float(
233+
await self.dual_cache.in_memory_cache.async_get_cache(key=key) or 0
234+
)
220235
delta = after - before
221-
if delta > 0:
222-
await self._increment_value_in_current_window(
223-
key=key, value=delta, ttl=60
224-
)
225236
merged = redis_val + delta
226237
await self.dual_cache.in_memory_cache.async_set_cache(
227238
key=key, value=merged
228239
)
229240

230-
self.reset_in_memory_keys_to_update()
231241
except Exception as e:
232242
verbose_router_logger.exception(
233243
f"Error syncing in-memory cache with Redis: {str(e)}"

tests/test_litellm/router_strategy/test_base_routing_strategy.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def test_sync_in_memory_spend_with_redis(base_strategy, mock_dual_cache):
9999
# Setup test data
100100
base_strategy.in_memory_keys_to_update = {"key1"}
101101

102-
# Mock the in-memory cache batch get responses
102+
# Mock the in-memory cache batch get responses for before snapshot
103103
in_memory_before_future: asyncio.Future[List[str]] = asyncio.Future()
104104
in_memory_before_future.set_result(["5.0"]) # Initial values
105105
mock_dual_cache.in_memory_cache.async_batch_get_cache.return_value = (
@@ -111,13 +111,12 @@ async def test_sync_in_memory_spend_with_redis(base_strategy, mock_dual_cache):
111111
redis_future.set_result({"key1": "15.0"}) # Redis values
112112
mock_dual_cache.redis_cache.async_batch_get_cache.return_value = redis_future
113113

114-
# Mock in-memory after snapshot
115-
in_memory_after_future: asyncio.Future[List[str]] = asyncio.Future()
116-
in_memory_after_future.set_result(["8.0"]) # Values after potential updates
117-
mock_dual_cache.in_memory_cache.async_batch_get_cache.side_effect = [
118-
in_memory_before_future, # First call for before snapshot
119-
in_memory_after_future, # Second call for after snapshot
120-
]
114+
# Mock in-memory get for after snapshot
115+
in_memory_after_future: asyncio.Future[Optional[str]] = asyncio.Future()
116+
in_memory_after_future.set_result("8.0") # Value after potential updates
117+
mock_dual_cache.in_memory_cache.async_get_cache.return_value = (
118+
in_memory_after_future
119+
)
121120

122121
await base_strategy._sync_in_memory_spend_with_redis()
123122

@@ -129,19 +128,17 @@ async def test_sync_in_memory_spend_with_redis(base_strategy, mock_dual_cache):
129128

130129
# Verify in-memory cache was updated with merged values
131130
# For key1: redis_val(15.0) + delta(8.0 - 5.0) = 18.0
132-
# For key2: redis_val(20.0) + delta(12.0 - 10.0) = 22.0
133131
assert mock_dual_cache.in_memory_cache.async_set_cache.call_count == 1
134132

135133
# Verify the final merged values
136134
set_cache_calls = mock_dual_cache.in_memory_cache.async_set_cache.call_args_list
137-
print(f"set_cache_calls: {set_cache_calls}")
138135
assert any(
139-
call.kwargs["key"] == "key1" and call.kwargs["value"] == 18.0
136+
call.kwargs["key"] == "key1" and float(call.kwargs["value"]) == 18.0
140137
for call in set_cache_calls
141138
)
142139

143-
# Verify cache keys were reset
144-
assert len(base_strategy.in_memory_keys_to_update) == 0
140+
# Verify cache keys still exist
141+
assert len(base_strategy.in_memory_keys_to_update) == 1
145142

146143

147144
def test_cache_keys_management(base_strategy):

0 commit comments

Comments
 (0)