Skip to content

Fix: Potential SQLi in spend_management_endpoints.py #9878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions litellm/proxy/spend_tracking/spend_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,9 +1919,7 @@ async def view_spend_logs( # noqa: PLR0915
):
result: dict = {}
for record in response:
dt_object = datetime.strptime(
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
) # type: ignore
dt_object = datetime.strptime(str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ") # type: ignore
date = dt_object.date()
if date not in result:
result[date] = {"users": {}, "models": {}}
Expand Down Expand Up @@ -2097,8 +2095,7 @@ async def is_materialized_global_spend_view() -> bool:
try:
resp = await prisma_client.db.query_raw(sql_query)

assert resp[0]["relkind"] == "m"
return True
return resp[0]["relkind"] == "m"
except Exception:
return False

Expand Down Expand Up @@ -2396,9 +2393,21 @@ async def global_spend_keys(
return response
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
sql_query = f"""SELECT * FROM "Last30dKeysBySpend" LIMIT {limit};"""
sql_query = """SELECT * FROM "Last30dKeysBySpend";"""

response = await prisma_client.db.query_raw(query=sql_query)
if limit is None:
response = await prisma_client.db.query_raw(sql_query)
return response
try:
limit = int(limit)
if limit < 1:
raise ValueError("Limit must be greater than 0")
sql_query = """SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;"""
response = await prisma_client.db.query_raw(sql_query, limit)
except ValueError as e:
raise HTTPException(
status_code=422, detail={"error": f"Invalid limit: {limit}, error: {e}"}
) from e

return response

Expand Down Expand Up @@ -2646,9 +2655,9 @@ async def global_spend_models(
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})

sql_query = f"""SELECT * FROM "Last30dModelsBySpend" LIMIT {limit};"""
sql_query = """SELECT * FROM "Last30dModelsBySpend" LIMIT $1 ;"""

response = await prisma_client.db.query_raw(query=sql_query)
response = await prisma_client.db.query_raw(sql_query, int(limit))

return response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,66 @@ def _compare_nested_dicts(
f"Value mismatch at {current_path}:\n expected: {expected_str}\n got: {actual_str}"
)
return differences


@pytest.mark.asyncio
async def test_global_spend_keys_endpoint_limit_validation(client, monkeypatch):
"""
Test to ensure that the global_spend_keys endpoint is protected against SQL injection attacks.
Verifies that the limit parameter is properly parameterized and not directly interpolated.
"""
# Create a simple mock for prisma client with empty response
mock_prisma_client = MagicMock()
mock_db = MagicMock()
mock_query_raw = MagicMock()
mock_query_raw.return_value = asyncio.Future()
mock_query_raw.return_value.set_result([])
mock_db.query_raw = mock_query_raw
mock_prisma_client.db = mock_db
# Apply the mock to the prisma_client module
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)

# Call the endpoint without specifying a limit
no_limit_response = client.get("/global/spend/keys")
assert no_limit_response.status_code == 200
mock_query_raw.assert_called_once_with('SELECT * FROM "Last30dKeysBySpend";')
# Reset the mock for the next test
mock_query_raw.reset_mock()
# Test with valid input
normal_limit = "10"
good_input_response = client.get(f"/global/spend/keys?limit={normal_limit}")
assert good_input_response.status_code == 200
# Verify the mock was called with the correct parameters
mock_query_raw.assert_called_once_with(
'SELECT * FROM "Last30dKeysBySpend" LIMIT $1 ;', 10
)
# Reset the mock for the next test
mock_query_raw.reset_mock()
# Test with SQL injection payload
sql_injection_limit = "10; DROP TABLE spend_logs; --"
response = client.get(f"/global/spend/keys?limit={sql_injection_limit}")
# Verify the response is a validation error (422)
assert response.status_code == 422
# Verify the mock was not called with the SQL injection payload
# This confirms that the validation happens before the database query
mock_query_raw.assert_not_called()
# Reset the mock for the next test
mock_query_raw.reset_mock()
# Test with non-numeric input
non_numeric_limit = "abc"
response = client.get(f"/global/spend/keys?limit={non_numeric_limit}")
assert response.status_code == 422
mock_query_raw.assert_not_called()
mock_query_raw.reset_mock()
# Test with negative number
negative_limit = "-5"
response = client.get(f"/global/spend/keys?limit={negative_limit}")
assert response.status_code == 422
mock_query_raw.assert_not_called()
mock_query_raw.reset_mock()
# Test with zero
zero_limit = "0"
response = client.get(f"/global/spend/keys?limit={zero_limit}")
assert response.status_code == 422
mock_query_raw.assert_not_called()
mock_query_raw.reset_mock()
Loading