Skip to content

Commit 8fcda2f

Browse files
lesebclaude
andauthored
fix(openai): clamp max_tokens to per-model limits to prevent overflow errors (#5696)
# What does this PR do? Fixes `BadRequestError: max_tokens is too large` when clients (e.g. Claude Code) send `max_tokens` values that exceed what the target OpenAI model supports. For example, Claude Code requests `max_tokens: 32000` but `gpt-4o-mini` only supports 16384. Adds a static per-model `max_output_tokens` map to the OpenAI provider adapter and clamps incoming `max_tokens` at request time. Supports prefix matching for dated snapshot variants (e.g. `gpt-4o-2024-08-06` inherits from `gpt-4o`). Logs a warning once per unknown model so operators know the map needs updating when new models are released. Also populates `max_output_tokens` in model metadata via `construct_model_from_identifier()`, exposing it through the `/v1/models` endpoint's `custom_metadata` field. ## Test Plan ```bash uv run pytest tests/unit/providers/inference/test_remote_openai.py -v --tb=short ``` Output: ``` tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxTokensClamping::test_clamps_when_request_exceeds_model_limit PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxTokensClamping::test_keeps_lower_request_value PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxTokensClamping::test_no_clamping_when_max_tokens_is_none PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxTokensClamping::test_does_not_mutate_original_params PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxTokensClamping::test_different_models_have_different_limits PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxTokensClamping::test_no_clamping_for_unknown_model PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxTokensClamping::test_dated_snapshot_model_uses_base_limit PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIModelMetadata::test_construct_model_includes_max_output_tokens PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIModelMetadata::test_construct_model_unknown_has_no_max_output_tokens PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIModelMetadata::test_construct_model_embedding_unchanged PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxOutputTokensWarning::test_warns_once_for_unknown_model PASSED tests/unit/providers/inference/test_remote_openai.py::TestOpenAIMaxOutputTokensWarning::test_all_known_models_have_limits PASSED 12 passed in 0.12s ``` --------- Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c1f57f6 commit 8fcda2f

2 files changed

Lines changed: 365 additions & 0 deletions

File tree

src/ogx/providers/remote/inference/openai/openai.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,43 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from collections.abc import AsyncIterator
8+
79
from ogx.log import get_logger
810
from ogx.providers.utils.inference.openai_mixin import OpenAIMixin
11+
from ogx_api import (
12+
Model,
13+
ModelType,
14+
OpenAIChatCompletion,
15+
OpenAIChatCompletionChunk,
16+
OpenAIChatCompletionRequestWithExtraBody,
17+
)
918

1019
from .config import OpenAIConfig
1120

1221
logger = get_logger(name=__name__, category="inference::openai")
1322

23+
# Max output tokens per OpenAI model. OpenAI's /v1/models endpoint does not
24+
# expose this, so we maintain the mapping statically.
25+
_MODEL_MAX_OUTPUT_TOKENS: dict[str, int] = {
26+
"gpt-4.1": 32768,
27+
"gpt-4.1-mini": 32768,
28+
"gpt-4.1-nano": 32768,
29+
"gpt-4o": 16384,
30+
"gpt-4o-mini": 16384,
31+
"gpt-4-turbo": 4096,
32+
"gpt-4": 8192,
33+
"o1": 100000,
34+
"o1-mini": 65536,
35+
"o1-pro": 100000,
36+
"o3": 100000,
37+
"o3-mini": 100000,
38+
"o3-pro": 100000,
39+
"o4-mini": 100000,
40+
}
41+
42+
_WARNED_MODELS: set[str] = set()
43+
1444

1545
#
1646
# This OpenAI adapter implements Inference methods using OpenAIMixin
@@ -31,6 +61,68 @@ class OpenAIInferenceAdapter(OpenAIMixin):
3161
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
3262
}
3363

64+
def _get_max_output_tokens(self, model: str) -> int | None:
65+
if model in _MODEL_MAX_OUTPUT_TOKENS:
66+
return _MODEL_MAX_OUTPUT_TOKENS[model]
67+
68+
# Try prefix matching for dated snapshot variants (e.g. gpt-4o-2024-08-06)
69+
for base_model, limit in sorted(
70+
_MODEL_MAX_OUTPUT_TOKENS.items(),
71+
key=lambda item: len(item[0]),
72+
reverse=True,
73+
):
74+
if model.startswith(f"{base_model}-"):
75+
return limit
76+
77+
if model not in _WARNED_MODELS:
78+
_WARNED_MODELS.add(model)
79+
logger.warning(
80+
"Unknown max_output_tokens for model, requests will not be clamped",
81+
model=model,
82+
)
83+
return None
84+
85+
def construct_model_from_identifier(self, identifier: str) -> Model:
86+
if metadata := self.embedding_model_metadata.get(identifier):
87+
return Model(
88+
provider_id=self.__provider_id__, # type: ignore[attr-defined]
89+
provider_resource_id=identifier,
90+
identifier=identifier,
91+
model_type=ModelType.embedding,
92+
metadata=metadata,
93+
)
94+
95+
metadata = {}
96+
max_output_tokens = self._get_max_output_tokens(identifier)
97+
if max_output_tokens is not None:
98+
metadata["max_output_tokens"] = max_output_tokens
99+
100+
return Model(
101+
provider_id=self.__provider_id__, # type: ignore[attr-defined]
102+
provider_resource_id=identifier,
103+
identifier=identifier,
104+
model_type=ModelType.llm,
105+
metadata=metadata,
106+
)
107+
108+
async def openai_chat_completion(
109+
self,
110+
params: OpenAIChatCompletionRequestWithExtraBody,
111+
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
112+
max_output_tokens = self._get_max_output_tokens(params.model)
113+
if max_output_tokens is not None:
114+
updated_params = params
115+
if params.max_tokens is not None and params.max_tokens > max_output_tokens:
116+
updated_params = updated_params.model_copy()
117+
updated_params.max_tokens = max_output_tokens
118+
if params.max_completion_tokens is not None and params.max_completion_tokens > max_output_tokens:
119+
if updated_params is params:
120+
updated_params = updated_params.model_copy()
121+
updated_params.max_completion_tokens = max_output_tokens
122+
params = updated_params
123+
124+
return await super().openai_chat_completion(params)
125+
34126
async def openai_chat_completions_with_reasoning(self, params) -> None:
35127
raise ValueError(
36128
"OpenAI provider does not support reasoning. "
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# Copyright (c) The OGX Contributors.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
8+
9+
import pytest
10+
11+
from ogx.providers.remote.inference.openai.config import OpenAIConfig
12+
from ogx.providers.remote.inference.openai.openai import (
13+
_MODEL_MAX_OUTPUT_TOKENS,
14+
_WARNED_MODELS,
15+
OpenAIInferenceAdapter,
16+
)
17+
from ogx_api import (
18+
OpenAIChatCompletion,
19+
OpenAIChatCompletionRequestWithExtraBody,
20+
OpenAIChatCompletionResponseMessage,
21+
OpenAIChoice,
22+
)
23+
24+
25+
@pytest.fixture
26+
def mock_openai_response():
27+
return OpenAIChatCompletion(
28+
id="chatcmpl-abc123",
29+
created=1,
30+
model="gpt-4o-mini",
31+
choices=[
32+
OpenAIChoice(
33+
message=OpenAIChatCompletionResponseMessage(content="hello"),
34+
finish_reason="stop",
35+
index=0,
36+
)
37+
],
38+
)
39+
40+
41+
@pytest.fixture(autouse=True)
42+
def _clear_warned_models():
43+
_WARNED_MODELS.clear()
44+
yield
45+
_WARNED_MODELS.clear()
46+
47+
48+
def _make_adapter():
49+
config = OpenAIConfig(api_key="fake-key")
50+
adapter = OpenAIInferenceAdapter(config=config)
51+
adapter.model_store = AsyncMock()
52+
return adapter
53+
54+
55+
class TestOpenAIMaxTokensClamping:
56+
async def test_clamps_when_request_exceeds_model_limit(self, mock_openai_response):
57+
adapter = _make_adapter()
58+
59+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
60+
mock_client = MagicMock()
61+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
62+
mock_client_prop.return_value = mock_client
63+
64+
params = OpenAIChatCompletionRequestWithExtraBody(
65+
model="gpt-4o-mini",
66+
messages=[{"role": "user", "content": "hi"}],
67+
stream=False,
68+
max_tokens=32000,
69+
)
70+
await adapter.openai_chat_completion(params)
71+
72+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
73+
assert call_kwargs["max_tokens"] == 16384
74+
75+
async def test_keeps_lower_request_value(self, mock_openai_response):
76+
adapter = _make_adapter()
77+
78+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
79+
mock_client = MagicMock()
80+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
81+
mock_client_prop.return_value = mock_client
82+
83+
params = OpenAIChatCompletionRequestWithExtraBody(
84+
model="gpt-4o-mini",
85+
messages=[{"role": "user", "content": "hi"}],
86+
stream=False,
87+
max_tokens=1000,
88+
)
89+
await adapter.openai_chat_completion(params)
90+
91+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
92+
assert call_kwargs["max_tokens"] == 1000
93+
94+
async def test_no_clamping_when_max_tokens_is_none(self, mock_openai_response):
95+
adapter = _make_adapter()
96+
97+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
98+
mock_client = MagicMock()
99+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
100+
mock_client_prop.return_value = mock_client
101+
102+
params = OpenAIChatCompletionRequestWithExtraBody(
103+
model="gpt-4o-mini",
104+
messages=[{"role": "user", "content": "hi"}],
105+
stream=False,
106+
)
107+
await adapter.openai_chat_completion(params)
108+
109+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
110+
assert call_kwargs.get("max_tokens") is None
111+
112+
async def test_does_not_mutate_original_params(self, mock_openai_response):
113+
adapter = _make_adapter()
114+
115+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
116+
mock_client = MagicMock()
117+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
118+
mock_client_prop.return_value = mock_client
119+
120+
params = OpenAIChatCompletionRequestWithExtraBody(
121+
model="gpt-4o-mini",
122+
messages=[{"role": "user", "content": "hi"}],
123+
stream=False,
124+
max_tokens=32000,
125+
)
126+
await adapter.openai_chat_completion(params)
127+
128+
assert params.max_tokens == 32000
129+
130+
async def test_different_models_have_different_limits(self, mock_openai_response):
131+
adapter = _make_adapter()
132+
133+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
134+
mock_client = MagicMock()
135+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
136+
mock_client_prop.return_value = mock_client
137+
138+
# gpt-4-turbo has a 4096 limit
139+
params = OpenAIChatCompletionRequestWithExtraBody(
140+
model="gpt-4-turbo",
141+
messages=[{"role": "user", "content": "hi"}],
142+
stream=False,
143+
max_tokens=32000,
144+
)
145+
await adapter.openai_chat_completion(params)
146+
147+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
148+
assert call_kwargs["max_tokens"] == 4096
149+
150+
async def test_no_clamping_for_unknown_model(self, mock_openai_response):
151+
adapter = _make_adapter()
152+
153+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
154+
mock_client = MagicMock()
155+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
156+
mock_client_prop.return_value = mock_client
157+
158+
params = OpenAIChatCompletionRequestWithExtraBody(
159+
model="some-future-model",
160+
messages=[{"role": "user", "content": "hi"}],
161+
stream=False,
162+
max_tokens=32000,
163+
)
164+
await adapter.openai_chat_completion(params)
165+
166+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
167+
assert call_kwargs["max_tokens"] == 32000
168+
169+
async def test_dated_snapshot_model_uses_base_limit(self, mock_openai_response):
170+
adapter = _make_adapter()
171+
172+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
173+
mock_client = MagicMock()
174+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
175+
mock_client_prop.return_value = mock_client
176+
177+
params = OpenAIChatCompletionRequestWithExtraBody(
178+
model="gpt-4o-2024-08-06",
179+
messages=[{"role": "user", "content": "hi"}],
180+
stream=False,
181+
max_tokens=32000,
182+
)
183+
await adapter.openai_chat_completion(params)
184+
185+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
186+
assert call_kwargs["max_tokens"] == 16384
187+
188+
async def test_clamps_max_completion_tokens_when_request_exceeds_model_limit(self, mock_openai_response):
189+
adapter = _make_adapter()
190+
191+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
192+
mock_client = MagicMock()
193+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
194+
mock_client_prop.return_value = mock_client
195+
196+
params = OpenAIChatCompletionRequestWithExtraBody(
197+
model="gpt-4o-mini",
198+
messages=[{"role": "user", "content": "hi"}],
199+
stream=False,
200+
max_completion_tokens=32000,
201+
)
202+
await adapter.openai_chat_completion(params)
203+
204+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
205+
assert call_kwargs["max_completion_tokens"] == 16384
206+
207+
async def test_clamps_both_max_token_fields_when_both_exceed_model_limit(self, mock_openai_response):
208+
adapter = _make_adapter()
209+
210+
with patch.object(OpenAIInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_prop:
211+
mock_client = MagicMock()
212+
mock_client.chat.completions.create = AsyncMock(return_value=mock_openai_response)
213+
mock_client_prop.return_value = mock_client
214+
215+
params = OpenAIChatCompletionRequestWithExtraBody(
216+
model="gpt-4o-mini",
217+
messages=[{"role": "user", "content": "hi"}],
218+
stream=False,
219+
max_tokens=32000,
220+
max_completion_tokens=32000,
221+
)
222+
await adapter.openai_chat_completion(params)
223+
224+
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
225+
assert call_kwargs["max_tokens"] == 16384
226+
assert call_kwargs["max_completion_tokens"] == 16384
227+
228+
229+
class TestOpenAIModelMetadata:
230+
def test_construct_model_includes_max_output_tokens(self):
231+
adapter = _make_adapter()
232+
adapter.__provider_id__ = "openai"
233+
234+
model = adapter.construct_model_from_identifier("gpt-4o-mini")
235+
assert model.metadata["max_output_tokens"] == 16384
236+
237+
def test_construct_model_unknown_has_no_max_output_tokens(self):
238+
adapter = _make_adapter()
239+
adapter.__provider_id__ = "openai"
240+
241+
model = adapter.construct_model_from_identifier("some-future-model")
242+
assert "max_output_tokens" not in model.metadata
243+
244+
def test_construct_model_embedding_unchanged(self):
245+
adapter = _make_adapter()
246+
adapter.__provider_id__ = "openai"
247+
248+
model = adapter.construct_model_from_identifier("text-embedding-3-small")
249+
assert model.model_type.value == "embedding"
250+
assert model.metadata["embedding_dimension"] == 1536
251+
252+
253+
class TestOpenAIMaxOutputTokensWarning:
254+
def test_warns_once_for_unknown_model(self, caplog):
255+
adapter = _make_adapter()
256+
257+
with caplog.at_level("WARNING"):
258+
result1 = adapter._get_max_output_tokens("brand-new-model")
259+
result2 = adapter._get_max_output_tokens("brand-new-model")
260+
261+
assert result1 is None
262+
assert result2 is None
263+
warning_count = sum(1 for r in caplog.records if "brand-new-model" in r.message)
264+
assert warning_count == 1
265+
266+
def test_all_known_models_have_limits(self):
267+
adapter = _make_adapter()
268+
for model_id, expected_limit in _MODEL_MAX_OUTPUT_TOKENS.items():
269+
assert adapter._get_max_output_tokens(model_id) == expected_limit
270+
271+
def test_prefix_matching_prefers_more_specific_model(self):
272+
adapter = _make_adapter()
273+
assert adapter._get_max_output_tokens("o1-mini-2024-09-12") == 65536

0 commit comments

Comments
 (0)