Skip to content

Commit 54ffb1f

Browse files
krrishdholakiastefan--
authored andcommitted
Anthropic - pass file url's as Document content type + Gemini - cache token tracking on streaming calls (BerriAI#11387)
* fix(anthropic/): fix regression when passing file url's to the 'file_id' parameter add test and ensure anthropic file url's are correctly sent as 'document' blocks * fix(vertex_and_google_ai_studio.py): Use same usage calculation function as non-streaming Closes BerriAI#10667 * test(test_vertex_and_google_ai_studio_gemini.py): update test
1 parent 685b600 commit 54ffb1f

File tree

7 files changed

+161
-44
lines changed

7 files changed

+161
-44
lines changed

litellm/litellm_core_utils/prompt_templates/factory.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,21 @@ def select_anthropic_content_block_type_for_file(
13961396
return "container_upload"
13971397

13981398

1399+
def anthropic_infer_file_id_content_type(
1400+
file_id: str,
1401+
) -> Literal["document_url", "container_upload"]:
1402+
"""
1403+
Use when 'format' not provided.
1404+
1405+
- URL's - assume are document_url
1406+
- Else - assume is container_upload
1407+
"""
1408+
if file_id.startswith("http") or file_id.startswith("https"):
1409+
return "document_url"
1410+
else:
1411+
return "container_upload"
1412+
1413+
13991414
def anthropic_process_openai_file_message(
14001415
message: ChatCompletionFileObject,
14011416
) -> Union[
@@ -1425,7 +1440,7 @@ def anthropic_process_openai_file_message(
14251440
content_block_type = (
14261441
select_anthropic_content_block_type_for_file(format)
14271442
if format
1428-
else "container_upload"
1443+
else anthropic_infer_file_id_content_type(file_id)
14291444
)
14301445
return_block_param: Optional[
14311446
Union[
@@ -1442,6 +1457,14 @@ def anthropic_process_openai_file_message(
14421457
file_id=file_id,
14431458
),
14441459
)
1460+
elif content_block_type == "document_url":
1461+
return_block_param = AnthropicMessagesDocumentParam(
1462+
type="document",
1463+
source=AnthropicContentParamSourceUrl(
1464+
type="url",
1465+
url=file_id,
1466+
),
1467+
)
14451468
elif content_block_type == "image":
14461469
return_block_param = AnthropicMessagesImageParam(
14471470
type="image",

litellm/llms/gemini/realtime/transformation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def transform_response_done_event(
658658
modality.lower() for modality in cast(List[str], gemini_modalities)
659659
]
660660
if "usageMetadata" in message:
661-
_chat_completion_usage = VertexGeminiConfig()._calculate_usage(
661+
_chat_completion_usage = VertexGeminiConfig._calculate_usage(
662662
completion_response=message,
663663
)
664664
else:

litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,8 @@ def _handle_content_policy_violation(
902902

903903
return model_response
904904

905-
def is_candidate_token_count_inclusive(self, usage_metadata: UsageMetadata) -> bool:
905+
@staticmethod
906+
def is_candidate_token_count_inclusive(usage_metadata: UsageMetadata) -> bool:
906907
"""
907908
Check if the candidate token count is inclusive of the thinking token count
908909
@@ -919,13 +920,16 @@ def is_candidate_token_count_inclusive(self, usage_metadata: UsageMetadata) -> b
919920
else:
920921
return False
921922

923+
@staticmethod
922924
def _calculate_usage(
923-
self,
924925
completion_response: Union[
925926
GenerateContentResponseBody, BidiGenerateContentServerMessage
926927
],
927928
) -> Usage:
928-
if "usageMetadata" not in completion_response:
929+
if (
930+
completion_response is not None
931+
and "usageMetadata" not in completion_response
932+
):
929933
raise ValueError(
930934
f"usageMetadata not found in completion_response. Got={completion_response}"
931935
)
@@ -936,33 +940,30 @@ def _calculate_usage(
936940
reasoning_tokens: Optional[int] = None
937941
response_tokens: Optional[int] = None
938942
response_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
939-
if "cachedContentTokenCount" in completion_response["usageMetadata"]:
940-
cached_tokens = completion_response["usageMetadata"][
941-
"cachedContentTokenCount"
942-
]
943+
usage_metadata = completion_response["usageMetadata"]
944+
if "cachedContentTokenCount" in usage_metadata:
945+
cached_tokens = usage_metadata["cachedContentTokenCount"]
943946

944947
## GEMINI LIVE API ONLY PARAMS ##
945-
if "responseTokenCount" in completion_response["usageMetadata"]:
946-
response_tokens = completion_response["usageMetadata"]["responseTokenCount"]
947-
if "responseTokensDetails" in completion_response["usageMetadata"]:
948+
if "responseTokenCount" in usage_metadata:
949+
response_tokens = usage_metadata["responseTokenCount"]
950+
if "responseTokensDetails" in usage_metadata:
948951
response_tokens_details = CompletionTokensDetailsWrapper()
949-
for detail in completion_response["usageMetadata"]["responseTokensDetails"]:
952+
for detail in usage_metadata["responseTokensDetails"]:
950953
if detail["modality"] == "TEXT":
951954
response_tokens_details.text_tokens = detail["tokenCount"]
952955
elif detail["modality"] == "AUDIO":
953956
response_tokens_details.audio_tokens = detail["tokenCount"]
954957
#########################################################
955958

956-
if "promptTokensDetails" in completion_response["usageMetadata"]:
957-
for detail in completion_response["usageMetadata"]["promptTokensDetails"]:
959+
if "promptTokensDetails" in usage_metadata:
960+
for detail in usage_metadata["promptTokensDetails"]:
958961
if detail["modality"] == "AUDIO":
959962
audio_tokens = detail["tokenCount"]
960963
elif detail["modality"] == "TEXT":
961964
text_tokens = detail["tokenCount"]
962-
if "thoughtsTokenCount" in completion_response["usageMetadata"]:
963-
reasoning_tokens = completion_response["usageMetadata"][
964-
"thoughtsTokenCount"
965-
]
965+
if "thoughtsTokenCount" in usage_metadata:
966+
reasoning_tokens = usage_metadata["thoughtsTokenCount"]
966967
prompt_tokens_details = PromptTokensDetailsWrapper(
967968
cached_tokens=cached_tokens,
968969
audio_tokens=audio_tokens,
@@ -973,19 +974,15 @@ def _calculate_usage(
973974
"candidatesTokenCount", 0
974975
)
975976
if (
976-
not self.is_candidate_token_count_inclusive(
977-
completion_response["usageMetadata"]
978-
)
977+
not VertexGeminiConfig.is_candidate_token_count_inclusive(usage_metadata)
979978
and reasoning_tokens
980979
):
981980
completion_tokens = reasoning_tokens + completion_tokens
982981
## GET USAGE ##
983982
usage = Usage(
984-
prompt_tokens=completion_response["usageMetadata"].get(
985-
"promptTokenCount", 0
986-
),
983+
prompt_tokens=usage_metadata.get("promptTokenCount", 0),
987984
completion_tokens=completion_tokens,
988-
total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
985+
total_tokens=usage_metadata.get("totalTokenCount", 0),
989986
prompt_tokens_details=prompt_tokens_details,
990987
reasoning_tokens=reasoning_tokens,
991988
completion_tokens_details=response_tokens_details,
@@ -1169,7 +1166,9 @@ def transform_response(
11691166
_candidates, model_response, logging_obj.optional_params
11701167
)
11711168

1172-
usage = self._calculate_usage(completion_response=completion_response)
1169+
usage = VertexGeminiConfig._calculate_usage(
1170+
completion_response=completion_response
1171+
)
11731172
setattr(model_response, "usage", usage)
11741173

11751174
## ADD METADATA TO RESPONSE ##
@@ -1806,21 +1805,8 @@ def chunk_parser(self, chunk: dict) -> "ModelResponseStream":
18061805
## GEMINI SETS FINISHREASON ON EVERY CHUNK!
18071806

18081807
if "usageMetadata" in processed_chunk:
1809-
usage = Usage(
1810-
prompt_tokens=processed_chunk["usageMetadata"].get(
1811-
"promptTokenCount", 0
1812-
),
1813-
completion_tokens=processed_chunk["usageMetadata"].get(
1814-
"candidatesTokenCount", 0
1815-
),
1816-
total_tokens=processed_chunk["usageMetadata"].get(
1817-
"totalTokenCount", 0
1818-
),
1819-
completion_tokens_details=CompletionTokensDetailsWrapper(
1820-
reasoning_tokens=processed_chunk["usageMetadata"].get(
1821-
"thoughtsTokenCount", 0
1822-
)
1823-
),
1808+
usage = VertexGeminiConfig._calculate_usage(
1809+
completion_response=processed_chunk,
18241810
)
18251811

18261812
args: Dict[str, Any] = {

litellm/types/llms/anthropic.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class AnthropicContentParamSource(TypedDict):
114114
data: str
115115

116116

117+
class AnthropicContentParamSourceUrl(TypedDict):
118+
type: Literal["url"]
119+
url: str
120+
121+
117122
class AnthropicContentParamSourceFileId(TypedDict):
118123
type: Literal["file"]
119124
file_id: str
@@ -140,7 +145,11 @@ class CitationsObject(TypedDict):
140145
class AnthropicMessagesDocumentParam(TypedDict, total=False):
141146
type: Required[Literal["document"]]
142147
source: Required[
143-
Union[AnthropicContentParamSource, AnthropicContentParamSourceFileId]
148+
Union[
149+
AnthropicContentParamSource,
150+
AnthropicContentParamSourceFileId,
151+
AnthropicContentParamSourceUrl,
152+
]
144153
]
145154
cache_control: Optional[Union[dict, ChatCompletionCachedContent]]
146155
title: str

tests/llm_translation/base_llm_unit_tests.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,40 @@ async def test_pdf_handling(self, pdf_messages, sync_mode):
248248
)
249249

250250
assert response is not None
251+
252+
@pytest.mark.asyncio
253+
async def test_async_pdf_handling_with_file_id(self):
254+
from litellm.utils import supports_pdf_input
255+
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
256+
litellm.model_cost = litellm.get_model_cost_map(url="")
257+
258+
litellm._turn_on_debug()
259+
260+
261+
image_content = [
262+
{"type": "text", "text": "What's this file about?"},
263+
{
264+
"type": "file",
265+
"file": {
266+
"file_id": "https://upload.wikimedia.org/wikipedia/commons/2/20/Re_example.pdf"
267+
},
268+
},
269+
]
270+
271+
image_messages = [{"role": "user", "content": image_content}]
272+
273+
base_completion_call_args = self.get_base_completion_call_args()
274+
275+
if not supports_pdf_input(base_completion_call_args["model"], None):
276+
pytest.skip("Model does not support image input")
277+
278+
response = await self.async_completion_function(
279+
**base_completion_call_args,
280+
messages=image_messages,
281+
)
282+
283+
assert response is not None
284+
251285

252286
def test_file_data_unit_test(self, pdf_messages):
253287
from litellm.utils import supports_pdf_input, return_raw_request

tests/llm_translation/test_anthropic_completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,4 @@ def test_anthropic_text_editor():
12721272
print(e)
12731273

12741274
assert response is not None
1275+

tests/test_litellm/llms/vertex_ai/gemini/test_vertex_and_google_ai_studio_gemini.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,10 @@ def test_vertex_ai_candidate_token_count_inclusive(
309309
Test that the candidate token count is inclusive of the thinking token count
310310
"""
311311
v = VertexGeminiConfig()
312-
assert v.is_candidate_token_count_inclusive(usage_metadata) is inclusive
312+
assert (
313+
VertexGeminiConfig.is_candidate_token_count_inclusive(usage_metadata)
314+
is inclusive
315+
)
313316

314317
usage = v._calculate_usage(completion_response={"usageMetadata": usage_metadata})
315318
assert usage.prompt_tokens == expected_usage.prompt_tokens
@@ -490,3 +493,64 @@ def test_vertex_ai_map_tool_with_anyof():
490493
"anyOf": [{"type": "string", "nullable": True, "title": "Base Branch"}]
491494
}, f"Expected only anyOf field and its contents to be kept, but got {tools[0]['function_declarations'][0]['parameters']['properties']['base_branch']}"
492495

496+
497+
def test_vertex_ai_streaming_usage_calculation():
498+
"""
499+
Ensure streaming usage calculation uses same function as non-streaming usage calculation
500+
"""
501+
from unittest.mock import patch
502+
503+
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
504+
ModelResponseIterator,
505+
VertexGeminiConfig,
506+
)
507+
508+
v = VertexGeminiConfig()
509+
usage_metadata = {
510+
"promptTokenCount": 57,
511+
"candidatesTokenCount": 10,
512+
"totalTokenCount": 67,
513+
}
514+
515+
# Test streaming chunk parsing
516+
with patch.object(VertexGeminiConfig, "_calculate_usage") as mock_calculate_usage:
517+
# Create a streaming chunk
518+
chunk = {
519+
"candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
520+
"usageMetadata": usage_metadata,
521+
}
522+
523+
# Create iterator and parse chunk
524+
iterator = ModelResponseIterator(
525+
streaming_response=[], sync_stream=True, logging_obj=MagicMock()
526+
)
527+
iterator.chunk_parser(chunk)
528+
529+
# Verify _calculate_usage was called with correct parameters
530+
mock_calculate_usage.assert_called_once_with(completion_response=chunk)
531+
532+
# Test non-streaming response parsing
533+
with patch.object(VertexGeminiConfig, "_calculate_usage") as mock_calculate_usage:
534+
# Create a completion response
535+
completion_response = {
536+
"candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
537+
"usageMetadata": usage_metadata,
538+
}
539+
540+
# Parse completion response
541+
v.transform_response(
542+
model="gemini-pro",
543+
raw_response=MagicMock(json=lambda: completion_response),
544+
model_response=ModelResponse(),
545+
logging_obj=MagicMock(),
546+
request_data={},
547+
messages=[],
548+
optional_params={},
549+
litellm_params={},
550+
encoding=None,
551+
)
552+
553+
# Verify _calculate_usage was called with correct parameters
554+
mock_calculate_usage.assert_called_once_with(
555+
completion_response=completion_response,
556+
)

0 commit comments

Comments
 (0)