Skip to content

Commit ec89f7d

Browse files
Handle gemini audio input (#10739)
* fix(vertex_ai/gemini/transformation.py): handle gemini audio data translation Fixes #10070 * feat(vertex_ai/gemini/transformation.py): Handle audio format param translation Fixes #10070 * fix: fix linting error * test: update test * fix: fix linting error
1 parent 0d6efff commit ec89f7d

File tree

3 files changed

+79
-22
lines changed

3 files changed

+79
-22
lines changed

litellm/llms/vertex_ai/gemini/transformation.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_get_image_mime_type_from_url,
1717
)
1818
from litellm.litellm_core_utils.prompt_templates.factory import (
19+
convert_generic_image_chunk_to_openai_image_obj,
1920
convert_to_anthropic_image_obj,
2021
convert_to_gemini_tool_call_invoke,
2122
convert_to_gemini_tool_call_result,
@@ -45,6 +46,7 @@
4546
ToolConfig,
4647
Tools,
4748
)
49+
from litellm.types.utils import GenericImageParsingChunk
4850

4951
from ..common_utils import (
5052
_check_text_in_content,
@@ -154,10 +156,26 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
154156
_parts.append(_part)
155157
elif element["type"] == "input_audio":
156158
audio_element = cast(ChatCompletionAudioObject, element)
157-
if audio_element["input_audio"].get("data") is not None:
159+
audio_data = audio_element["input_audio"].get("data")
160+
audio_format = audio_element["input_audio"].get("format")
161+
if audio_data is not None and audio_format is not None:
162+
audio_format_modified = (
163+
"audio/" + audio_format
164+
if audio_format.startswith("audio/") is False
165+
else audio_format
166+
) # Gemini expects audio/wav, audio/mp3, etc.
167+
openai_image_str = (
168+
convert_generic_image_chunk_to_openai_image_obj(
169+
image_chunk=GenericImageParsingChunk(
170+
type="base64",
171+
media_type=audio_format_modified,
172+
data=audio_data,
173+
)
174+
)
175+
)
158176
_part = _process_gemini_image(
159-
image_url=audio_element["input_audio"]["data"],
160-
format=audio_element["input_audio"].get("format"),
177+
image_url=openai_image_str,
178+
format=audio_format_modified,
161179
)
162180
_parts.append(_part)
163181
elif element["type"] == "file":

litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,19 @@ def _map_thinking_param(
398398

399399
return params
400400

401+
def map_response_modalities(self, value: list) -> list:
402+
response_modalities = []
403+
for modality in value:
404+
if modality == "text":
405+
response_modalities.append("TEXT")
406+
elif modality == "image":
407+
response_modalities.append("IMAGE")
408+
elif modality == "audio":
409+
response_modalities.append("AUDIO")
410+
else:
411+
response_modalities.append("MODALITY_UNSPECIFIED")
412+
return response_modalities
413+
401414
def map_openai_params(
402415
self,
403416
non_default_params: Dict,
@@ -465,14 +478,7 @@ def map_openai_params(
465478
cast(AnthropicThinkingParam, value)
466479
)
467480
elif param == "modalities" and isinstance(value, list):
468-
response_modalities = []
469-
for modality in value:
470-
if modality == "text":
471-
response_modalities.append("TEXT")
472-
elif modality == "image":
473-
response_modalities.append("IMAGE")
474-
else:
475-
response_modalities.append("MODALITY_UNSPECIFIED")
481+
response_modalities = self.map_response_modalities(value)
476482
optional_params["responseModalities"] = response_modalities
477483

478484
if litellm.vertex_ai_safety_settings is not None:

tests/llm_translation/base_llm_unit_tests.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,45 @@ class EventsList(BaseModel):
525525
except litellm.InternalServerError:
526526
pytest.skip("Model is overloaded")
527527

528+
@pytest.mark.flaky(retries=6, delay=1)
529+
def test_audio_input(self):
530+
"""
531+
Test that audio input is supported by the LLM API
532+
"""
533+
from litellm.utils import supports_audio_input
534+
litellm._turn_on_debug()
535+
base_completion_call_args = self.get_base_completion_call_args()
536+
if not supports_audio_input(base_completion_call_args["model"], None):
537+
pytest.skip(
538+
f"Model={base_completion_call_args['model']} does not support audio input"
539+
)
540+
541+
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
542+
response = httpx.get(url)
543+
response.raise_for_status()
544+
wav_data = response.content
545+
encoded_string = base64.b64encode(wav_data).decode("utf-8")
546+
547+
completion = self.completion_function(
548+
**base_completion_call_args,
549+
messages=[
550+
{
551+
"role": "user",
552+
"content": [
553+
{"type": "text", "text": "What is in this recording?"},
554+
{
555+
"type": "input_audio",
556+
"input_audio": {"data": encoded_string, "format": "wav"},
557+
},
558+
],
559+
},
560+
],
561+
)
562+
563+
print(completion.choices[0].message)
564+
565+
566+
528567
@pytest.mark.flaky(retries=6, delay=1)
529568
def test_json_response_format_stream(self):
530569
"""
@@ -979,7 +1018,7 @@ async def test_completion_cost(self):
9791018
assert response._hidden_params["response_cost"] > 0
9801019

9811020
@pytest.mark.parametrize("input_type", ["input_audio", "audio_url"])
982-
@pytest.mark.parametrize("format_specified", [True, False])
1021+
@pytest.mark.parametrize("format_specified", [True])
9831022
def test_supports_audio_input(self, input_type, format_specified):
9841023
from litellm.utils import return_raw_request, supports_audio_input
9851024
from litellm.types.utils import CallTypes
@@ -1010,16 +1049,10 @@ def test_supports_audio_input(self, input_type, format_specified):
10101049
test_file_id = "gs://bucket/file.wav"
10111050

10121051
if input_type == "input_audio":
1013-
if format_specified:
1014-
audio_content.append({
1015-
"type": "input_audio",
1016-
"input_audio": {"data": encoded_string, "format": audio_format},
1017-
})
1018-
else:
1019-
audio_content.append({
1020-
"type": "input_audio",
1021-
"input_audio": {"data": encoded_string},
1022-
})
1052+
audio_content.append({
1053+
"type": "input_audio",
1054+
"input_audio": {"data": encoded_string, "format": audio_format},
1055+
})
10231056
elif input_type == "audio_url":
10241057
audio_content.append(
10251058
{

0 commit comments

Comments
 (0)