Skip to content

fix(proxy): Enable partial matching for User ID filter in virtual keys page #12205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ async def generate_key_fn( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)

response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response
response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response

response = GenerateKeyResponse(**response)

Expand Down Expand Up @@ -681,9 +681,9 @@ async def _set_object_permission(
data=data_json["object_permission"],
)
)
data_json["object_permission_id"] = (
created_object_permission.object_permission_id
)
data_json[
"object_permission_id"
] = created_object_permission.object_permission_id

# delete the object_permission from the data_json
data_json.pop("object_permission")
Expand Down Expand Up @@ -1667,10 +1667,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)

if len(_keys_being_deleted) == 0:
Expand Down Expand Up @@ -1778,9 +1778,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config

try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception:
models = None
# 2. process model table
Expand Down Expand Up @@ -2091,11 +2091,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)

if complete_user_info_db_obj is None:
Expand Down Expand Up @@ -2181,10 +2181,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
if teams is None:
return []
Expand Down Expand Up @@ -2394,7 +2394,10 @@ async def _list_key_helper(
# Base conditions for user's own keys
user_condition: Dict[str, Any] = {}
if user_id and isinstance(user_id, str):
user_condition["user_id"] = user_id
user_condition["user_id"] = {
"contains": user_id,
"mode": "insensitive", # Case-insensitive search
}
if team_id and isinstance(team_id, str):
user_condition["team_id"] = team_id
if key_alias and isinstance(key_alias, str):
Expand Down
135 changes: 134 additions & 1 deletion tests/proxy_admin_ui_tests/test_key_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@
UpdateUserRequest,
UserAPIKeyAuth,
)
from litellm.types.proxy.management_endpoints.ui_sso import LiteLLM_UpperboundKeyGenerateParams
from litellm.types.proxy.management_endpoints.ui_sso import (
LiteLLM_UpperboundKeyGenerateParams,
)

proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache())

Expand Down Expand Up @@ -1184,6 +1186,137 @@ async def test_list_key_helper_team_filtering(prisma_client):
)


@pytest.mark.asyncio
async def test_list_key_helper_user_id_partial_matching(prisma_client):
"""
Test _list_key_helper function with partial user_id matching:
1. Exact match still works
2. Partial match (substring) works
3. Case-insensitive matching works
"""
from litellm.proxy.management_endpoints.key_management_endpoints import (
_list_key_helper,
)

# Setup - create multiple test keys
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()

# Create test data with specific user IDs for partial matching
test_id = str(uuid.uuid4())[:8] # Short unique ID for this test run
test_users = [
f"user_john_smith_123_{test_id}",
f"user_jane_doe_456_{test_id}",
f"admin_john_admin_{test_id}",
f"test_user_789_{test_id}",
]

# Create keys for each test user
for user_id in test_users:
await generate_key_fn(
data=GenerateKeyRequest(
user_id=user_id,
key_alias=f"key_for_{user_id}_{uuid.uuid4()}", # Make alias unique
),
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="admin",
),
)

# Test 1: Partial match - search for substring that exists in 2 user IDs
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=100, # Increase size to get all results
user_id="john_", # This substring exists in "user_john_smith" and "admin_john_admin"
team_id=None,
key_alias=None,
key_hash=None,
organization_id=None,
return_full_object=True, # Return full objects so we can filter by user_id
)
# Filter results to only count those from this test run
test_keys = [k for k in result["keys"] if test_id in k.user_id]
assert (
len(test_keys) == 2
), f"Should return 2 keys containing 'john_' from test_id {test_id}"

# Test 2: Case-insensitive match - search for "JOHN_" should also return 2 keys
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=100, # Increase size to get all results
user_id="JOHN_", # Search for JOHN_ in uppercase
team_id=None,
key_alias=None,
key_hash=None,
organization_id=None,
return_full_object=True, # Return full objects so we can filter by user_id
)
# Filter results to only count those from this test run
test_keys = [k for k in result["keys"] if test_id in k.user_id]
assert (
len(test_keys) == 2
), "Should return 2 keys containing 'john_' (case-insensitive)"

# Test 3: Partial match with numbers - search for test_id should return all 4 keys
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_id, # Search for the test_id which is in all user IDs
team_id=None,
key_alias=None,
key_hash=None,
organization_id=None,
)
assert len(result["keys"]) == 4, f"Should return 4 keys containing '{test_id}'"

# Test 4: Partial match prefix - search for "user_j" should return 2 keys
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=f"user_j", # Generic prefix search
team_id=None,
key_alias=None,
key_hash=None,
organization_id=None,
)
assert (
len(result["keys"]) >= 2
), "Should return at least 2 keys starting with 'user_j'"

# Test 5: Exact match still works - full user_id should return exactly 1 key
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id=test_users[0], # Use the exact first test user ID
team_id=None,
key_alias=None,
key_hash=None,
organization_id=None,
)
assert len(result["keys"]) == 1, "Exact match should still work and return 1 key"

# Test 6: No match - search for non-existent pattern
result = await _list_key_helper(
prisma_client=prisma_client,
page=1,
size=10,
user_id="nonexistent",
team_id=None,
key_alias=None,
key_hash=None,
organization_id=None,
)
assert len(result["keys"]) == 0, "Should return 0 keys for non-existent pattern"


@pytest.mark.asyncio
@patch("litellm.proxy.management_endpoints.key_management_endpoints.get_team_object")
async def test_key_generate_always_db_team(mock_get_team_object):
Expand Down
Loading