Skip to content

Litellm batch api background cost calc #12125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Polls LiteLLM_ManagedObjectTable to check if the batch job is complete, and if the cost has been tracked.
"""

import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Optional, cast

from litellm._logging import verbose_proxy_logger
Expand Down Expand Up @@ -42,6 +44,7 @@ async def check_batch_cost(self):
calculate_batch_cost_and_usage,
)
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
get_batch_id_from_unified_batch_id,
Expand All @@ -55,6 +58,8 @@ async def check_batch_cost(self):
}
)

completed_jobs = []

for job in jobs:
# get the model from the job
unified_object_id = job.unified_object_id
Expand All @@ -81,6 +86,10 @@ async def check_batch_cost(self):
response = await self.llm_router.aretrieve_batch(
model=model_id,
batch_id=batch_id,
litellm_metadata={
"user_api_key_user_id": job.created_by or "default-user-id",
"batch_ignore_default_logging": True,
},
)

## RETRIEVE THE BATCH JOB OUTPUT FILE
Expand Down Expand Up @@ -129,6 +138,38 @@ async def check_batch_cost(self):
)
)

if response.status != "validating":
# mark for updating
pass
logging_obj = LiteLLMLogging(
model=batch_models[0],
messages=[{"role": "user", "content": "<retrieve_batch>"}],
stream=False,
call_type="aretrieve_batch",
start_time=datetime.now(),
litellm_call_id=str(uuid.uuid4()),
function_id=str(uuid.uuid4()),
)

logging_obj.update_environment_variables(
litellm_params={
"metadata": {
"user_api_key_user_id": job.created_by or "default-user-id",
}
},
optional_params={},
)

await logging_obj.async_success_handler(
result=response,
batch_cost=batch_cost,
batch_usage=batch_usage,
batch_models=batch_models,
)

# mark the job as complete
completed_jobs.append(job)

if len(completed_jobs) > 0:
# mark the jobs as complete
await self.prisma_client.db.litellm_managedobjecttable.update_many(
where={"id": {"in": [job.id for job in completed_jobs]}},
data={"status": "complete"},
)
25 changes: 20 additions & 5 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
from .specialty_caches.dynamic_logging_cache import DynamicLoggingCache

if TYPE_CHECKING:
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
try:
from litellm_enterprise.enterprise_callbacks.callback_controls import (
Expand Down Expand Up @@ -1997,18 +1998,32 @@ async def async_success_handler( # noqa: PLR0915
return

## CALCULATE COST FOR BATCH JOBS
if (
self.call_type == CallTypes.aretrieve_batch.value
and isinstance(result, LiteLLMBatch)
and result.status == "completed"
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
result, LiteLLMBatch
):
litellm_params = self.litellm_params or {}
litellm_metadata = litellm_params.get("litellm_metadata", {})
if (
litellm_metadata.get("batch_ignore_default_logging", False) is True
): # polling job will query these frequently, don't spam db logs
return

from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
)

# check if file id is a unified file id
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(result.id)
if not is_base64_unified_file_id: # only run for non-unified file ids

batch_cost = kwargs.get("batch_cost", None)
batch_usage = kwargs.get("batch_usage", None)
batch_models = kwargs.get("batch_models", None)
if all([batch_cost, batch_usage, batch_models]) is not None:
result._hidden_params["response_cost"] = batch_cost
result._hidden_params["batch_models"] = batch_models
result.usage = batch_usage

elif not is_base64_unified_file_id: # only run for non-unified file ids
response_cost, batch_usage, batch_models = (
await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider
Expand Down
6 changes: 5 additions & 1 deletion litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ model_list:
model: gpt-4o
api_key: os.environ/OPENAI_API_KEY_TEST_2
model_info:
id: 12345679
id: 12345679


general_settings:
check_managed_files_batch_cost: true
23 changes: 23 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3440,6 +3440,29 @@ async def initialize_scheduled_background_jobs(
verbose_proxy_logger.error(
"Invalid maximum_spend_logs_retention_interval value"
)
### CHECK BATCH COST ###
if llm_router is not None:
try:
from litellm_enterprise.proxy.common_utils.check_batch_cost import (
CheckBatchCost,
)

check_batch_cost_job = CheckBatchCost(
proxy_logging_obj=proxy_logging_obj,
prisma_client=prisma_client,
llm_router=llm_router,
)
scheduler.add_job(
check_batch_cost_job.check_batch_cost,
"interval",
seconds=3600, # these can run infrequently, as batch jobs take time to complete
)

except Exception:
verbose_proxy_logger.debug(
"Checking batch cost for LiteLLM Managed Files is an Enterprise Feature. Skipping..."
)
pass

scheduler.start()

Expand Down
Loading