Skip to content

[Feat SSO] Add LiteLLM SCIM Integration for Team and User management #10072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
121fb32
fix NewUser response type
ishaan-jaff Apr 16, 2025
a14d0d9
add scim router
ishaan-jaff Apr 16, 2025
4deb795
add v0 scim v2 endpoints
ishaan-jaff Apr 16, 2025
24978c9
working scim transformation
ishaan-jaff Apr 16, 2025
57e61bc
use 1 file for types
ishaan-jaff Apr 16, 2025
f475564
fix scim firstname and givenName storage
ishaan-jaff Apr 16, 2025
251c398
working SCIMErrorResponse
ishaan-jaff Apr 16, 2025
a6e2988
working team / group provisioning on SCIM
ishaan-jaff Apr 16, 2025
10d399a
add SCIMPatchOp
ishaan-jaff Apr 16, 2025
ce71e85
move scim folder
ishaan-jaff Apr 16, 2025
edb50d4
fix import scim_router
ishaan-jaff Apr 16, 2025
a208598
fix dont auto create scim keys
ishaan-jaff Apr 17, 2025
119ea80
add auth on all scim endpoints
ishaan-jaff Apr 17, 2025
b8a1bc5
add is_virtual_key_allowed_to_call_route
ishaan-jaff Apr 17, 2025
d16f923
fix allowed routes
ishaan-jaff Apr 17, 2025
4fe81bc
fix for key management
ishaan-jaff Apr 17, 2025
f75aaab
fix allowed routes check
ishaan-jaff Apr 17, 2025
38de64e
clean up error message
ishaan-jaff Apr 17, 2025
54aa100
fix code check
ishaan-jaff Apr 17, 2025
cd4e923
fix for route checks
ishaan-jaff Apr 17, 2025
af7ffca
ui SCIM support
ishaan-jaff Apr 17, 2025
2c3c029
add UI tab for SCIM
ishaan-jaff Apr 17, 2025
d3e2949
fixes SCIM
ishaan-jaff Apr 17, 2025
3971a9e
fixes for SCIM settings on ui
ishaan-jaff Apr 17, 2025
61190da
scim settings
ishaan-jaff Apr 17, 2025
eea7050
clean up scim view
ishaan-jaff Apr 17, 2025
a953319
Merge branch 'main' into litellm_scim_support
ishaan-jaff Apr 17, 2025
61a9fcf
add migration for allowed_routes in keys table
ishaan-jaff Apr 17, 2025
62c43ef
refactor scim transform
ishaan-jaff Apr 17, 2025
721e1f6
fix SCIM linting error
ishaan-jaff Apr 17, 2025
18b75cb
fix code quality check
ishaan-jaff Apr 17, 2025
ad41bbe
fix ui linting
ishaan-jaff Apr 17, 2025
7c4268f
test_scim_transformations.py
ishaan-jaff Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- AlterTable
ALTER TABLE "LiteLLM_VerificationToken" ADD COLUMN "allowed_routes" TEXT[] DEFAULT ARRAY[]::TEXT[];

1 change: 1 addition & 0 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ model LiteLLM_VerificationToken {
budget_duration String?
budget_reset_at DateTime?
allowed_cache_controls String[] @default([])
allowed_routes String[] @default([])
model_spend Json @default("{}")
model_max_budget Json @default("{}")
budget_id String?
Expand Down
71 changes: 38 additions & 33 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
allowed_cache_controls: Optional[list] = []
config: Optional[dict] = {}
permissions: Optional[dict] = {}
model_max_budget: Optional[
dict
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
model_max_budget: Optional[dict] = (
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}

model_config = ConfigDict(protected_namespaces=())
model_rpm_limit: Optional[dict] = None
Expand All @@ -667,6 +667,7 @@ class KeyRequestBase(GenerateRequestBase):
budget_id: Optional[str] = None
tags: Optional[List[str]] = None
enforced_params: Optional[List[str]] = None
allowed_routes: Optional[list] = []


class GenerateKeyRequest(KeyRequestBase):
Expand Down Expand Up @@ -816,6 +817,8 @@ class NewUserResponse(GenerateKeyResponse):
teams: Optional[list] = None
user_alias: Optional[str] = None
model_max_budget: Optional[dict] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None


class UpdateUserRequest(GenerateRequestBase):
Expand Down Expand Up @@ -908,12 +911,12 @@ class NewCustomerRequest(BudgetNewRequest):
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)

@model_validator(mode="before")
@classmethod
Expand All @@ -935,12 +938,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[
AllowedModelRegion
] = None # require all user requests to use models in this specific region
default_model: Optional[
str
] = None # if no equivalent model in allowed region - default all requests to this model
allowed_model_region: Optional[AllowedModelRegion] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)


class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
Expand Down Expand Up @@ -1076,9 +1079,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):

class AddTeamCallback(LiteLLMPydanticObjectBase):
callback_name: str
callback_type: Optional[
Literal["success", "failure", "success_and_failure"]
] = "success_and_failure"
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
"success_and_failure"
)
callback_vars: Dict[str, str]

@model_validator(mode="before")
Expand Down Expand Up @@ -1144,6 +1147,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
litellm_model_table: Optional[LiteLLM_ModelTable] = None
updated_at: Optional[datetime] = None
created_at: Optional[datetime] = None

model_config = ConfigDict(protected_namespaces=())
Expand Down Expand Up @@ -1335,9 +1339,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
stored_in_db: Optional[bool]
field_default_value: Any
premium_field: bool = False
nested_fields: Optional[
List[FieldDetail]
] = None # For nested dictionary or Pydantic fields
nested_fields: Optional[List[FieldDetail]] = (
None # For nested dictionary or Pydantic fields
)


class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
Expand Down Expand Up @@ -1491,6 +1495,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
allowed_cache_controls: Optional[list] = []
allowed_routes: Optional[list] = []
permissions: Dict = {}
model_spend: Dict = {}
model_max_budget: Dict = {}
Expand Down Expand Up @@ -1604,9 +1609,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
budget_id: Optional[str] = None
created_at: datetime
updated_at: datetime
user: Optional[
Any
] = None # You might want to replace 'Any' with a more specific type if available
user: Optional[Any] = (
None # You might want to replace 'Any' with a more specific type if available
)
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None

model_config = ConfigDict(protected_namespaces=())
Expand Down Expand Up @@ -2354,9 +2359,9 @@ class TeamModelDeleteRequest(BaseModel):
# Organization Member Requests
class OrganizationMemberAddRequest(OrgMemberAddRequest):
organization_id: str
max_budget_in_organization: Optional[
float
] = None # Users max budget within the organization
max_budget_in_organization: Optional[float] = (
None # Users max budget within the organization
)


class OrganizationMemberDeleteRequest(MemberDeleteRequest):
Expand Down Expand Up @@ -2545,9 +2550,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
Maps provider names to their budget configs.
"""

providers: Dict[
str, ProviderBudgetResponseObject
] = {} # Dictionary mapping provider names to their budget configurations
providers: Dict[str, ProviderBudgetResponseObject] = (
{}
) # Dictionary mapping provider names to their budget configurations


class ProxyStateVariables(TypedDict):
Expand Down Expand Up @@ -2675,9 +2680,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
enforce_rbac: bool = False
roles_jwt_field: Optional[str] = None # v2 on role mappings
role_mappings: Optional[List[RoleMapping]] = None
object_id_jwt_field: Optional[
str
] = None # can be either user / team, inferred from the role mapping
object_id_jwt_field: Optional[str] = (
None # can be either user / team, inferred from the role mapping
)
scope_mappings: Optional[List[ScopeMapping]] = None
enforce_scope_based_access: bool = False
enforce_team_based_model_access: bool = False
Expand Down
11 changes: 8 additions & 3 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
## Common auth checks between jwt + key based auth
"""
Got Valid Token from Cache, DB
Run checks for:
Run checks for:

1. If user can call model
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import asyncio
import re
Expand Down Expand Up @@ -270,6 +270,11 @@ def _is_api_route_allowed(
if valid_token is None:
raise Exception("Invalid proxy server token passed. valid_token=None.")

# Check if Virtual Key is allowed to call the route - Applies to all Roles
RouteChecks.is_virtual_key_allowed_to_call_route(
route=route, valid_token=valid_token
)

if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
RouteChecks.non_proxy_admin_allowed_routes_check(
user_obj=user_obj,
Expand Down
60 changes: 60 additions & 0 deletions litellm/proxy/auth/route_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@


class RouteChecks:
@staticmethod
def is_virtual_key_allowed_to_call_route(
route: str, valid_token: UserAPIKeyAuth
) -> bool:
"""
Raises Exception if Virtual Key is not allowed to call the route
"""

# Only check if valid_token.allowed_routes is set and is a list with at least one item
if valid_token.allowed_routes is None:
return True
if not isinstance(valid_token.allowed_routes, list):
return True
if len(valid_token.allowed_routes) == 0:
return True

# explicit check for allowed routes
if route in valid_token.allowed_routes:
return True

# check if wildcard pattern is allowed
for allowed_route in valid_token.allowed_routes:
if RouteChecks._route_matches_wildcard_pattern(
route=route, pattern=allowed_route
):
return True

raise Exception(
f"Virtual key is not allowed to call this route. Only allowed to call routes: {valid_token.allowed_routes}. Tried to call route: {route}"
)

@staticmethod
def non_proxy_admin_allowed_routes_check(
user_obj: Optional[LiteLLM_UserTable],
Expand Down Expand Up @@ -220,6 +251,35 @@ def _route_matches_pattern(route: str, pattern: str) -> bool:
return True
return False

@staticmethod
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
"""
Check if route matches the wildcard pattern

eg.

pattern: "/scim/v2/*"
route: "/scim/v2/Users"
- returns: True

pattern: "/scim/v2/*"
route: "/chat/completions"
- returns: False


pattern: "/scim/v2/*"
route: "/scim/v2/Users/123"
- returns: True

"""
if pattern.endswith("*"):
# Get the prefix (everything before the wildcard)
prefix = pattern[:-1]
return route.startswith(prefix)
else:
# If there's no wildcard, the pattern and route should match exactly
return route == pattern

@staticmethod
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
"""
Expand Down
43 changes: 23 additions & 20 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ async def generate_key_fn( # noqa: PLR0915
- soft_budget: Optional[float] - Specify soft budget for a given key. Will trigger a slack alert when this soft budget is reached.
- tags: Optional[List[str]] - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
- enforced_params: Optional[List[str]] - List of enforced params for the key (Enterprise only). [Docs](https://docs.litellm.ai/docs/proxy/enterprise#enforce-required-params-for-llm-requests)

- allowed_routes: Optional[list] - List of allowed routes for the key. Store the actual route or store a wildcard pattern for a set of routes. Example - ["/chat/completions", "/embeddings", "/keys/*"]
Examples:

1. Allow users to turn on/off pii masking
Expand Down Expand Up @@ -577,9 +577,9 @@ async def generate_key_fn( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)

response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response
response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response

response = GenerateKeyResponse(**response)

Expand Down Expand Up @@ -723,6 +723,7 @@ async def update_key_fn(
- config: Optional[dict] - [DEPRECATED PARAM] Key-specific config.
- temp_budget_increase: Optional[float] - Temporary budget increase for the key (Enterprise only).
- temp_budget_expiry: Optional[str] - Expiry time for the temporary budget increase (Enterprise only).
- allowed_routes: Optional[list] - List of allowed routes for the key. Store the actual route or store a wildcard pattern for a set of routes. Example - ["/chat/completions", "/embeddings", "/keys/*"]

Example:
```bash
Expand Down Expand Up @@ -1167,6 +1168,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
send_invite_email: Optional[bool] = None,
created_by: Optional[str] = None,
updated_by: Optional[str] = None,
allowed_routes: Optional[list] = None,
):
from litellm.proxy.proxy_server import (
litellm_proxy_budget_name,
Expand Down Expand Up @@ -1272,6 +1274,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
"blocked": blocked,
"created_by": created_by,
"updated_by": updated_by,
"allowed_routes": allowed_routes or [],
}

if (
Expand Down Expand Up @@ -1467,10 +1470,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
)

# Assuming 'db' is your Prisma Client instance
Expand Down Expand Up @@ -1572,9 +1575,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config

try:
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
except Exception:
models = None
# 2. process model table
Expand Down Expand Up @@ -1861,11 +1864,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
)

if complete_user_info_db_obj is None:
Expand Down Expand Up @@ -1926,10 +1929,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
)
if teams is None:
return []
Expand Down
Loading
Loading