Skip to content

Commit 7a4e807

Browse files
Litellm add file validation (#11081)
* fix: cleanup print statement * feat(managed_files.py): add auth check on managed files Implemented for file retrieve + delete calls * feat(files_endpoints.py): support returning files by model name enables managed file support * feat(managed_files/): filter list of files by the ones created by user prevents user from seeing another file * test: update test * fix(files_endpoints.py): list_files - always default to provider based routing * build: add new table to prisma schema
1 parent 1daf23d commit 7a4e807

File tree

12 files changed

+265
-21
lines changed

12 files changed

+265
-21
lines changed

enterprise/enterprise_hooks/managed_files.py

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import base64
66
import json
77
import uuid
8-
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
8+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
99

1010
from fastapi import HTTPException
1111

@@ -26,8 +26,10 @@
2626
)
2727
from litellm.types.llms.openai import (
2828
AllMessageValues,
29+
AsyncCursorPage,
2930
ChatCompletionFileObject,
3031
CreateFileRequest,
32+
FileObject,
3133
OpenAIFileObject,
3234
OpenAIFilesPurpose,
3335
)
@@ -67,6 +69,7 @@ async def store_unified_file_id(
6769
file_object: OpenAIFileObject,
6870
litellm_parent_otel_span: Optional[Span],
6971
model_mappings: Dict[str, str],
72+
user_api_key_dict: UserAPIKeyAuth,
7073
) -> None:
7174
verbose_logger.info(
7275
f"Storing LiteLLM Managed File object with id={file_id} in cache"
@@ -75,6 +78,9 @@ async def store_unified_file_id(
7578
unified_file_id=file_id,
7679
file_object=file_object,
7780
model_mappings=model_mappings,
81+
flat_model_file_ids=list(model_mappings.values()),
82+
created_by=user_api_key_dict.user_id,
83+
updated_by=user_api_key_dict.user_id,
7884
)
7985
await self.internal_usage_cache.async_set_cache(
8086
key=file_id,
@@ -87,6 +93,9 @@ async def store_unified_file_id(
8793
"unified_file_id": file_id,
8894
"file_object": file_object.model_dump_json(),
8995
"model_mappings": json.dumps(model_mappings),
96+
"flat_model_file_ids": list(model_mappings.values()),
97+
"created_by": user_api_key_dict.user_id,
98+
"updated_by": user_api_key_dict.user_id,
9099
}
91100
)
92101

@@ -169,6 +178,18 @@ async def delete_unified_file_id(
169178
)
170179
return initial_value.file_object
171180

181+
async def can_user_call_unified_file_id(
182+
self, unified_file_id: str, user_api_key_dict: UserAPIKeyAuth
183+
) -> bool:
184+
## check if the user has access to the unified file id
185+
user_id = user_api_key_dict.user_id
186+
managed_file = await self.prisma_client.db.litellm_managedfiletable.find_first(
187+
where={"unified_file_id": unified_file_id}
188+
)
189+
if managed_file:
190+
return managed_file.created_by == user_id
191+
return False
192+
172193
async def can_user_call_unified_object_id(
173194
self, unified_object_id: str, user_api_key_dict: UserAPIKeyAuth
174195
) -> bool:
@@ -184,6 +205,44 @@ async def can_user_call_unified_object_id(
184205
return managed_object.created_by == user_id
185206
return False
186207

208+
async def get_user_created_file_ids(
209+
self, user_api_key_dict: UserAPIKeyAuth, model_object_ids: List[str]
210+
) -> List[OpenAIFileObject]:
211+
"""
212+
Get all file ids created by the user for a list of model object ids
213+
214+
Returns:
215+
- List of OpenAIFileObject's
216+
"""
217+
file_ids = await self.prisma_client.db.litellm_managedfiletable.find_many(
218+
where={
219+
"created_by": user_api_key_dict.user_id,
220+
"flat_model_file_ids": {"hasSome": model_object_ids},
221+
}
222+
)
223+
return [OpenAIFileObject(**file_object.file_object) for file_object in file_ids]
224+
225+
async def check_managed_file_id_access(
226+
self, data: Dict, user_api_key_dict: UserAPIKeyAuth
227+
) -> bool:
228+
retrieve_file_id = cast(Optional[str], data.get("file_id"))
229+
potential_file_id = (
230+
_is_base64_encoded_unified_file_id(retrieve_file_id)
231+
if retrieve_file_id
232+
else False
233+
)
234+
if potential_file_id and retrieve_file_id:
235+
if await self.can_user_call_unified_file_id(
236+
retrieve_file_id, user_api_key_dict
237+
):
238+
return True
239+
else:
240+
raise HTTPException(
241+
status_code=403,
242+
detail=f"User {user_api_key_dict.user_id} does not have access to the file {retrieve_file_id}",
243+
)
244+
return False
245+
187246
async def async_pre_call_hook(
188247
self,
189248
user_api_key_dict: UserAPIKeyAuth,
@@ -200,6 +259,9 @@ async def async_pre_call_hook(
200259
"rerank",
201260
"acreate_batch",
202261
"aretrieve_batch",
262+
"acreate_file",
263+
"afile_list",
264+
"afile_delete",
203265
"afile_content",
204266
"acreate_fine_tuning_job",
205267
"aretrieve_fine_tuning_job",
@@ -211,9 +273,14 @@ async def async_pre_call_hook(
211273
- Detect litellm_proxy/ file_id
212274
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
213275
"""
214-
print(
215-
"CALLS ASYNC PRE CALL HOOK - DATA={}, CALL_TYPE={}".format(data, call_type)
216-
)
276+
### HANDLE FILE ACCESS ### - ensure user has access to the file
277+
if (
278+
call_type == CallTypes.afile_content.value
279+
or call_type == CallTypes.afile_delete.value
280+
):
281+
await self.check_managed_file_id_access(data, user_api_key_dict)
282+
283+
### HANDLE TRANSFORMATIONS ###
217284
if call_type == CallTypes.completion.value:
218285
messages = data.get("messages")
219286
if messages:
@@ -298,7 +365,6 @@ async def async_pre_call_hook(
298365
[input_file_id], user_api_key_dict.parent_otel_span
299366
)
300367

301-
print("DATA={}".format(data))
302368
return data
303369

304370
async def async_pre_call_deployment_hook(
@@ -416,6 +482,7 @@ async def acreate_file(
416482
llm_router: Router,
417483
target_model_names_list: List[str],
418484
litellm_parent_otel_span: Span,
485+
user_api_key_dict: UserAPIKeyAuth,
419486
) -> OpenAIFileObject:
420487
responses = await self.create_file_for_each_model(
421488
llm_router=llm_router,
@@ -448,6 +515,7 @@ async def acreate_file(
448515
file_object=response,
449516
litellm_parent_otel_span=litellm_parent_otel_span,
450517
model_mappings=model_mappings,
518+
user_api_key_dict=user_api_key_dict,
451519
)
452520
return response
453521

@@ -560,6 +628,7 @@ def get_batch_id_from_unified_batch_id(self, file_id: str) -> str:
560628
async def async_post_call_success_hook(
561629
self, data: Dict, user_api_key_dict: UserAPIKeyAuth, response: LLMResponseTypes
562630
) -> Any:
631+
print(f"response: {response}, type: {type(response)}")
563632
if isinstance(response, LiteLLMBatch):
564633
## Check if unified_file_id is in the response
565634
unified_file_id = response._hidden_params.get(
@@ -619,6 +688,31 @@ async def async_post_call_success_hook(
619688
user_api_key_dict=user_api_key_dict,
620689
)
621690
)
691+
elif isinstance(response, AsyncCursorPage):
692+
"""
693+
For listing files, filter for the ones created by the user
694+
"""
695+
print("INSIDE ASYNC CURSOR PAGE BLOCK")
696+
## check if file object
697+
if hasattr(response, "data") and isinstance(response.data, list):
698+
if all(
699+
isinstance(file_object, FileObject) for file_object in response.data
700+
):
701+
## Get all file id's
702+
## Check which file id's were created by the user
703+
## Filter the response to only include the files created by the user
704+
## Return the filtered response
705+
file_ids = [
706+
file_object.id
707+
for file_object in cast(List[FileObject], response.data) # type: ignore
708+
]
709+
user_created_file_ids = await self.get_user_created_file_ids(
710+
user_api_key_dict, file_ids
711+
)
712+
## Filter the response to only include the files created by the user
713+
response.data = user_created_file_ids # type: ignore
714+
return response
715+
return response
622716
return response
623717

624718
async def afile_retrieve(
@@ -638,6 +732,7 @@ async def afile_list(
638732
litellm_parent_otel_span: Optional[Span],
639733
**data: Dict,
640734
) -> List[OpenAIFileObject]:
735+
"""Handled in files_endpoints.py"""
641736
return []
642737

643738
async def afile_delete(
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
-- AlterTable
2+
ALTER TABLE "LiteLLM_ManagedFileTable" ADD COLUMN "created_by" TEXT,
3+
ADD COLUMN "flat_model_file_ids" TEXT[] DEFAULT ARRAY[]::TEXT[],
4+
ADD COLUMN "updated_by" TEXT;
5+
6+
-- CreateTable
7+
CREATE TABLE "LiteLLM_ManagedObjectTable" (
8+
"id" TEXT NOT NULL,
9+
"unified_object_id" TEXT NOT NULL,
10+
"model_object_id" TEXT NOT NULL,
11+
"file_object" JSONB NOT NULL,
12+
"file_purpose" TEXT NOT NULL,
13+
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
14+
"created_by" TEXT,
15+
"updated_at" TIMESTAMP(3) NOT NULL,
16+
"updated_by" TEXT,
17+
18+
CONSTRAINT "LiteLLM_ManagedObjectTable_pkey" PRIMARY KEY ("id")
19+
);
20+
21+
-- CreateIndex
22+
CREATE UNIQUE INDEX "LiteLLM_ManagedObjectTable_unified_object_id_key" ON "LiteLLM_ManagedObjectTable"("unified_object_id");
23+
24+
-- CreateIndex
25+
CREATE UNIQUE INDEX "LiteLLM_ManagedObjectTable_model_object_id_key" ON "LiteLLM_ManagedObjectTable"("model_object_id");
26+
27+
-- CreateIndex
28+
CREATE INDEX "LiteLLM_ManagedObjectTable_unified_object_id_idx" ON "LiteLLM_ManagedObjectTable"("unified_object_id");
29+
30+
-- CreateIndex
31+
CREATE INDEX "LiteLLM_ManagedObjectTable_model_object_id_idx" ON "LiteLLM_ManagedObjectTable"("model_object_id");
32+

litellm-proxy-extras/litellm_proxy_extras/schema.prisma

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,30 @@ model LiteLLM_ManagedFileTable {
453453
id String @id @default(uuid())
454454
unified_file_id String @unique // The base64 encoded unified file ID
455455
file_object Json // Stores the OpenAIFileObject
456-
model_mappings Json // Stores the mapping of model_id -> provider_file_id
456+
model_mappings Json
457+
flat_model_file_ids String[] @default([]) // Flat list of model file id's - for faster querying of model id -> unified file id
457458
created_at DateTime @default(now())
459+
created_by String?
458460
updated_at DateTime @updatedAt
461+
updated_by String?
459462
460463
@@index([unified_file_id])
461464
}
462465

466+
model LiteLLM_ManagedObjectTable { // for batches or finetuning jobs which use the
467+
id String @id @default(uuid())
468+
unified_object_id String @unique // The base64 encoded unified file ID
469+
model_object_id String @unique // the id returned by the backend API provider
470+
file_object Json // Stores the OpenAIFileObject
471+
file_purpose String // either 'batch' or 'fine-tune'
472+
created_at DateTime @default(now())
473+
created_by String?
474+
updated_at DateTime @updatedAt
475+
updated_by String?
476+
477+
@@index([unified_object_id])
478+
@@index([model_object_id])
479+
}
463480

464481
model LiteLLM_ManagedVectorStoresTable {
465482
vector_store_id String @id

litellm/llms/base_llm/files/transformation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import httpx
55

6+
from litellm.proxy._types import UserAPIKeyAuth
67
from litellm.types.llms.openai import (
78
AllMessageValues,
89
CreateFileRequest,
@@ -115,6 +116,7 @@ async def acreate_file(
115116
llm_router: Router,
116117
target_model_names_list: List[str],
117118
litellm_parent_otel_span: Span,
119+
user_api_key_dict: UserAPIKeyAuth,
118120
) -> OpenAIFileObject:
119121
pass
120122

litellm/proxy/_new_secret_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ model_list:
22
- model_name: "gemini-2.0-flash-gemini"
33
litellm_params:
44
model: gemini/gemini-2.0-flash
5-
- model_name: "gpt-4o-mini-openai"
5+
- model_name: "gpt-4.1-openai"
66
litellm_params:
7-
model: gpt-4.1-mini-2025-04-14
8-
api_key: os.environ/OPENAI_API_KEY_2
7+
model: gpt-4.1
8+
api_key: os.environ/OPENAI_API_KEY
99
model_info:
1010
access_groups: ["default-openai-models"]
1111
- model_name: "gpt-4o-realtime-preview"

litellm/proxy/_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2883,6 +2883,9 @@ class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
28832883
unified_file_id: str
28842884
file_object: OpenAIFileObject
28852885
model_mappings: Dict[str, str]
2886+
flat_model_file_ids: List[str]
2887+
created_by: Optional[str]
2888+
updated_by: Optional[str]
28862889

28872890

28882891
class LiteLLM_ManagedObjectTable(LiteLLMPydanticObjectBase):

0 commit comments

Comments
 (0)