Skip to content

Commit c1f57f6

Browse files
lesebclaude
andauthored
perf(responses): batch guardrail checks during streaming (#5664)
# What does this PR do? Fixes a performance cliff when output guardrails are enabled on streaming responses. Previously, every streaming token triggered an O(n) string join, a `list_shields()` lookup, and a Safety API `run_moderation()` call. For a 1000-token response this meant ~1000 redundant API calls and quadratic string reconstruction. This PR: - Extracts `resolve_guardrail_model_ids()` to cache shield lookups once per request - Batches guardrail checks every 200 characters (configurable via `GUARDRAIL_BATCH_CHARS` env var) instead of every token - Adds a final guardrail check at stream end for remaining buffered content - Flushes reasoning-only deltas per chunk so they stream in real time ## Test Plan 1. Unit tests pass (`236 passed`): ```bash uv run pytest tests/unit/providers/responses/ -x --tb=short -q ``` 2. New test verifies reasoning events stream without waiting for text accumulation: ```bash uv run pytest tests/unit/providers/inline/responses/builtin/responses/test_streaming.py::test_guardrailed_reasoning_streams_before_completion -v ``` 3. Benchmark script for A/B testing against a running OGX server: ```bash # Start server with per-token checking (before): GUARDRAIL_BATCH_CHARS=1 SAFETY_MODEL=ollama/llama-guard3:1b uv run ogx stack run starter # Run benchmark: uv run python scripts/benchmark_guardrail_batching.py --model openai/gpt-4.1-nano # Restart server with batched checking (after, default): SAFETY_MODEL=ollama/llama-guard3:1b uv run ogx stack run starter # Run benchmark again and compare uv run python scripts/benchmark_guardrail_batching.py --model openai/gpt-4.1-nano ``` --------- Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 107a2bc commit c1f57f6

4 files changed

Lines changed: 198 additions & 29 deletions

File tree

src/ogx/providers/inline/responses/builtin/responses/streaming.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
convert_chat_choice_to_response_message,
118118
convert_mcp_tool_choice,
119119
is_function_tool_call,
120+
resolve_guardrail_model_ids,
120121
run_guardrails,
121122
should_summarize_reasoning,
122123
summarize_reasoning,
@@ -129,6 +130,8 @@
129130
# Anything else is either a registered function tool (client-side) or a hallucinated name.
130131
_SERVER_SIDE_BUILTIN_TOOL_NAMES = frozenset({"web_search", "knowledge_search", "file_search"})
131132

133+
_GUARDRAIL_BATCH_CHARS = 200
134+
132135
# Maps OpenAI Chat Completions error codes to Responses API error codes
133136
_RESPONSES_API_ERROR_CODES = {
134137
"invalid_base64": "invalid_base64_image",
@@ -304,6 +307,7 @@ def __init__(
304307
self.accumulated_usage: OpenAIResponseUsage | None = None
305308
# Track if we've sent a refusal response
306309
self.violation_detected = False
310+
self._guardrail_model_ids: list[str] = []
307311
# Track total calls made to built-in tools
308312
self.accumulated_builtin_tool_calls = 0
309313
# Track total output tokens generated across inference calls
@@ -411,8 +415,15 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
411415

412416
# Input safety validation - check messages before processing
413417
if self.guardrail_ids:
418+
if self.safety_api is not None:
419+
self._guardrail_model_ids = await resolve_guardrail_model_ids(self.safety_api, self.guardrail_ids)
414420
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
415-
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
421+
input_violation_message = await run_guardrails(
422+
self.safety_api,
423+
combined_text,
424+
self.guardrail_ids,
425+
model_ids=self._guardrail_model_ids,
426+
)
416427
if input_violation_message:
417428
logger.info("Input guardrail violation", input_violation_message=input_violation_message)
418429
yield await self._create_refusal_response(input_violation_message)
@@ -1038,6 +1049,8 @@ async def _process_streaming_chunks(
10381049
message_output_index = len(output_messages)
10391050
reasoning_text_accumulated = []
10401051
refusal_text_accumulated = []
1052+
pending_guardrail_events: list[OpenAIResponseObjectStream] = []
1053+
chars_since_last_check = 0
10411054

10421055
async for raw_chunk in completion_result:
10431056
# Providers returning OpenAIChatCompletionChunkWithReasoning wrap
@@ -1059,9 +1072,6 @@ async def _process_streaming_chunks(
10591072
# Accumulate usage from chunks (typically in final chunk with stream_options)
10601073
self._accumulate_chunk_usage(chunk)
10611074

1062-
# Track deltas for this specific chunk for guardrail validation
1063-
chunk_events: list[OpenAIResponseObjectStream] = []
1064-
10651075
for chunk_choice in chunk.choices:
10661076
# Collect logprobs if present
10671077
chunk_logprobs = None
@@ -1115,12 +1125,14 @@ async def _process_streaming_chunks(
11151125
)
11161126
# Buffer text delta events for guardrail check
11171127
if self.guardrail_ids:
1118-
chunk_events.append(text_delta_event)
1128+
pending_guardrail_events.append(text_delta_event)
11191129
else:
11201130
yield text_delta_event
11211131

11221132
# Collect content for final response
1123-
chat_response_content.append(chunk_choice.delta.content or "")
1133+
content_delta = chunk_choice.delta.content or ""
1134+
chat_response_content.append(content_delta)
1135+
chars_since_last_check += len(content_delta)
11241136
if chunk_choice.finish_reason:
11251137
chunk_finish_reason = chunk_choice.finish_reason
11261138

@@ -1137,7 +1149,7 @@ async def _process_streaming_chunks(
11371149
):
11381150
# Buffer reasoning events for guardrail check
11391151
if self.guardrail_ids:
1140-
chunk_events.append(event)
1152+
pending_guardrail_events.append(event)
11411153
else:
11421154
yield event
11431155
reasoning_part_emitted = True
@@ -1232,21 +1244,49 @@ async def _process_streaming_chunks(
12321244
response_tool_call.function.arguments or ""
12331245
) + tool_call.function.arguments
12341246

1235-
# Output Safety Validation for this chunk
1236-
if self.guardrail_ids:
1237-
# Check guardrails on accumulated text so far
1247+
# Batched output safety validation. If we have only buffered reasoning events and
1248+
# no assistant text yet, flush per chunk so reasoning can stream in real time.
1249+
guardrail_check_due = chars_since_last_check >= _GUARDRAIL_BATCH_CHARS
1250+
if pending_guardrail_events and not any(chat_response_content):
1251+
guardrail_check_due = True
1252+
1253+
if self.guardrail_ids and guardrail_check_due:
12381254
accumulated_text = "".join(chat_response_content)
1239-
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
1255+
violation_message = await run_guardrails(
1256+
self.safety_api,
1257+
accumulated_text,
1258+
self.guardrail_ids,
1259+
model_ids=self._guardrail_model_ids,
1260+
)
12401261
if violation_message:
12411262
logger.info("Output guardrail violation", violation_message=violation_message)
1242-
chunk_events.clear()
1263+
pending_guardrail_events.clear()
12431264
yield await self._create_refusal_response(violation_message)
12441265
self.violation_detected = True
12451266
return
1246-
else:
1247-
# No violation detected, emit all content events for this chunk
1248-
for event in chunk_events:
1249-
yield event
1267+
for event in pending_guardrail_events:
1268+
yield event
1269+
pending_guardrail_events.clear()
1270+
chars_since_last_check = 0
1271+
1272+
# Final guardrail check on remaining buffered content
1273+
if self.guardrail_ids and pending_guardrail_events:
1274+
accumulated_text = "".join(chat_response_content)
1275+
violation_message = await run_guardrails(
1276+
self.safety_api,
1277+
accumulated_text,
1278+
self.guardrail_ids,
1279+
model_ids=self._guardrail_model_ids,
1280+
)
1281+
if violation_message:
1282+
logger.info("Output guardrail violation", violation_message=violation_message)
1283+
pending_guardrail_events.clear()
1284+
yield await self._create_refusal_response(violation_message)
1285+
self.violation_detected = True
1286+
return
1287+
for event in pending_guardrail_events:
1288+
yield event
1289+
pending_guardrail_events.clear()
12501290

12511291
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
12521292
for tool_call_index in sorted(chat_response_tool_calls.keys()):

src/ogx/providers/inline/responses/builtin/responses/utils.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -548,27 +548,38 @@ def is_function_tool_call(
548548
return False
549549

550550

551-
async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids: list[str]) -> str | None:
552-
"""Run guardrails against messages and return violation message if blocked."""
553-
if not messages:
554-
return None
555-
556-
# If safety API is not available, skip guardrails
557-
if safety_api is None:
558-
return None
551+
async def resolve_guardrail_model_ids(safety_api: Safety, guardrail_ids: list[str]) -> list[str]:
552+
"""Resolve guardrail identifiers to concrete shield model IDs.
559553
560-
# Look up shields to get their provider_resource_id (actual model ID)
561-
model_ids = []
554+
Call once and pass the result to run_guardrails() to avoid repeated lookups.
555+
"""
562556
# TODO: list_shields not in Safety interface but available at runtime via API routing
563557
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
564-
558+
model_ids = []
565559
for guardrail_id in guardrail_ids:
566560
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
567561
if matching_shields:
568-
model_id = matching_shields[0].provider_resource_id
569-
model_ids.append(model_id)
562+
model_ids.append(matching_shields[0].provider_resource_id)
570563
else:
571564
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
565+
return model_ids
566+
567+
568+
async def run_guardrails(
569+
safety_api: Safety | None,
570+
messages: str,
571+
guardrail_ids: list[str],
572+
model_ids: list[str] | None = None,
573+
) -> str | None:
574+
"""Run guardrails against messages and return violation message if blocked."""
575+
if not messages:
576+
return None
577+
578+
if safety_api is None:
579+
return None
580+
581+
if model_ids is None:
582+
model_ids = await resolve_guardrail_model_ids(safety_api, guardrail_ids)
572583

573584
guardrail_tasks = [
574585
safety_api.run_moderation(RunModerationRequest(input=messages, model=model_id)) for model_id in model_ids

tests/integration/agents/recordings/72c53cb1f81e5b2835ea301e84dbd3431d5f46ba974ee9cb5df3ff3f14d90732.json

Lines changed: 62 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/unit/providers/inline/responses/builtin/responses/test_streaming.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
import asyncio
78
from collections.abc import AsyncIterator
89
from unittest.mock import AsyncMock, MagicMock
910

@@ -23,11 +24,15 @@
2324
from ogx_api.inference.models import (
2425
OpenAIAssistantMessageParam,
2526
OpenAIChatCompletion,
27+
OpenAIChatCompletionChunk,
28+
OpenAIChatCompletionChunkWithReasoning,
2629
OpenAIChatCompletionResponseMessage,
2730
OpenAIChatCompletionToolCall,
2831
OpenAIChatCompletionToolCallFunction,
2932
OpenAIChatCompletionUsage,
3033
OpenAIChoice,
34+
OpenAIChoiceDelta,
35+
OpenAIChunkChoice,
3136
)
3237
from ogx_api.openai_responses import (
3338
OpenAIResponseInputToolMCP,
@@ -577,3 +582,54 @@ async def test_uses_correct_summary_mode(self):
577582
call_args = mock_inference.openai_chat_completion.call_args[0][0]
578583
user_msg = call_args.messages[1].content
579584
assert "Preserve the key logical steps" in user_msg
585+
586+
587+
async def test_guardrailed_reasoning_streams_before_completion(mock_inference_api, mock_context, mock_safety_api):
588+
"""Guardrail batching should not buffer reasoning-only deltas until stream completion."""
589+
mock_context.model = "test-model"
590+
mock_context.temperature = None
591+
mock_context.top_p = None
592+
mock_context.frequency_penalty = None
593+
594+
orchestrator = StreamingResponseOrchestrator(
595+
inference_api=mock_inference_api,
596+
ctx=mock_context,
597+
response_id="resp_reasoning_guardrails",
598+
created_at=0,
599+
text=MagicMock(),
600+
max_infer_iters=1,
601+
tool_executor=MagicMock(),
602+
instructions=None,
603+
safety_api=mock_safety_api,
604+
guardrail_ids=["llama-guard"],
605+
)
606+
607+
gate = asyncio.Event()
608+
609+
async def completion_result() -> AsyncIterator[OpenAIChatCompletionChunkWithReasoning]:
610+
chunk = OpenAIChatCompletionChunk(
611+
id="chatcmpl_reasoning",
612+
choices=[
613+
OpenAIChunkChoice(
614+
index=0,
615+
delta=OpenAIChoiceDelta(content=None, role="assistant"),
616+
finish_reason=None,
617+
)
618+
],
619+
created=1,
620+
model="test-model",
621+
object="chat.completion.chunk",
622+
)
623+
yield OpenAIChatCompletionChunkWithReasoning(chunk=chunk, reasoning_content="thinking...")
624+
625+
await gate.wait()
626+
627+
stream = orchestrator._process_streaming_chunks(completion_result(), output_messages=[])
628+
629+
# If reasoning is buffered until completion, this call will time out.
630+
first_event = await asyncio.wait_for(anext(stream), timeout=0.5)
631+
assert first_event.type in {"response.content_part.added", "response.reasoning_text.delta"}
632+
633+
gate.set()
634+
async for _ in stream:
635+
pass

0 commit comments

Comments
 (0)