Skip to content

[Bug Fix] Add Cost Tracking for gpt-image-1 when quality is unspecified #10247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from litellm.responses.utils import ResponseAPILoggingUtils
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
ImageGenerationRequestQuality,
OpenAIRealtimeStreamList,
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
Expand Down Expand Up @@ -642,9 +643,9 @@ def completion_cost( # noqa: PLR0915
or isinstance(completion_response, dict)
): # tts returns a custom class
if isinstance(completion_response, dict):
usage_obj: Optional[
Union[dict, Usage]
] = completion_response.get("usage", {})
usage_obj: Optional[Union[dict, Usage]] = (
completion_response.get("usage", {})
)
else:
usage_obj = getattr(completion_response, "usage", {})
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
Expand Down Expand Up @@ -913,7 +914,7 @@ def completion_cost( # noqa: PLR0915


def get_response_cost_from_hidden_params(
hidden_params: Union[dict, BaseModel]
hidden_params: Union[dict, BaseModel],
) -> Optional[float]:
if isinstance(hidden_params, BaseModel):
_hidden_params_dict = hidden_params.model_dump()
Expand Down Expand Up @@ -1101,31 +1102,38 @@ def default_image_cost_calculator(
f"{quality}/{base_model_name}" if quality else base_model_name
)

# gpt-image-1 models use low, medium, high quality. If user did not specify quality, use medium fot gpt-image-1 model family
model_name_with_v2_quality = (
f"{ImageGenerationRequestQuality.MEDIUM.value}/{base_model_name}"
)

verbose_logger.debug(
f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
)

model_without_provider = f"{size_str}/{model.split('/')[-1]}"
model_with_quality_without_provider = (
f"{quality}/{model_without_provider}" if quality else model_without_provider
)

# Try model with quality first, fall back to base model name
if model_name_with_quality in litellm.model_cost:
cost_info = litellm.model_cost[model_name_with_quality]
elif base_model_name in litellm.model_cost:
cost_info = litellm.model_cost[base_model_name]
else:
# Try without provider prefix
model_without_provider = f"{size_str}/{model.split('/')[-1]}"
model_with_quality_without_provider = (
f"{quality}/{model_without_provider}" if quality else model_without_provider
cost_info: Optional[dict] = None
models_to_check = [
model_name_with_quality,
base_model_name,
model_name_with_v2_quality,
model_with_quality_without_provider,
model_without_provider,
]
for model in models_to_check:
if model in litellm.model_cost:
cost_info = litellm.model_cost[model]
break
if cost_info is None:
raise Exception(
f"Model not found in cost map. Tried checking {models_to_check}"
)

if model_with_quality_without_provider in litellm.model_cost:
cost_info = litellm.model_cost[model_with_quality_without_provider]
elif model_without_provider in litellm.model_cost:
cost_info = litellm.model_cost[model_without_provider]
else:
raise Exception(
f"Model not found in cost map. Tried {model_name_with_quality}, {base_model_name}, {model_with_quality_without_provider}, and {model_without_provider}"
)

return cost_info["input_cost_per_pixel"] * height * width * n


Expand Down
27 changes: 14 additions & 13 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
ChatCompletionPredictionContentParam,
ChatCompletionUserMessage,
HttpxBinaryResponseContent,
ImageGenerationRequestQuality,
)
from .types.utils import (
LITELLM_IMAGE_VARIATION_PROVIDERS,
Expand Down Expand Up @@ -2688,9 +2689,9 @@ def completion( # type: ignore # noqa: PLR0915
"aws_region_name" not in optional_params
or optional_params["aws_region_name"] is None
):
optional_params[
"aws_region_name"
] = aws_bedrock_client.meta.region_name
optional_params["aws_region_name"] = (
aws_bedrock_client.meta.region_name
)

bedrock_route = BedrockModelInfo.get_bedrock_route(model)
if bedrock_route == "converse":
Expand Down Expand Up @@ -4412,9 +4413,9 @@ def adapter_completion(
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)

response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[
Union[BaseModel, AdapterCompletionStreamWrapper]
] = None
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
None
)
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
Expand Down Expand Up @@ -4567,7 +4568,7 @@ def image_generation( # noqa: PLR0915
prompt: str,
model: Optional[str] = None,
n: Optional[int] = None,
quality: Optional[str] = None,
quality: Optional[Union[str, ImageGenerationRequestQuality]] = None,
response_format: Optional[str] = None,
size: Optional[str] = None,
style: Optional[str] = None,
Expand Down Expand Up @@ -5834,9 +5835,9 @@ def stream_chunk_builder( # noqa: PLR0915
]

if len(content_chunks) > 0:
response["choices"][0]["message"][
"content"
] = processor.get_combined_content(content_chunks)
response["choices"][0]["message"]["content"] = (
processor.get_combined_content(content_chunks)
)

reasoning_chunks = [
chunk
Expand All @@ -5847,9 +5848,9 @@ def stream_chunk_builder( # noqa: PLR0915
]

if len(reasoning_chunks) > 0:
response["choices"][0]["message"][
"reasoning_content"
] = processor.get_combined_reasoning_content(reasoning_chunks)
response["choices"][0]["message"]["reasoning_content"] = (
processor.get_combined_reasoning_content(reasoning_chunks)
)

audio_chunks = [
chunk
Expand Down
45 changes: 27 additions & 18 deletions litellm/types/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,12 +824,12 @@ def __init__(self, **kwargs):

class Hyperparameters(BaseModel):
batch_size: Optional[Union[str, int]] = None # "Number of examples in each batch."
learning_rate_multiplier: Optional[
Union[str, float]
] = None # Scaling factor for the learning rate
n_epochs: Optional[
Union[str, int]
] = None # "The number of epochs to train the model for"
learning_rate_multiplier: Optional[Union[str, float]] = (
None # Scaling factor for the learning rate
)
n_epochs: Optional[Union[str, int]] = (
None # "The number of epochs to train the model for"
)


class FineTuningJobCreate(BaseModel):
Expand All @@ -856,18 +856,18 @@ class FineTuningJobCreate(BaseModel):

model: str # "The name of the model to fine-tune."
training_file: str # "The ID of an uploaded file that contains training data."
hyperparameters: Optional[
Hyperparameters
] = None # "The hyperparameters used for the fine-tuning job."
suffix: Optional[
str
] = None # "A string of up to 18 characters that will be added to your fine-tuned model name."
validation_file: Optional[
str
] = None # "The ID of an uploaded file that contains validation data."
integrations: Optional[
List[str]
] = None # "A list of integrations to enable for your fine-tuning job."
hyperparameters: Optional[Hyperparameters] = (
None # "The hyperparameters used for the fine-tuning job."
)
suffix: Optional[str] = (
None # "A string of up to 18 characters that will be added to your fine-tuned model name."
)
validation_file: Optional[str] = (
None # "The ID of an uploaded file that contains validation data."
)
integrations: Optional[List[str]] = (
None # "A list of integrations to enable for your fine-tuning job."
)
seed: Optional[int] = None # "The seed controls the reproducibility of the job."


Expand Down Expand Up @@ -1259,3 +1259,12 @@ class OpenAIRealtimeStreamResponseBaseObject(TypedDict):
OpenAIRealtimeStreamList = List[
Union[OpenAIRealtimeStreamResponseBaseObject, OpenAIRealtimeStreamSessionEvents]
]


class ImageGenerationRequestQuality(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
AUTO = "auto"
STANDARD = "standard"
HD = "hd"
6 changes: 3 additions & 3 deletions tests/image_gen_tests/base_image_generation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ async def test_basic_image_generation(self):
logged_standard_logging_payload = custom_logger.standard_logging_payload
print("logged_standard_logging_payload", logged_standard_logging_payload)
assert logged_standard_logging_payload is not None
# assert logged_standard_logging_payload["response_cost"] is not None
# assert logged_standard_logging_payload["response_cost"] > 0
assert logged_standard_logging_payload["response_cost"] is not None
assert logged_standard_logging_payload["response_cost"] > 0

from openai.types.images_response import ImagesResponse

Expand All @@ -85,4 +85,4 @@ async def test_basic_image_generation(self):
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
pytest.fail(f"An exception occurred - {str(e)}")
3 changes: 3 additions & 0 deletions tests/image_gen_tests/test_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class TestOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "dall-e-3"}

class TestOpenAIGPTImage1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "gpt-image-1"}

class TestAzureOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
Expand Down
Loading