diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 9215060dfed2..aed0c70695ec 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -603,9 +603,9 @@ async def generate_key_fn( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -681,9 +681,9 @@ async def _set_object_permission( data=data_json["object_permission"], ) ) - data_json["object_permission_id"] = ( - created_object_permission.object_permission_id - ) + data_json[ + "object_permission_id" + ] = created_object_permission.object_permission_id # delete the object_permission from the data_json data_json.pop("object_permission") @@ -1667,10 +1667,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) if len(_keys_being_deleted) == 0: @@ -1778,9 +1778,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -2091,11 +2091,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -2181,10 +2181,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] @@ -2394,7 +2394,10 @@ async def _list_key_helper( # Base conditions for user's own keys user_condition: Dict[str, Any] = {} if user_id and isinstance(user_id, str): - user_condition["user_id"] = user_id + user_condition["user_id"] = { + "contains": user_id, + "mode": "insensitive", # Case-insensitive search + } if team_id and isinstance(team_id, str): user_condition["team_id"] = team_id if key_alias and isinstance(key_alias, str): diff --git a/tests/proxy_admin_ui_tests/test_key_management.py b/tests/proxy_admin_ui_tests/test_key_management.py index b3522c881644..092c5acaed03 100644 --- a/tests/proxy_admin_ui_tests/test_key_management.py +++ b/tests/proxy_admin_ui_tests/test_key_management.py @@ -91,7 +91,9 @@ UpdateUserRequest, UserAPIKeyAuth, ) -from litellm.types.proxy.management_endpoints.ui_sso import LiteLLM_UpperboundKeyGenerateParams +from litellm.types.proxy.management_endpoints.ui_sso import ( + LiteLLM_UpperboundKeyGenerateParams, +) proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) @@ -1184,6 +1186,137 @@ async def test_list_key_helper_team_filtering(prisma_client): ) +@pytest.mark.asyncio +async def test_list_key_helper_user_id_partial_matching(prisma_client): + """ + Test _list_key_helper function with partial user_id matching: + 1. Exact match still works + 2. Partial match (substring) works + 3. Case-insensitive matching works + """ + from litellm.proxy.management_endpoints.key_management_endpoints import ( + _list_key_helper, + ) + + # Setup - create multiple test keys + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + # Create test data with specific user IDs for partial matching + test_id = str(uuid.uuid4())[:8] # Short unique ID for this test run + test_users = [ + f"user_john_smith_123_{test_id}", + f"user_jane_doe_456_{test_id}", + f"admin_john_admin_{test_id}", + f"test_user_789_{test_id}", + ] + + # Create keys for each test user + for user_id in test_users: + await generate_key_fn( + data=GenerateKeyRequest( + user_id=user_id, + key_alias=f"key_for_{user_id}_{uuid.uuid4()}", # Make alias unique + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="admin", + ), + ) + + # Test 1: Partial match - search for substring that exists in 2 user IDs + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=100, # Increase size to get all results + user_id="john_", # This substring exists in "user_john_smith" and "admin_john_admin" + team_id=None, + key_alias=None, + key_hash=None, + organization_id=None, + return_full_object=True, # Return full objects so we can filter by user_id + ) + # Filter results to only count those from this test run + test_keys = [k for k in result["keys"] if test_id in k.user_id] + assert ( + len(test_keys) == 2 + ), f"Should return 2 keys containing 'john_' from test_id {test_id}" + + # Test 2: Case-insensitive match - search for "JOHN_" should also return 2 keys + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=100, # Increase size to get all results + user_id="JOHN_", # Search for JOHN_ in uppercase + team_id=None, + key_alias=None, + key_hash=None, + organization_id=None, + return_full_object=True, # Return full objects so we can filter by user_id + ) + # Filter results to only count those from this test run + test_keys = [k for k in result["keys"] if test_id in k.user_id] + assert ( + len(test_keys) == 2 + ), "Should return 2 keys containing 'john_' (case-insensitive)" + + # Test 3: Partial match with numbers - search for test_id should return all 4 keys + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=test_id, # Search for the test_id which is in all user IDs + team_id=None, + key_alias=None, + key_hash=None, + organization_id=None, + ) + assert len(result["keys"]) == 4, f"Should return 4 keys containing '{test_id}'" + + # Test 4: Partial match prefix - search for "user_j" should return 2 keys + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=f"user_j", # Generic prefix search + team_id=None, + key_alias=None, + key_hash=None, + organization_id=None, + ) + assert ( + len(result["keys"]) >= 2 + ), "Should return at least 2 keys starting with 'user_j'" + + # Test 5: Exact match still works - full user_id should return exactly 1 key + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id=test_users[0], # Use the exact first test user ID + team_id=None, + key_alias=None, + key_hash=None, + organization_id=None, + ) + assert len(result["keys"]) == 1, "Exact match should still work and return 1 key" + + # Test 6: No match - search for non-existent pattern + result = await _list_key_helper( + prisma_client=prisma_client, + page=1, + size=10, + user_id="nonexistent", + team_id=None, + key_alias=None, + key_hash=None, + organization_id=None, + ) + assert len(result["keys"]) == 0, "Should return 0 keys for non-existent pattern" + + @pytest.mark.asyncio @patch("litellm.proxy.management_endpoints.key_management_endpoints.get_team_object") async def test_key_generate_always_db_team(mock_get_team_object):