5
5
import base64
6
6
import json
7
7
import uuid
8
- from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Union , cast
8
+ from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Tuple , Union , cast
9
9
10
10
from fastapi import HTTPException
11
11
26
26
)
27
27
from litellm .types .llms .openai import (
28
28
AllMessageValues ,
29
+ AsyncCursorPage ,
29
30
ChatCompletionFileObject ,
30
31
CreateFileRequest ,
32
+ FileObject ,
31
33
OpenAIFileObject ,
32
34
OpenAIFilesPurpose ,
33
35
)
@@ -67,6 +69,7 @@ async def store_unified_file_id(
67
69
file_object : OpenAIFileObject ,
68
70
litellm_parent_otel_span : Optional [Span ],
69
71
model_mappings : Dict [str , str ],
72
+ user_api_key_dict : UserAPIKeyAuth ,
70
73
) -> None :
71
74
verbose_logger .info (
72
75
f"Storing LiteLLM Managed File object with id={ file_id } in cache"
@@ -75,6 +78,9 @@ async def store_unified_file_id(
75
78
unified_file_id = file_id ,
76
79
file_object = file_object ,
77
80
model_mappings = model_mappings ,
81
+ flat_model_file_ids = list (model_mappings .values ()),
82
+ created_by = user_api_key_dict .user_id ,
83
+ updated_by = user_api_key_dict .user_id ,
78
84
)
79
85
await self .internal_usage_cache .async_set_cache (
80
86
key = file_id ,
@@ -87,6 +93,9 @@ async def store_unified_file_id(
87
93
"unified_file_id" : file_id ,
88
94
"file_object" : file_object .model_dump_json (),
89
95
"model_mappings" : json .dumps (model_mappings ),
96
+ "flat_model_file_ids" : list (model_mappings .values ()),
97
+ "created_by" : user_api_key_dict .user_id ,
98
+ "updated_by" : user_api_key_dict .user_id ,
90
99
}
91
100
)
92
101
@@ -169,6 +178,18 @@ async def delete_unified_file_id(
169
178
)
170
179
return initial_value .file_object
171
180
181
+ async def can_user_call_unified_file_id (
182
+ self , unified_file_id : str , user_api_key_dict : UserAPIKeyAuth
183
+ ) -> bool :
184
+ ## check if the user has access to the unified file id
185
+ user_id = user_api_key_dict .user_id
186
+ managed_file = await self .prisma_client .db .litellm_managedfiletable .find_first (
187
+ where = {"unified_file_id" : unified_file_id }
188
+ )
189
+ if managed_file :
190
+ return managed_file .created_by == user_id
191
+ return False
192
+
172
193
async def can_user_call_unified_object_id (
173
194
self , unified_object_id : str , user_api_key_dict : UserAPIKeyAuth
174
195
) -> bool :
@@ -184,6 +205,44 @@ async def can_user_call_unified_object_id(
184
205
return managed_object .created_by == user_id
185
206
return False
186
207
208
+ async def get_user_created_file_ids (
209
+ self , user_api_key_dict : UserAPIKeyAuth , model_object_ids : List [str ]
210
+ ) -> List [OpenAIFileObject ]:
211
+ """
212
+ Get all file ids created by the user for a list of model object ids
213
+
214
+ Returns:
215
+ - List of OpenAIFileObject's
216
+ """
217
+ file_ids = await self .prisma_client .db .litellm_managedfiletable .find_many (
218
+ where = {
219
+ "created_by" : user_api_key_dict .user_id ,
220
+ "flat_model_file_ids" : {"hasSome" : model_object_ids },
221
+ }
222
+ )
223
+ return [OpenAIFileObject (** file_object .file_object ) for file_object in file_ids ]
224
+
225
+ async def check_managed_file_id_access (
226
+ self , data : Dict , user_api_key_dict : UserAPIKeyAuth
227
+ ) -> bool :
228
+ retrieve_file_id = cast (Optional [str ], data .get ("file_id" ))
229
+ potential_file_id = (
230
+ _is_base64_encoded_unified_file_id (retrieve_file_id )
231
+ if retrieve_file_id
232
+ else False
233
+ )
234
+ if potential_file_id and retrieve_file_id :
235
+ if await self .can_user_call_unified_file_id (
236
+ retrieve_file_id , user_api_key_dict
237
+ ):
238
+ return True
239
+ else :
240
+ raise HTTPException (
241
+ status_code = 403 ,
242
+ detail = f"User { user_api_key_dict .user_id } does not have access to the file { retrieve_file_id } " ,
243
+ )
244
+ return False
245
+
187
246
async def async_pre_call_hook (
188
247
self ,
189
248
user_api_key_dict : UserAPIKeyAuth ,
@@ -200,6 +259,9 @@ async def async_pre_call_hook(
200
259
"rerank" ,
201
260
"acreate_batch" ,
202
261
"aretrieve_batch" ,
262
+ "acreate_file" ,
263
+ "afile_list" ,
264
+ "afile_delete" ,
203
265
"afile_content" ,
204
266
"acreate_fine_tuning_job" ,
205
267
"aretrieve_fine_tuning_job" ,
@@ -211,9 +273,14 @@ async def async_pre_call_hook(
211
273
- Detect litellm_proxy/ file_id
212
274
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
213
275
"""
214
- print (
215
- "CALLS ASYNC PRE CALL HOOK - DATA={}, CALL_TYPE={}" .format (data , call_type )
216
- )
276
+ ### HANDLE FILE ACCESS ### - ensure user has access to the file
277
+ if (
278
+ call_type == CallTypes .afile_content .value
279
+ or call_type == CallTypes .afile_delete .value
280
+ ):
281
+ await self .check_managed_file_id_access (data , user_api_key_dict )
282
+
283
+ ### HANDLE TRANSFORMATIONS ###
217
284
if call_type == CallTypes .completion .value :
218
285
messages = data .get ("messages" )
219
286
if messages :
@@ -298,7 +365,6 @@ async def async_pre_call_hook(
298
365
[input_file_id ], user_api_key_dict .parent_otel_span
299
366
)
300
367
301
- print ("DATA={}" .format (data ))
302
368
return data
303
369
304
370
async def async_pre_call_deployment_hook (
@@ -416,6 +482,7 @@ async def acreate_file(
416
482
llm_router : Router ,
417
483
target_model_names_list : List [str ],
418
484
litellm_parent_otel_span : Span ,
485
+ user_api_key_dict : UserAPIKeyAuth ,
419
486
) -> OpenAIFileObject :
420
487
responses = await self .create_file_for_each_model (
421
488
llm_router = llm_router ,
@@ -448,6 +515,7 @@ async def acreate_file(
448
515
file_object = response ,
449
516
litellm_parent_otel_span = litellm_parent_otel_span ,
450
517
model_mappings = model_mappings ,
518
+ user_api_key_dict = user_api_key_dict ,
451
519
)
452
520
return response
453
521
@@ -560,6 +628,7 @@ def get_batch_id_from_unified_batch_id(self, file_id: str) -> str:
560
628
async def async_post_call_success_hook (
561
629
self , data : Dict , user_api_key_dict : UserAPIKeyAuth , response : LLMResponseTypes
562
630
) -> Any :
631
+ print (f"response: { response } , type: { type (response )} " )
563
632
if isinstance (response , LiteLLMBatch ):
564
633
## Check if unified_file_id is in the response
565
634
unified_file_id = response ._hidden_params .get (
@@ -619,6 +688,31 @@ async def async_post_call_success_hook(
619
688
user_api_key_dict = user_api_key_dict ,
620
689
)
621
690
)
691
+ elif isinstance (response , AsyncCursorPage ):
692
+ """
693
+ For listing files, filter for the ones created by the user
694
+ """
695
+ print ("INSIDE ASYNC CURSOR PAGE BLOCK" )
696
+ ## check if file object
697
+ if hasattr (response , "data" ) and isinstance (response .data , list ):
698
+ if all (
699
+ isinstance (file_object , FileObject ) for file_object in response .data
700
+ ):
701
+ ## Get all file id's
702
+ ## Check which file id's were created by the user
703
+ ## Filter the response to only include the files created by the user
704
+ ## Return the filtered response
705
+ file_ids = [
706
+ file_object .id
707
+ for file_object in cast (List [FileObject ], response .data ) # type: ignore
708
+ ]
709
+ user_created_file_ids = await self .get_user_created_file_ids (
710
+ user_api_key_dict , file_ids
711
+ )
712
+ ## Filter the response to only include the files created by the user
713
+ response .data = user_created_file_ids # type: ignore
714
+ return response
715
+ return response
622
716
return response
623
717
624
718
async def afile_retrieve (
@@ -638,6 +732,7 @@ async def afile_list(
638
732
litellm_parent_otel_span : Optional [Span ],
639
733
** data : Dict ,
640
734
) -> List [OpenAIFileObject ]:
735
+ """Handled in files_endpoints.py"""
641
736
return []
642
737
643
738
async def afile_delete (
0 commit comments