Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6d35c33

Browse files
authoredMar 19, 2025··
Introduce tool_use_behavior on agents (#203)
## Context By default, the outputs of tools are sent to the LLM again. The LLM gets to read the outputs, and produce a new response. There are cases where this is not desired: 1. Every tool results in another round trip, and sometimes the output of the tool is enough. 2. If you force tool use (via model settings `tool_choice=required`), then the agent will just infinite loop. This enables you to have different behavior, e.g. use the first tool output as the final output, or write a custom function to process tool results and potentially produce an output. ## Test plan Added new tests and ran existing tests Also added examples. Closes #117
2 parents 48ff99b + 10aa555 commit 6d35c33

File tree

12 files changed

+594
-26
lines changed

12 files changed

+594
-26
lines changed
 

‎docs/agents.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,16 @@ robot_agent = pirate_agent.clone(
130130
instructions="Write like a robot",
131131
)
132132
```
133+
134+
## Forcing tool use
135+
136+
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
137+
138+
1. `auto`, which allows the LLM to decide whether or not to use a tool.
139+
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
140+
3. `none`, which requires the LLM to _not_ use a tool.
141+
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
142+
143+
!!! note
144+
145+
If requiring tool use, you should consider setting [`Agent.tool_use_behavior`] to stop the Agent from running when a tool output is produced. Otherwise, the Agent might run in an infinite loop, where the LLM produces a tool call , and the tool result is sent to the LLM, and this infinite loops because the LLM is always forced to use a tool.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import Any, Literal
5+
6+
from pydantic import BaseModel
7+
8+
from agents import (
9+
Agent,
10+
FunctionToolResult,
11+
ModelSettings,
12+
RunContextWrapper,
13+
Runner,
14+
ToolsToFinalOutputFunction,
15+
ToolsToFinalOutputResult,
16+
function_tool,
17+
)
18+
19+
"""
20+
This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")`
21+
to force the agent to use any tool.
22+
23+
You can run it with 3 options:
24+
1. `default`: The default behavior, which is to send the tool output to the LLM. In this case,
25+
`tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would
26+
call the tool, the tool would run and send the results to the LLM, and that would repeat
27+
(because the model is forced to use a tool every time.)
28+
2. `first_tool_result`: The first tool result is used as the final output.
29+
3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool
30+
results, and chooses to use the first tool result to generate the final output.
31+
32+
Usage:
33+
python examples/agent_patterns/forcing_tool_use.py -t default
34+
python examples/agent_patterns/forcing_tool_use.py -t first_tool
35+
python examples/agent_patterns/forcing_tool_use.py -t custom
36+
"""
37+
38+
39+
class Weather(BaseModel):
40+
city: str
41+
temperature_range: str
42+
conditions: str
43+
44+
45+
@function_tool
46+
def get_weather(city: str) -> Weather:
47+
print("[debug] get_weather called")
48+
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind")
49+
50+
51+
async def custom_tool_use_behavior(
52+
context: RunContextWrapper[Any], results: list[FunctionToolResult]
53+
) -> ToolsToFinalOutputResult:
54+
weather: Weather = results[0].output
55+
return ToolsToFinalOutputResult(
56+
is_final_output=True, final_output=f"{weather.city} is {weather.conditions}."
57+
)
58+
59+
60+
async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"):
61+
if tool_use_behavior == "default":
62+
behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = (
63+
"run_llm_again"
64+
)
65+
elif tool_use_behavior == "first_tool":
66+
behavior = "stop_on_first_tool"
67+
elif tool_use_behavior == "custom":
68+
behavior = custom_tool_use_behavior
69+
70+
agent = Agent(
71+
name="Weather agent",
72+
instructions="You are a helpful agent.",
73+
tools=[get_weather],
74+
tool_use_behavior=behavior,
75+
model_settings=ModelSettings(
76+
tool_choice="required" if tool_use_behavior != "default" else None
77+
),
78+
)
79+
80+
result = await Runner.run(agent, input="What's the weather in Tokyo?")
81+
print(result.final_output)
82+
83+
84+
if __name__ == "__main__":
85+
import argparse
86+
87+
parser = argparse.ArgumentParser()
88+
parser.add_argument(
89+
"-t",
90+
"--tool-use-behavior",
91+
type=str,
92+
required=True,
93+
choices=["default", "first_tool", "custom"],
94+
help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. "
95+
"first_tool_result will cause the first tool result to be used as the final output. "
96+
"custom will use a custom tool use behavior function.",
97+
)
98+
args = parser.parse_args()
99+
asyncio.run(main(args.tool_use_behavior))

‎examples/basic/tools.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import asyncio
2+
3+
from pydantic import BaseModel
4+
5+
from agents import Agent, Runner, function_tool
6+
7+
8+
class Weather(BaseModel):
9+
city: str
10+
temperature_range: str
11+
conditions: str
12+
13+
14+
@function_tool
15+
def get_weather(city: str) -> Weather:
16+
print("[debug] get_weather called")
17+
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
18+
19+
20+
agent = Agent(
21+
name="Hello world",
22+
instructions="You are a helpful agent.",
23+
tools=[get_weather],
24+
)
25+
26+
27+
async def main():
28+
result = await Runner.run(agent, input="What's the weather in Tokyo?")
29+
print(result.final_output)
30+
# The weather in Tokyo is sunny.
31+
32+
33+
if __name__ == "__main__":
34+
asyncio.run(main())

‎src/agents/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from openai import AsyncOpenAI
66

77
from . import _config
8-
from .agent import Agent
8+
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
99
from .agent_output import AgentOutputSchema
1010
from .computer import AsyncComputer, Button, Computer, Environment
1111
from .exceptions import (
@@ -57,6 +57,7 @@
5757
ComputerTool,
5858
FileSearchTool,
5959
FunctionTool,
60+
FunctionToolResult,
6061
Tool,
6162
WebSearchTool,
6263
default_tool_error_function,
@@ -137,6 +138,8 @@ def enable_verbose_stdout_logging():
137138

138139
__all__ = [
139140
"Agent",
141+
"ToolsToFinalOutputFunction",
142+
"ToolsToFinalOutputResult",
140143
"Runner",
141144
"Model",
142145
"ModelProvider",
@@ -190,6 +193,7 @@ def enable_verbose_stdout_logging():
190193
"AgentUpdatedStreamEvent",
191194
"StreamEvent",
192195
"FunctionTool",
196+
"FunctionToolResult",
193197
"ComputerTool",
194198
"FileSearchTool",
195199
"Tool",

‎src/agents/_run_impl.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import inspect
5+
from collections.abc import Awaitable
46
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, cast
68

79
from openai.types.responses import (
810
ResponseComputerToolCall,
@@ -25,7 +27,7 @@
2527
from openai.types.responses.response_input_param import ComputerCallOutput
2628
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
2729

28-
from .agent import Agent
30+
from .agent import Agent, ToolsToFinalOutputResult
2931
from .agent_output import AgentOutputSchema
3032
from .computer import AsyncComputer, Computer
3133
from .exceptions import AgentsException, ModelBehaviorError, UserError
@@ -48,7 +50,7 @@
4850
from .models.interface import ModelTracing
4951
from .run_context import RunContextWrapper, TContext
5052
from .stream_events import RunItemStreamEvent, StreamEvent
51-
from .tool import ComputerTool, FunctionTool
53+
from .tool import ComputerTool, FunctionTool, FunctionToolResult
5254
from .tracing import (
5355
SpanError,
5456
Trace,
@@ -70,6 +72,8 @@ class QueueCompleteSentinel:
7072

7173
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
7274

75+
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
76+
7377

7478
@dataclass
7579
class ToolRunHandoff:
@@ -199,7 +203,7 @@ async def execute_tools_and_side_effects(
199203
config=run_config,
200204
),
201205
)
202-
new_step_items.extend(function_results)
206+
new_step_items.extend([result.run_item for result in function_results])
203207
new_step_items.extend(computer_results)
204208

205209
# Second, check if there are any handoffs
@@ -216,6 +220,36 @@ async def execute_tools_and_side_effects(
216220
run_config=run_config,
217221
)
218222

223+
# Third, we'll check if the tool use should result in a final output
224+
check_tool_use = await cls._check_for_final_output_from_tools(
225+
agent=agent,
226+
tool_results=function_results,
227+
context_wrapper=context_wrapper,
228+
config=run_config,
229+
)
230+
231+
if check_tool_use.is_final_output:
232+
# If the output type is str, then let's just stringify it
233+
if not agent.output_type or agent.output_type is str:
234+
check_tool_use.final_output = str(check_tool_use.final_output)
235+
236+
if check_tool_use.final_output is None:
237+
logger.error(
238+
"Model returned a final output of None. Not raising an error because we assume"
239+
"you know what you're doing."
240+
)
241+
242+
return await cls.execute_final_output(
243+
agent=agent,
244+
original_input=original_input,
245+
new_response=new_response,
246+
pre_step_items=pre_step_items,
247+
new_step_items=new_step_items,
248+
final_output=check_tool_use.final_output,
249+
hooks=hooks,
250+
context_wrapper=context_wrapper,
251+
)
252+
219253
# Now we can check if the model also produced a final output
220254
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
221255

@@ -355,10 +389,10 @@ async def execute_function_tool_calls(
355389
hooks: RunHooks[TContext],
356390
context_wrapper: RunContextWrapper[TContext],
357391
config: RunConfig,
358-
) -> list[RunItem]:
392+
) -> list[FunctionToolResult]:
359393
async def run_single_tool(
360394
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
361-
) -> str:
395+
) -> Any:
362396
with function_span(func_tool.name) as span_fn:
363397
if config.trace_include_sensitive_data:
364398
span_fn.span_data.input = tool_call.arguments
@@ -404,10 +438,14 @@ async def run_single_tool(
404438
results = await asyncio.gather(*tasks)
405439

406440
return [
407-
ToolCallOutputItem(
408-
output=str(result),
409-
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
410-
agent=agent,
441+
FunctionToolResult(
442+
tool=tool_run.function_tool,
443+
output=result,
444+
run_item=ToolCallOutputItem(
445+
output=result,
446+
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
447+
agent=agent,
448+
),
411449
)
412450
for tool_run, result in zip(tool_runs, results)
413451
]
@@ -646,6 +684,47 @@ def stream_step_result_to_queue(
646684
if event:
647685
queue.put_nowait(event)
648686

687+
@classmethod
688+
async def _check_for_final_output_from_tools(
689+
cls,
690+
*,
691+
agent: Agent[TContext],
692+
tool_results: list[FunctionToolResult],
693+
context_wrapper: RunContextWrapper[TContext],
694+
config: RunConfig,
695+
) -> ToolsToFinalOutputResult:
696+
"""Returns (i, final_output)."""
697+
if not tool_results:
698+
return _NOT_FINAL_OUTPUT
699+
700+
if agent.tool_use_behavior == "run_llm_again":
701+
return _NOT_FINAL_OUTPUT
702+
elif agent.tool_use_behavior == "stop_on_first_tool":
703+
return ToolsToFinalOutputResult(
704+
is_final_output=True, final_output=tool_results[0].output
705+
)
706+
elif isinstance(agent.tool_use_behavior, dict):
707+
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
708+
for tool_result in tool_results:
709+
if tool_result.tool.name in names:
710+
return ToolsToFinalOutputResult(
711+
is_final_output=True, final_output=tool_result.output
712+
)
713+
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
714+
elif callable(agent.tool_use_behavior):
715+
if inspect.iscoroutinefunction(agent.tool_use_behavior):
716+
return await cast(
717+
Awaitable[ToolsToFinalOutputResult],
718+
agent.tool_use_behavior(context_wrapper, tool_results),
719+
)
720+
else:
721+
return cast(
722+
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
723+
)
724+
725+
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
726+
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
727+
649728

650729
class TraceCtxManager:
651730
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""

‎src/agents/agent.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import inspect
55
from collections.abc import Awaitable
66
from dataclasses import dataclass, field
7-
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
7+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
8+
9+
from typing_extensions import TypeAlias, TypedDict
810

911
from .guardrail import InputGuardrail, OutputGuardrail
1012
from .handoffs import Handoff
@@ -13,7 +15,7 @@
1315
from .model_settings import ModelSettings
1416
from .models.interface import Model
1517
from .run_context import RunContextWrapper, TContext
16-
from .tool import Tool, function_tool
18+
from .tool import FunctionToolResult, Tool, function_tool
1719
from .util import _transforms
1820
from .util._types import MaybeAwaitable
1921

@@ -22,6 +24,33 @@
2224
from .result import RunResult
2325

2426

27+
@dataclass
28+
class ToolsToFinalOutputResult:
29+
is_final_output: bool
30+
"""Whether this is the final output. If False, the LLM will run again and receive the tool call
31+
output.
32+
"""
33+
34+
final_output: Any | None = None
35+
"""The final output. Can be None if `is_final_output` is False, otherwise must match the
36+
`output_type` of the agent.
37+
"""
38+
39+
40+
ToolsToFinalOutputFunction: TypeAlias = Callable[
41+
[RunContextWrapper[TContext], list[FunctionToolResult]],
42+
MaybeAwaitable[ToolsToFinalOutputResult],
43+
]
44+
"""A function that takes a run context and a list of tool results, and returns a
45+
`ToolToFinalOutputResult`.
46+
"""
47+
48+
49+
class StopAtTools(TypedDict):
50+
stop_at_tool_names: list[str]
51+
"""A list of tool names, any of which will stop the agent from running further."""
52+
53+
2554
@dataclass
2655
class Agent(Generic[TContext]):
2756
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
@@ -95,6 +124,25 @@ class Agent(Generic[TContext]):
95124
"""A class that receives callbacks on various lifecycle events for this agent.
96125
"""
97126

127+
tool_use_behavior: (
128+
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
129+
) = "run_llm_again"
130+
"""This lets you configure how tool use is handled.
131+
- "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
132+
and gets to respond.
133+
- "stop_on_first_tool": The output of the first tool call is used as the final output. This
134+
means that the LLM does not process the result of the tool call.
135+
- A list of tool names: The agent will stop running if any of the tools in the list are called.
136+
The final output will be the output of the first matching tool call. The LLM does not
137+
process the result of the tool call.
138+
- A function: If you pass a function, it will be called with the run context and the list of
139+
tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
140+
calls result in a final output.
141+
142+
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
143+
web search, etc are always processed by the LLM.
144+
"""
145+
98146
def clone(self, **kwargs: Any) -> Agent[TContext]:
99147
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
100148
```

‎src/agents/items.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu
129129
raw_item: FunctionCallOutput | ComputerCallOutput
130130
"""The raw item from the model."""
131131

132-
output: str
133-
"""The output of the tool call."""
132+
output: Any
133+
"""The output of the tool call. This is whatever the tool call returned; the `raw_item`
134+
contains a string representation of the output.
135+
"""
134136

135137
type: Literal["tool_call_output_item"] = "tool_call_output_item"
136138

‎src/agents/tool.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .computer import AsyncComputer, Computer
1616
from .exceptions import ModelBehaviorError
1717
from .function_schema import DocstringStyle, function_schema
18+
from .items import RunItem
1819
from .logger import logger
1920
from .run_context import RunContextWrapper
2021
from .tracing import SpanError
@@ -29,6 +30,18 @@
2930
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
3031

3132

33+
@dataclass
34+
class FunctionToolResult:
35+
tool: FunctionTool
36+
"""The tool that was run."""
37+
38+
output: Any
39+
"""The output of the tool."""
40+
41+
run_item: RunItem
42+
"""The run item that was produced as a result of the tool call."""
43+
44+
3245
@dataclass
3346
class FunctionTool:
3447
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
@@ -44,15 +57,15 @@ class FunctionTool:
4457
params_json_schema: dict[str, Any]
4558
"""The JSON schema for the tool's parameters."""
4659

47-
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]]
60+
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
4861
"""A function that invokes the tool with the given context and parameters. The params passed
4962
are:
5063
1. The tool run context.
5164
2. The arguments from the LLM, as a JSON string.
5265
53-
You must return a string representation of the tool output. In case of errors, you can either
54-
raise an Exception (which will cause the run to fail) or return a string error message (which
55-
will be sent back to the LLM).
66+
You must return a string representation of the tool output, or something we can call `str()` on.
67+
In case of errors, you can either raise an Exception (which will cause the run to fail) or
68+
return a string error message (which will be sent back to the LLM).
5669
"""
5770

5871
strict_json_schema: bool = True
@@ -207,7 +220,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
207220
strict_json_schema=strict_mode,
208221
)
209222

210-
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
223+
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
211224
try:
212225
json_data: dict[str, Any] = json.loads(input) if input else {}
213226
except Exception as e:
@@ -254,9 +267,9 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
254267
else:
255268
logger.debug(f"Tool {schema.name} returned {result}")
256269

257-
return str(result)
270+
return result
258271

259-
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
272+
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
260273
try:
261274
return await _on_invoke_tool_impl(ctx, input)
262275
except Exception as e:

‎src/agents/tracing/span_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def export(self) -> dict[str, Any]:
5151
class FunctionSpanData(SpanData):
5252
__slots__ = ("name", "input", "output")
5353

54-
def __init__(self, name: str, input: str | None, output: str | None):
54+
def __init__(self, name: str, input: str | None, output: Any | None):
5555
self.name = name
5656
self.input = input
5757
self.output = output
@@ -65,7 +65,7 @@ def export(self) -> dict[str, Any]:
6565
"type": self.type,
6666
"name": self.name,
6767
"input": self.input,
68-
"output": self.output,
68+
"output": str(self.output) if self.output else None,
6969
}
7070

7171

‎tests/test_agent_runner.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
UserError,
2222
handoff,
2323
)
24+
from agents.agent import ToolsToFinalOutputResult
25+
from agents.tool import FunctionToolResult, function_tool
2426

2527
from .fake_model import FakeModel
2628
from .test_responses import (
@@ -552,3 +554,83 @@ def guardrail_function(
552554

553555
with pytest.raises(OutputGuardrailTripwireTriggered):
554556
await Runner.run(agent, input="user_message")
557+
558+
559+
@function_tool
560+
def test_tool_one():
561+
return Foo(bar="tool_one_result")
562+
563+
564+
@function_tool
565+
def test_tool_two():
566+
return "tool_two_result"
567+
568+
569+
@pytest.mark.asyncio
570+
async def test_tool_use_behavior_first_output():
571+
model = FakeModel()
572+
agent = Agent(
573+
name="test",
574+
model=model,
575+
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
576+
tool_use_behavior="stop_on_first_tool",
577+
output_type=Foo,
578+
)
579+
580+
model.add_multiple_turn_outputs(
581+
[
582+
# First turn: a message and tool call
583+
[
584+
get_text_message("a_message"),
585+
get_function_tool_call("test_tool_one", None),
586+
get_function_tool_call("test_tool_two", None),
587+
],
588+
]
589+
)
590+
591+
result = await Runner.run(agent, input="user_message")
592+
593+
assert result.final_output == Foo(bar="tool_one_result"), (
594+
"should have used the first tool result"
595+
)
596+
597+
598+
def custom_tool_use_behavior(
599+
context: RunContextWrapper[Any], results: list[FunctionToolResult]
600+
) -> ToolsToFinalOutputResult:
601+
if "test_tool_one" in [result.tool.name for result in results]:
602+
return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output")
603+
else:
604+
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
605+
606+
607+
@pytest.mark.asyncio
608+
async def test_tool_use_behavior_custom_function():
609+
model = FakeModel()
610+
agent = Agent(
611+
name="test",
612+
model=model,
613+
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
614+
tool_use_behavior=custom_tool_use_behavior,
615+
)
616+
617+
model.add_multiple_turn_outputs(
618+
[
619+
# First turn: a message and tool call
620+
[
621+
get_text_message("a_message"),
622+
get_function_tool_call("test_tool_two", None),
623+
],
624+
# Second turn: a message and tool call
625+
[
626+
get_text_message("a_message"),
627+
get_function_tool_call("test_tool_one", None),
628+
get_function_tool_call("test_tool_two", None),
629+
],
630+
]
631+
)
632+
633+
result = await Runner.run(agent, input="user_message")
634+
635+
assert len(result.raw_responses) == 2, "should have two model responses"
636+
assert result.final_output == "the_final_output", "should have used the custom function"

‎tests/test_function_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ async def test_simple_function():
4949
assert tool.name == "simple_function"
5050

5151
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
52-
assert result == "6"
52+
assert result == 6
5353

5454
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
55-
assert result == "3"
55+
assert result == 3
5656

5757
# Missing required argument should raise an error
5858
with pytest.raises(ModelBehaviorError):

‎tests/test_tool_use_behavior.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright
2+
3+
from __future__ import annotations
4+
5+
from typing import cast
6+
7+
import pytest
8+
from openai.types.responses.response_input_item_param import FunctionCallOutput
9+
10+
from agents import (
11+
Agent,
12+
FunctionToolResult,
13+
RunConfig,
14+
RunContextWrapper,
15+
ToolCallOutputItem,
16+
ToolsToFinalOutputResult,
17+
UserError,
18+
)
19+
from agents._run_impl import RunImpl
20+
21+
from .test_responses import get_function_tool
22+
23+
24+
def _make_function_tool_result(
25+
agent: Agent, output: str, tool_name: str | None = None
26+
) -> FunctionToolResult:
27+
# Construct a FunctionToolResult with the given output using a simple function tool.
28+
tool = get_function_tool(tool_name or "dummy", return_value=output)
29+
raw_item: FunctionCallOutput = cast(
30+
FunctionCallOutput,
31+
{
32+
"call_id": "1",
33+
"output": output,
34+
"type": "function_call_output",
35+
},
36+
)
37+
# For this test we don't care about the specific RunItem subclass, only the output field
38+
run_item = ToolCallOutputItem(agent=agent, raw_item=raw_item, output=output)
39+
return FunctionToolResult(tool=tool, output=output, run_item=run_item)
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_no_tool_results_returns_not_final_output() -> None:
44+
# If there are no tool results at all, tool_use_behavior should not produce a final output.
45+
agent = Agent(name="test")
46+
result = await RunImpl._check_for_final_output_from_tools(
47+
agent=agent,
48+
tool_results=[],
49+
context_wrapper=RunContextWrapper(context=None),
50+
config=RunConfig(),
51+
)
52+
assert result.is_final_output is False
53+
assert result.final_output is None
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_run_llm_again_behavior() -> None:
58+
# With the default run_llm_again behavior, even with tools we still expect to keep running.
59+
agent = Agent(name="test", tool_use_behavior="run_llm_again")
60+
tool_results = [_make_function_tool_result(agent, "ignored")]
61+
result = await RunImpl._check_for_final_output_from_tools(
62+
agent=agent,
63+
tool_results=tool_results,
64+
context_wrapper=RunContextWrapper(context=None),
65+
config=RunConfig(),
66+
)
67+
assert result.is_final_output is False
68+
assert result.final_output is None
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_stop_on_first_tool_behavior() -> None:
73+
# When tool_use_behavior is stop_on_first_tool, we should surface first tool output as final.
74+
agent = Agent(name="test", tool_use_behavior="stop_on_first_tool")
75+
tool_results = [
76+
_make_function_tool_result(agent, "first_tool_output"),
77+
_make_function_tool_result(agent, "ignored"),
78+
]
79+
result = await RunImpl._check_for_final_output_from_tools(
80+
agent=agent,
81+
tool_results=tool_results,
82+
context_wrapper=RunContextWrapper(context=None),
83+
config=RunConfig(),
84+
)
85+
assert result.is_final_output is True
86+
assert result.final_output == "first_tool_output"
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_custom_tool_use_behavior_sync() -> None:
91+
"""If tool_use_behavior is a sync function, we should call it and propagate its return."""
92+
93+
def behavior(
94+
context: RunContextWrapper, results: list[FunctionToolResult]
95+
) -> ToolsToFinalOutputResult:
96+
assert len(results) == 3
97+
return ToolsToFinalOutputResult(is_final_output=True, final_output="custom")
98+
99+
agent = Agent(name="test", tool_use_behavior=behavior)
100+
tool_results = [
101+
_make_function_tool_result(agent, "ignored1"),
102+
_make_function_tool_result(agent, "ignored2"),
103+
_make_function_tool_result(agent, "ignored3"),
104+
]
105+
result = await RunImpl._check_for_final_output_from_tools(
106+
agent=agent,
107+
tool_results=tool_results,
108+
context_wrapper=RunContextWrapper(context=None),
109+
config=RunConfig(),
110+
)
111+
assert result.is_final_output is True
112+
assert result.final_output == "custom"
113+
114+
115+
@pytest.mark.asyncio
116+
async def test_custom_tool_use_behavior_async() -> None:
117+
"""If tool_use_behavior is an async function, we should await it and propagate its return."""
118+
119+
async def behavior(
120+
context: RunContextWrapper, results: list[FunctionToolResult]
121+
) -> ToolsToFinalOutputResult:
122+
assert len(results) == 3
123+
return ToolsToFinalOutputResult(is_final_output=True, final_output="async_custom")
124+
125+
agent = Agent(name="test", tool_use_behavior=behavior)
126+
tool_results = [
127+
_make_function_tool_result(agent, "ignored1"),
128+
_make_function_tool_result(agent, "ignored2"),
129+
_make_function_tool_result(agent, "ignored3"),
130+
]
131+
result = await RunImpl._check_for_final_output_from_tools(
132+
agent=agent,
133+
tool_results=tool_results,
134+
context_wrapper=RunContextWrapper(context=None),
135+
config=RunConfig(),
136+
)
137+
assert result.is_final_output is True
138+
assert result.final_output == "async_custom"
139+
140+
141+
@pytest.mark.asyncio
142+
async def test_invalid_tool_use_behavior_raises() -> None:
143+
"""If tool_use_behavior is invalid, we should raise a UserError."""
144+
agent = Agent(name="test")
145+
# Force an invalid value; mypy will complain, so ignore the type here.
146+
agent.tool_use_behavior = "bad_value" # type: ignore[assignment]
147+
tool_results = [_make_function_tool_result(agent, "ignored")]
148+
with pytest.raises(UserError):
149+
await RunImpl._check_for_final_output_from_tools(
150+
agent=agent,
151+
tool_results=tool_results,
152+
context_wrapper=RunContextWrapper(context=None),
153+
config=RunConfig(),
154+
)
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_tool_names_to_stop_at_behavior() -> None:
159+
agent = Agent(
160+
name="test",
161+
tools=[
162+
get_function_tool("tool1", return_value="tool1_output"),
163+
get_function_tool("tool2", return_value="tool2_output"),
164+
get_function_tool("tool3", return_value="tool3_output"),
165+
],
166+
tool_use_behavior={"stop_at_tool_names": ["tool1"]},
167+
)
168+
169+
tool_results = [
170+
_make_function_tool_result(agent, "ignored1", "tool2"),
171+
_make_function_tool_result(agent, "ignored3", "tool3"),
172+
]
173+
result = await RunImpl._check_for_final_output_from_tools(
174+
agent=agent,
175+
tool_results=tool_results,
176+
context_wrapper=RunContextWrapper(context=None),
177+
config=RunConfig(),
178+
)
179+
assert result.is_final_output is False, "We should not have stopped at tool1"
180+
181+
# Now test with a tool that matches the list
182+
tool_results = [
183+
_make_function_tool_result(agent, "output1", "tool1"),
184+
_make_function_tool_result(agent, "ignored2", "tool2"),
185+
_make_function_tool_result(agent, "ignored3", "tool3"),
186+
]
187+
result = await RunImpl._check_for_final_output_from_tools(
188+
agent=agent,
189+
tool_results=tool_results,
190+
context_wrapper=RunContextWrapper(context=None),
191+
config=RunConfig(),
192+
)
193+
assert result.is_final_output is True, "We should have stopped at tool1"
194+
assert result.final_output == "output1"

0 commit comments

Comments
 (0)
Please sign in to comment.