Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0743763

Browse files
committedMar 24, 2025
[2/n] Add MCP support to Runner
### Summary: This enables users to **use** MCP inside the SDK. 1. You add a list of MCP servers to `Agent`, via `mcp_server=[...]` 2. When an agent runs, we look up its MCP tools and add them to the list of tools. 3. When a tool call occurs, we call the relevant MCP server. Notes: 1. There's some refactoring to make sure we send the full list of tools to the Runner/Model etc. 2. Right now, you could have a locally defined tool that conflicts with an MCP defined tool. I didn't add errors for that, will do in a followup. ### Test Plan: See unit tests. Also has an end to end example next PR.
1 parent 37b5749 commit 0743763

14 files changed

+662
-35
lines changed
 

‎src/agents/_run_impl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from .models.interface import ModelTracing
5151
from .run_context import RunContextWrapper, TContext
5252
from .stream_events import RunItemStreamEvent, StreamEvent
53-
from .tool import ComputerTool, FunctionTool, FunctionToolResult
53+
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
5454
from .tracing import (
5555
SpanError,
5656
Trace,
@@ -301,6 +301,7 @@ def process_model_response(
301301
cls,
302302
*,
303303
agent: Agent[Any],
304+
all_tools: list[Tool],
304305
response: ModelResponse,
305306
output_schema: AgentOutputSchema | None,
306307
handoffs: list[Handoff],
@@ -312,8 +313,8 @@ def process_model_response(
312313
computer_actions = []
313314

314315
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
315-
function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
316-
computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
316+
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
317+
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
317318

318319
for output in response.output:
319320
if isinstance(output, ResponseOutputMessage):

‎src/agents/agent.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .handoffs import Handoff
1313
from .items import ItemHelpers
1414
from .logger import logger
15+
from .mcp import MCPUtil
1516
from .model_settings import ModelSettings
1617
from .models.interface import Model
1718
from .run_context import RunContextWrapper, TContext
@@ -21,6 +22,7 @@
2122

2223
if TYPE_CHECKING:
2324
from .lifecycle import AgentHooks
25+
from .mcp import MCPServer
2426
from .result import RunResult
2527

2628

@@ -107,6 +109,16 @@ class Agent(Generic[TContext]):
107109
tools: list[Tool] = field(default_factory=list)
108110
"""A list of tools that the agent can use."""
109111

112+
mcp_servers: list[MCPServer] = field(default_factory=list)
113+
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
114+
the agent can use. Every time the agent runs, it will include tools from these servers in the
115+
list of available tools.
116+
117+
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
118+
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
119+
longer needed.
120+
"""
121+
110122
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
111123
"""A list of checks that run in parallel to the agent's execution, before generating a
112124
response. Runs only if the agent is the first agent in the chain.
@@ -205,3 +217,11 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
205217
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
206218

207219
return None
220+
221+
async def get_mcp_tools(self) -> list[Tool]:
222+
"""Fetches the available tools from the MCP servers."""
223+
return await MCPUtil.get_all_function_tools(self.mcp_servers)
224+
225+
async def get_all_tools(self) -> list[Tool]:
226+
"""All agent tools, including MCP tools and function tools."""
227+
return await MCPUtil.get_all_function_tools(self.mcp_servers) + self.tools

‎src/agents/run.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from openai.types.responses import ResponseCompletedEvent
99

10+
from agents.tool import Tool
11+
1012
from ._run_impl import (
1113
NextStepFinalOutput,
1214
NextStepHandoff,
@@ -177,7 +179,8 @@ async def run(
177179
# agent changes, or if the agent loop ends.
178180
if current_span is None:
179181
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
180-
tool_names = [t.name for t in current_agent.tools]
182+
all_tools = await cls._get_all_tools(current_agent)
183+
tool_names = [t.name for t in all_tools]
181184
if output_schema := cls._get_output_schema(current_agent):
182185
output_type_name = output_schema.output_type_name()
183186
else:
@@ -217,6 +220,7 @@ async def run(
217220
),
218221
cls._run_single_turn(
219222
agent=current_agent,
223+
all_tools=all_tools,
220224
original_input=original_input,
221225
generated_items=generated_items,
222226
hooks=hooks,
@@ -228,6 +232,7 @@ async def run(
228232
else:
229233
turn_result = await cls._run_single_turn(
230234
agent=current_agent,
235+
all_tools=all_tools,
231236
original_input=original_input,
232237
generated_items=generated_items,
233238
hooks=hooks,
@@ -627,7 +632,7 @@ async def _run_single_turn_streamed(
627632
system_prompt = await agent.get_system_prompt(context_wrapper)
628633

629634
handoffs = cls._get_handoffs(agent)
630-
635+
all_tools = await cls._get_all_tools(agent)
631636
model = cls._get_model(agent, run_config)
632637
model_settings = agent.model_settings.resolve(run_config.model_settings)
633638
final_response: ModelResponse | None = None
@@ -640,7 +645,7 @@ async def _run_single_turn_streamed(
640645
system_prompt,
641646
input,
642647
model_settings,
643-
agent.tools,
648+
all_tools,
644649
output_schema,
645650
handoffs,
646651
get_model_tracing_impl(
@@ -677,6 +682,7 @@ async def _run_single_turn_streamed(
677682
pre_step_items=streamed_result.new_items,
678683
new_response=final_response,
679684
output_schema=output_schema,
685+
all_tools=all_tools,
680686
handoffs=handoffs,
681687
hooks=hooks,
682688
context_wrapper=context_wrapper,
@@ -691,6 +697,7 @@ async def _run_single_turn(
691697
cls,
692698
*,
693699
agent: Agent[TContext],
700+
all_tools: list[Tool],
694701
original_input: str | list[TResponseInputItem],
695702
generated_items: list[RunItem],
696703
hooks: RunHooks[TContext],
@@ -721,6 +728,7 @@ async def _run_single_turn(
721728
system_prompt,
722729
input,
723730
output_schema,
731+
all_tools,
724732
handoffs,
725733
context_wrapper,
726734
run_config,
@@ -732,6 +740,7 @@ async def _run_single_turn(
732740
pre_step_items=generated_items,
733741
new_response=new_response,
734742
output_schema=output_schema,
743+
all_tools=all_tools,
735744
handoffs=handoffs,
736745
hooks=hooks,
737746
context_wrapper=context_wrapper,
@@ -743,6 +752,7 @@ async def _get_single_step_result_from_response(
743752
cls,
744753
*,
745754
agent: Agent[TContext],
755+
all_tools: list[Tool],
746756
original_input: str | list[TResponseInputItem],
747757
pre_step_items: list[RunItem],
748758
new_response: ModelResponse,
@@ -754,6 +764,7 @@ async def _get_single_step_result_from_response(
754764
) -> SingleStepResult:
755765
processed_response = RunImpl.process_model_response(
756766
agent=agent,
767+
all_tools=all_tools,
757768
response=new_response,
758769
output_schema=output_schema,
759770
handoffs=handoffs,
@@ -853,6 +864,7 @@ async def _get_new_response(
853864
system_prompt: str | None,
854865
input: list[TResponseInputItem],
855866
output_schema: AgentOutputSchema | None,
867+
all_tools: list[Tool],
856868
handoffs: list[Handoff],
857869
context_wrapper: RunContextWrapper[TContext],
858870
run_config: RunConfig,
@@ -863,7 +875,7 @@ async def _get_new_response(
863875
system_instructions=system_prompt,
864876
input=input,
865877
model_settings=model_settings,
866-
tools=agent.tools,
878+
tools=all_tools,
867879
output_schema=output_schema,
868880
handoffs=handoffs,
869881
tracing=get_model_tracing_impl(
@@ -892,6 +904,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
892904
handoffs.append(handoff(handoff_item))
893905
return handoffs
894906

907+
@classmethod
908+
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
909+
return await agent.get_all_tools()
910+
895911
@classmethod
896912
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
897913
if isinstance(run_config.model, Model):

‎tests/mcp/__init__.py

Whitespace-only changes.

‎tests/mcp/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
import sys
3+
4+
5+
# Skip MCP tests on Python 3.9
6+
def pytest_ignore_collect(collection_path, config):
7+
if sys.version_info[:2] == (3, 9):
8+
this_dir = os.path.dirname(__file__)
9+
10+
if str(collection_path).startswith(this_dir):
11+
return True

‎tests/mcp/helpers.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import json
2+
import shutil
3+
from typing import Any
4+
5+
from mcp import Tool as MCPTool
6+
from mcp.types import CallToolResult, TextContent
7+
8+
from agents.mcp import MCPServer
9+
10+
tee = shutil.which("tee") or ""
11+
assert tee, "tee not found"
12+
13+
14+
# Added dummy stream classes for patching stdio_client to avoid real I/O during tests
15+
class DummyStream:
16+
async def send(self, msg):
17+
pass
18+
19+
async def receive(self):
20+
raise Exception("Dummy receive not implemented")
21+
22+
23+
class DummyStreamsContextManager:
24+
async def __aenter__(self):
25+
return (DummyStream(), DummyStream())
26+
27+
async def __aexit__(self, exc_type, exc_val, exc_tb):
28+
pass
29+
30+
31+
class FakeMCPServer(MCPServer):
32+
def __init__(self, tools: list[MCPTool] | None = None):
33+
self.tools: list[MCPTool] = tools or []
34+
self.tool_calls: list[str] = []
35+
self.tool_results: list[str] = []
36+
37+
def add_tool(self, name: str, input_schema: dict[str, Any]):
38+
self.tools.append(MCPTool(name=name, inputSchema=input_schema))
39+
40+
async def connect(self):
41+
pass
42+
43+
async def cleanup(self):
44+
pass
45+
46+
async def list_tools(self):
47+
return self.tools
48+
49+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
50+
self.tool_calls.append(tool_name)
51+
self.tool_results.append(f"result_{tool_name}_{json.dumps(arguments)}")
52+
return CallToolResult(
53+
content=[TextContent(text=self.tool_results[-1], type="text")],
54+
)

‎tests/mcp/test_caching.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from unittest.mock import AsyncMock, patch
2+
3+
import pytest
4+
from mcp.types import ListToolsResult, Tool as MCPTool
5+
6+
from agents.mcp import MCPServerStdio
7+
8+
from .helpers import DummyStreamsContextManager, tee
9+
10+
11+
@pytest.mark.asyncio
12+
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
13+
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
14+
@patch("mcp.client.session.ClientSession.list_tools")
15+
async def test_server_caching_works(
16+
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
17+
):
18+
"""Test that if we turn caching on, the list of tools is cached and not fetched from the server
19+
on each call to `list_tools()`.
20+
"""
21+
server = MCPServerStdio(
22+
params={
23+
"command": tee,
24+
},
25+
cache_tools_list=True,
26+
)
27+
28+
tools = [
29+
MCPTool(name="tool1", inputSchema={}),
30+
MCPTool(name="tool2", inputSchema={}),
31+
]
32+
33+
mock_list_tools.return_value = ListToolsResult(tools=tools)
34+
35+
async with server:
36+
# Call list_tools() multiple times
37+
tools = await server.list_tools()
38+
assert tools == tools
39+
40+
assert mock_list_tools.call_count == 1, "list_tools() should have been called once"
41+
42+
# Call list_tools() again, should return the cached value
43+
tools = await server.list_tools()
44+
assert tools == tools
45+
46+
assert mock_list_tools.call_count == 1, "list_tools() should not have been called again"
47+
48+
# Invalidate the cache and call list_tools() again
49+
server.invalidate_tools_cache()
50+
tools = await server.list_tools()
51+
assert tools == tools
52+
53+
assert mock_list_tools.call_count == 2, "list_tools() should be called again"
54+
55+
# Without invalidating the cache, calling list_tools() again should return the cached value
56+
tools = await server.list_tools()
57+
assert tools == tools

‎tests/mcp/test_connect_disconnect.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from unittest.mock import AsyncMock, patch
2+
3+
import pytest
4+
from mcp.types import ListToolsResult, Tool as MCPTool
5+
6+
from agents.mcp import MCPServerStdio
7+
8+
from .helpers import DummyStreamsContextManager, tee
9+
10+
11+
@pytest.mark.asyncio
12+
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
13+
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
14+
@patch("mcp.client.session.ClientSession.list_tools")
15+
async def test_async_ctx_manager_works(
16+
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
17+
):
18+
"""Test that the async context manager works."""
19+
server = MCPServerStdio(
20+
params={
21+
"command": tee,
22+
},
23+
cache_tools_list=True,
24+
)
25+
26+
tools = [
27+
MCPTool(name="tool1", inputSchema={}),
28+
MCPTool(name="tool2", inputSchema={}),
29+
]
30+
31+
mock_list_tools.return_value = ListToolsResult(tools=tools)
32+
33+
assert server.session is None, "Server should not be connected"
34+
35+
async with server:
36+
assert server.session is not None, "Server should be connected"
37+
38+
assert server.session is None, "Server should be disconnected"
39+
40+
41+
@pytest.mark.asyncio
42+
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
43+
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
44+
@patch("mcp.client.session.ClientSession.list_tools")
45+
async def test_manual_connect_disconnect_works(
46+
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
47+
):
48+
"""Test that the async context manager works."""
49+
server = MCPServerStdio(
50+
params={
51+
"command": tee,
52+
},
53+
cache_tools_list=True,
54+
)
55+
56+
tools = [
57+
MCPTool(name="tool1", inputSchema={}),
58+
MCPTool(name="tool2", inputSchema={}),
59+
]
60+
61+
mock_list_tools.return_value = ListToolsResult(tools=tools)
62+
63+
assert server.session is None, "Server should not be connected"
64+
65+
await server.connect()
66+
assert server.session is not None, "Server should be connected"
67+
68+
await server.cleanup()
69+
assert server.session is None, "Server should be disconnected"

‎tests/mcp/test_mcp_util.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import logging
2+
from typing import Any
3+
4+
import pytest
5+
from mcp.types import Tool as MCPTool
6+
from pydantic import BaseModel
7+
8+
from agents import FunctionTool, RunContextWrapper
9+
from agents.exceptions import AgentsException, ModelBehaviorError
10+
from agents.mcp import MCPServer, MCPUtil
11+
12+
from .helpers import FakeMCPServer
13+
14+
15+
class Foo(BaseModel):
16+
bar: str
17+
baz: int
18+
19+
20+
class Bar(BaseModel):
21+
qux: str
22+
23+
24+
@pytest.mark.asyncio
25+
async def test_get_all_function_tools():
26+
"""Test that the get_all_function_tools function returns all function tools from a list of MCP
27+
servers.
28+
"""
29+
names = ["test_tool_1", "test_tool_2", "test_tool_3", "test_tool_4", "test_tool_5"]
30+
schemas = [
31+
{},
32+
{},
33+
{},
34+
Foo.model_json_schema(),
35+
Bar.model_json_schema(),
36+
]
37+
38+
server1 = FakeMCPServer()
39+
server1.add_tool(names[0], schemas[0])
40+
server1.add_tool(names[1], schemas[1])
41+
42+
server2 = FakeMCPServer()
43+
server2.add_tool(names[2], schemas[2])
44+
server2.add_tool(names[3], schemas[3])
45+
46+
server3 = FakeMCPServer()
47+
server3.add_tool(names[4], schemas[4])
48+
49+
servers: list[MCPServer] = [server1, server2, server3]
50+
tools = await MCPUtil.get_all_function_tools(servers)
51+
assert len(tools) == 5
52+
assert all(tool.name in names for tool in tools)
53+
54+
for idx, tool in enumerate(tools):
55+
assert isinstance(tool, FunctionTool)
56+
assert tool.params_json_schema == schemas[idx]
57+
assert tool.name == names[idx]
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_invoke_mcp_tool():
62+
"""Test that the invoke_mcp_tool function invokes an MCP tool and returns the result."""
63+
server = FakeMCPServer()
64+
server.add_tool("test_tool_1", {})
65+
66+
ctx = RunContextWrapper(context=None)
67+
tool = MCPTool(name="test_tool_1", inputSchema={})
68+
69+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
70+
# Just making sure it doesn't crash
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_mcp_invoke_bad_json_errors(caplog: pytest.LogCaptureFixture):
75+
caplog.set_level(logging.DEBUG)
76+
77+
"""Test that bad JSON input errors are logged and re-raised."""
78+
server = FakeMCPServer()
79+
server.add_tool("test_tool_1", {})
80+
81+
ctx = RunContextWrapper(context=None)
82+
tool = MCPTool(name="test_tool_1", inputSchema={})
83+
84+
with pytest.raises(ModelBehaviorError):
85+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "not_json")
86+
87+
assert "Invalid JSON input for tool test_tool_1" in caplog.text
88+
89+
90+
class CrashingFakeMCPServer(FakeMCPServer):
91+
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None):
92+
raise Exception("Crash!")
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture):
97+
caplog.set_level(logging.DEBUG)
98+
99+
"""Test that bad JSON input errors are logged and re-raised."""
100+
server = CrashingFakeMCPServer()
101+
server.add_tool("test_tool_1", {})
102+
103+
ctx = RunContextWrapper(context=None)
104+
tool = MCPTool(name="test_tool_1", inputSchema={})
105+
106+
with pytest.raises(AgentsException):
107+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
108+
109+
assert "Error invoking MCP tool test_tool_1" in caplog.text

‎tests/mcp/test_runner_calls_mcp.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import json
2+
3+
import pytest
4+
from pydantic import BaseModel
5+
6+
from agents import Agent, ModelBehaviorError, Runner, UserError
7+
8+
from ..fake_model import FakeModel
9+
from ..test_responses import get_function_tool_call, get_text_message
10+
from .helpers import FakeMCPServer
11+
12+
13+
@pytest.mark.asyncio
14+
@pytest.mark.parametrize("streaming", [False, True])
15+
async def test_runner_calls_mcp_tool(streaming: bool):
16+
"""Test that the runner calls an MCP tool when the model produces a tool call."""
17+
server = FakeMCPServer()
18+
server.add_tool("test_tool_1", {})
19+
server.add_tool("test_tool_2", {})
20+
server.add_tool("test_tool_3", {})
21+
model = FakeModel()
22+
agent = Agent(
23+
name="test",
24+
model=model,
25+
mcp_servers=[server],
26+
)
27+
28+
model.add_multiple_turn_outputs(
29+
[
30+
# First turn: a message and tool call
31+
[get_text_message("a_message"), get_function_tool_call("test_tool_2", "")],
32+
# Second turn: text message
33+
[get_text_message("done")],
34+
]
35+
)
36+
37+
if streaming:
38+
result = Runner.run_streamed(agent, input="user_message")
39+
async for _ in result.stream_events():
40+
pass
41+
else:
42+
await Runner.run(agent, input="user_message")
43+
44+
assert server.tool_calls == ["test_tool_2"]
45+
46+
47+
@pytest.mark.asyncio
48+
@pytest.mark.parametrize("streaming", [False, True])
49+
async def test_runner_asserts_when_mcp_tool_not_found(streaming: bool):
50+
"""Test that the runner asserts when an MCP tool is not found."""
51+
server = FakeMCPServer()
52+
server.add_tool("test_tool_1", {})
53+
server.add_tool("test_tool_2", {})
54+
server.add_tool("test_tool_3", {})
55+
model = FakeModel()
56+
agent = Agent(
57+
name="test",
58+
model=model,
59+
mcp_servers=[server],
60+
)
61+
62+
model.add_multiple_turn_outputs(
63+
[
64+
# First turn: a message and tool call
65+
[get_text_message("a_message"), get_function_tool_call("test_tool_doesnt_exist", "")],
66+
# Second turn: text message
67+
[get_text_message("done")],
68+
]
69+
)
70+
71+
with pytest.raises(ModelBehaviorError):
72+
if streaming:
73+
result = Runner.run_streamed(agent, input="user_message")
74+
async for _ in result.stream_events():
75+
pass
76+
else:
77+
await Runner.run(agent, input="user_message")
78+
79+
80+
@pytest.mark.asyncio
81+
@pytest.mark.parametrize("streaming", [False, True])
82+
async def test_runner_works_with_multiple_mcp_servers(streaming: bool):
83+
"""Test that the runner works with multiple MCP servers."""
84+
server1 = FakeMCPServer()
85+
server1.add_tool("test_tool_1", {})
86+
87+
server2 = FakeMCPServer()
88+
server2.add_tool("test_tool_2", {})
89+
server2.add_tool("test_tool_3", {})
90+
91+
model = FakeModel()
92+
agent = Agent(
93+
name="test",
94+
model=model,
95+
mcp_servers=[server1, server2],
96+
)
97+
98+
model.add_multiple_turn_outputs(
99+
[
100+
# First turn: a message and tool call
101+
[get_text_message("a_message"), get_function_tool_call("test_tool_2", "")],
102+
# Second turn: text message
103+
[get_text_message("done")],
104+
]
105+
)
106+
107+
if streaming:
108+
result = Runner.run_streamed(agent, input="user_message")
109+
async for _ in result.stream_events():
110+
pass
111+
else:
112+
await Runner.run(agent, input="user_message")
113+
114+
assert server1.tool_calls == []
115+
assert server2.tool_calls == ["test_tool_2"]
116+
117+
118+
@pytest.mark.asyncio
119+
@pytest.mark.parametrize("streaming", [False, True])
120+
async def test_runner_errors_when_mcp_tools_clash(streaming: bool):
121+
"""Test that the runner errors when multiple servers have the same tool name."""
122+
server1 = FakeMCPServer()
123+
server1.add_tool("test_tool_1", {})
124+
server1.add_tool("test_tool_2", {})
125+
126+
server2 = FakeMCPServer()
127+
server2.add_tool("test_tool_2", {})
128+
server2.add_tool("test_tool_3", {})
129+
130+
model = FakeModel()
131+
agent = Agent(
132+
name="test",
133+
model=model,
134+
mcp_servers=[server1, server2],
135+
)
136+
137+
model.add_multiple_turn_outputs(
138+
[
139+
# First turn: a message and tool call
140+
[get_text_message("a_message"), get_function_tool_call("test_tool_3", "")],
141+
# Second turn: text message
142+
[get_text_message("done")],
143+
]
144+
)
145+
146+
with pytest.raises(UserError):
147+
if streaming:
148+
result = Runner.run_streamed(agent, input="user_message")
149+
async for _ in result.stream_events():
150+
pass
151+
else:
152+
await Runner.run(agent, input="user_message")
153+
154+
155+
class Foo(BaseModel):
156+
bar: str
157+
baz: int
158+
159+
160+
@pytest.mark.asyncio
161+
@pytest.mark.parametrize("streaming", [False, True])
162+
async def test_runner_calls_mcp_tool_with_args(streaming: bool):
163+
"""Test that the runner calls an MCP tool when the model produces a tool call."""
164+
server = FakeMCPServer()
165+
await server.connect()
166+
server.add_tool("test_tool_1", {})
167+
server.add_tool("test_tool_2", Foo.model_json_schema())
168+
server.add_tool("test_tool_3", {})
169+
model = FakeModel()
170+
agent = Agent(
171+
name="test",
172+
model=model,
173+
mcp_servers=[server],
174+
)
175+
176+
json_args = json.dumps(Foo(bar="baz", baz=1).model_dump())
177+
178+
model.add_multiple_turn_outputs(
179+
[
180+
# First turn: a message and tool call
181+
[get_text_message("a_message"), get_function_tool_call("test_tool_2", json_args)],
182+
# Second turn: text message
183+
[get_text_message("done")],
184+
]
185+
)
186+
187+
if streaming:
188+
result = Runner.run_streamed(agent, input="user_message")
189+
async for _ in result.stream_events():
190+
pass
191+
else:
192+
await Runner.run(agent, input="user_message")
193+
194+
assert server.tool_calls == ["test_tool_2"]
195+
assert server.tool_results == [f"result_test_tool_2_{json_args}"]
196+
197+
await server.cleanup()

‎tests/mcp/test_server_errors.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
3+
from agents.exceptions import UserError
4+
from agents.mcp.server import _MCPServerWithClientSession
5+
6+
7+
class CrashingClientSessionServer(_MCPServerWithClientSession):
8+
def __init__(self):
9+
super().__init__(cache_tools_list=False)
10+
self.cleanup_called = False
11+
12+
def create_streams(self):
13+
raise ValueError("Crash!")
14+
15+
async def cleanup(self):
16+
self.cleanup_called = True
17+
await super().cleanup()
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_server_errors_cause_error_and_cleanup_called():
22+
server = CrashingClientSessionServer()
23+
24+
with pytest.raises(ValueError):
25+
await server.connect()
26+
27+
assert server.cleanup_called
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_not_calling_connect_causes_error():
32+
server = CrashingClientSessionServer()
33+
34+
with pytest.raises(UserError):
35+
await server.list_tools()
36+
37+
with pytest.raises(UserError):
38+
await server.call_tool("foo", {})

‎tests/test_run_step_execution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ async def get_execute_result(
290290

291291
processed_response = RunImpl.process_model_response(
292292
agent=agent,
293+
all_tools=await agent.get_all_tools(),
293294
response=response,
294295
output_schema=output_schema,
295296
handoffs=handoffs,

‎tests/test_run_step_processing.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ def test_empty_response():
4343
)
4444

4545
result = RunImpl.process_model_response(
46-
agent=agent, response=response, output_schema=None, handoffs=[]
46+
agent=agent,
47+
response=response,
48+
output_schema=None,
49+
handoffs=[],
50+
all_tools=[],
4751
)
4852
assert not result.handoffs
4953
assert not result.functions
@@ -57,13 +61,14 @@ def test_no_tool_calls():
5761
referenceable_id=None,
5862
)
5963
result = RunImpl.process_model_response(
60-
agent=agent, response=response, output_schema=None, handoffs=[]
64+
agent=agent, response=response, output_schema=None, handoffs=[], all_tools=[]
6165
)
6266
assert not result.handoffs
6367
assert not result.functions
6468

6569

66-
def test_single_tool_call():
70+
@pytest.mark.asyncio
71+
async def test_single_tool_call():
6772
agent = Agent(name="test", tools=[get_function_tool(name="test")])
6873
response = ModelResponse(
6974
output=[
@@ -74,7 +79,11 @@ def test_single_tool_call():
7479
referenceable_id=None,
7580
)
7681
result = RunImpl.process_model_response(
77-
agent=agent, response=response, output_schema=None, handoffs=[]
82+
agent=agent,
83+
response=response,
84+
output_schema=None,
85+
handoffs=[],
86+
all_tools=await agent.get_all_tools(),
7887
)
7988
assert not result.handoffs
8089
assert result.functions and len(result.functions) == 1
@@ -84,7 +93,8 @@ def test_single_tool_call():
8493
assert func.tool_call.arguments == ""
8594

8695

87-
def test_missing_tool_call_raises_error():
96+
@pytest.mark.asyncio
97+
async def test_missing_tool_call_raises_error():
8898
agent = Agent(name="test", tools=[get_function_tool(name="test")])
8999
response = ModelResponse(
90100
output=[
@@ -97,11 +107,16 @@ def test_missing_tool_call_raises_error():
97107

98108
with pytest.raises(ModelBehaviorError):
99109
RunImpl.process_model_response(
100-
agent=agent, response=response, output_schema=None, handoffs=[]
110+
agent=agent,
111+
response=response,
112+
output_schema=None,
113+
handoffs=[],
114+
all_tools=await agent.get_all_tools(),
101115
)
102116

103117

104-
def test_multiple_tool_calls():
118+
@pytest.mark.asyncio
119+
async def test_multiple_tool_calls():
105120
agent = Agent(
106121
name="test",
107122
tools=[
@@ -121,7 +136,11 @@ def test_multiple_tool_calls():
121136
)
122137

123138
result = RunImpl.process_model_response(
124-
agent=agent, response=response, output_schema=None, handoffs=[]
139+
agent=agent,
140+
response=response,
141+
output_schema=None,
142+
handoffs=[],
143+
all_tools=await agent.get_all_tools(),
125144
)
126145
assert not result.handoffs
127146
assert result.functions and len(result.functions) == 2
@@ -146,7 +165,11 @@ async def test_handoffs_parsed_correctly():
146165
referenceable_id=None,
147166
)
148167
result = RunImpl.process_model_response(
149-
agent=agent_3, response=response, output_schema=None, handoffs=[]
168+
agent=agent_3,
169+
response=response,
170+
output_schema=None,
171+
handoffs=[],
172+
all_tools=await agent_3.get_all_tools(),
150173
)
151174
assert not result.handoffs, "Shouldn't have a handoff here"
152175

@@ -160,6 +183,7 @@ async def test_handoffs_parsed_correctly():
160183
response=response,
161184
output_schema=None,
162185
handoffs=Runner._get_handoffs(agent_3),
186+
all_tools=await agent_3.get_all_tools(),
163187
)
164188
assert len(result.handoffs) == 1, "Should have a handoff here"
165189
handoff = result.handoffs[0]
@@ -189,10 +213,12 @@ async def test_missing_handoff_fails():
189213
response=response,
190214
output_schema=None,
191215
handoffs=Runner._get_handoffs(agent_3),
216+
all_tools=await agent_3.get_all_tools(),
192217
)
193218

194219

195-
def test_multiple_handoffs_doesnt_error():
220+
@pytest.mark.asyncio
221+
async def test_multiple_handoffs_doesnt_error():
196222
agent_1 = Agent(name="test_1")
197223
agent_2 = Agent(name="test_2")
198224
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
@@ -210,6 +236,7 @@ def test_multiple_handoffs_doesnt_error():
210236
response=response,
211237
output_schema=None,
212238
handoffs=Runner._get_handoffs(agent_3),
239+
all_tools=await agent_3.get_all_tools(),
213240
)
214241
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
215242

@@ -218,7 +245,8 @@ class Foo(BaseModel):
218245
bar: str
219246

220247

221-
def test_final_output_parsed_correctly():
248+
@pytest.mark.asyncio
249+
async def test_final_output_parsed_correctly():
222250
agent = Agent(name="test", output_type=Foo)
223251
response = ModelResponse(
224252
output=[
@@ -234,10 +262,12 @@ def test_final_output_parsed_correctly():
234262
response=response,
235263
output_schema=Runner._get_output_schema(agent),
236264
handoffs=[],
265+
all_tools=await agent.get_all_tools(),
237266
)
238267

239268

240-
def test_file_search_tool_call_parsed_correctly():
269+
@pytest.mark.asyncio
270+
async def test_file_search_tool_call_parsed_correctly():
241271
# Ensure that a ResponseFileSearchToolCall output is parsed into a ToolCallItem and that no tool
242272
# runs are scheduled.
243273

@@ -254,7 +284,11 @@ def test_file_search_tool_call_parsed_correctly():
254284
referenceable_id=None,
255285
)
256286
result = RunImpl.process_model_response(
257-
agent=agent, response=response, output_schema=None, handoffs=[]
287+
agent=agent,
288+
response=response,
289+
output_schema=None,
290+
handoffs=[],
291+
all_tools=await agent.get_all_tools(),
258292
)
259293
# The final item should be a ToolCallItem for the file search call
260294
assert any(
@@ -265,7 +299,8 @@ def test_file_search_tool_call_parsed_correctly():
265299
assert not result.handoffs
266300

267301

268-
def test_function_web_search_tool_call_parsed_correctly():
302+
@pytest.mark.asyncio
303+
async def test_function_web_search_tool_call_parsed_correctly():
269304
agent = Agent(name="test")
270305
web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call")
271306
response = ModelResponse(
@@ -274,7 +309,11 @@ def test_function_web_search_tool_call_parsed_correctly():
274309
referenceable_id=None,
275310
)
276311
result = RunImpl.process_model_response(
277-
agent=agent, response=response, output_schema=None, handoffs=[]
312+
agent=agent,
313+
response=response,
314+
output_schema=None,
315+
handoffs=[],
316+
all_tools=await agent.get_all_tools(),
278317
)
279318
assert any(
280319
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
@@ -284,7 +323,8 @@ def test_function_web_search_tool_call_parsed_correctly():
284323
assert not result.handoffs
285324

286325

287-
def test_reasoning_item_parsed_correctly():
326+
@pytest.mark.asyncio
327+
async def test_reasoning_item_parsed_correctly():
288328
# Verify that a Reasoning output item is converted into a ReasoningItem.
289329

290330
reasoning = ResponseReasoningItem(
@@ -296,7 +336,11 @@ def test_reasoning_item_parsed_correctly():
296336
referenceable_id=None,
297337
)
298338
result = RunImpl.process_model_response(
299-
agent=Agent(name="test"), response=response, output_schema=None, handoffs=[]
339+
agent=Agent(name="test"),
340+
response=response,
341+
output_schema=None,
342+
handoffs=[],
343+
all_tools=await Agent(name="test").get_all_tools(),
300344
)
301345
assert any(
302346
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
@@ -342,7 +386,8 @@ def drag(self, path: list[tuple[int, int]]) -> None:
342386
return None # pragma: no cover
343387

344388

345-
def test_computer_tool_call_without_computer_tool_raises_error():
389+
@pytest.mark.asyncio
390+
async def test_computer_tool_call_without_computer_tool_raises_error():
346391
# If the agent has no ComputerTool in its tools, process_model_response should raise a
347392
# ModelBehaviorError when encountering a ResponseComputerToolCall.
348393
computer_call = ResponseComputerToolCall(
@@ -360,11 +405,16 @@ def test_computer_tool_call_without_computer_tool_raises_error():
360405
)
361406
with pytest.raises(ModelBehaviorError):
362407
RunImpl.process_model_response(
363-
agent=Agent(name="test"), response=response, output_schema=None, handoffs=[]
408+
agent=Agent(name="test"),
409+
response=response,
410+
output_schema=None,
411+
handoffs=[],
412+
all_tools=await Agent(name="test").get_all_tools(),
364413
)
365414

366415

367-
def test_computer_tool_call_with_computer_tool_parsed_correctly():
416+
@pytest.mark.asyncio
417+
async def test_computer_tool_call_with_computer_tool_parsed_correctly():
368418
# If the agent contains a ComputerTool, ensure that a ResponseComputerToolCall is parsed into a
369419
# ToolCallItem and scheduled to run in computer_actions.
370420
dummy_computer = DummyComputer()
@@ -383,7 +433,11 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly():
383433
referenceable_id=None,
384434
)
385435
result = RunImpl.process_model_response(
386-
agent=agent, response=response, output_schema=None, handoffs=[]
436+
agent=agent,
437+
response=response,
438+
output_schema=None,
439+
handoffs=[],
440+
all_tools=await agent.get_all_tools(),
387441
)
388442
assert any(
389443
isinstance(item, ToolCallItem) and item.raw_item is computer_call
@@ -392,7 +446,8 @@ def test_computer_tool_call_with_computer_tool_parsed_correctly():
392446
assert result.computer_actions and result.computer_actions[0].tool_call == computer_call
393447

394448

395-
def test_tool_and_handoff_parsed_correctly():
449+
@pytest.mark.asyncio
450+
async def test_tool_and_handoff_parsed_correctly():
396451
agent_1 = Agent(name="test_1")
397452
agent_2 = Agent(name="test_2")
398453
agent_3 = Agent(
@@ -413,6 +468,7 @@ def test_tool_and_handoff_parsed_correctly():
413468
response=response,
414469
output_schema=None,
415470
handoffs=Runner._get_handoffs(agent_3),
471+
all_tools=await agent_3.get_all_tools(),
416472
)
417473
assert result.functions and len(result.functions) == 1
418474
assert len(result.handoffs) == 1, "Should have a handoff here"

‎tests/voice/test_openai_stt.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def fake_time_func():
269269
async for _ in turns:
270270
pass
271271

272-
assert "Timeout waiting for transcription_session.created event" in str(exc_info.value)
272+
assert "Timeout waiting for transcription_session.created event" in str(exc_info.value)
273273

274274
await session.close()
275275

@@ -302,13 +302,11 @@ async def test_session_error_event():
302302
trace_include_sensitive_audio_data=False,
303303
)
304304

305-
with pytest.raises(STTWebsocketConnectionError) as exc_info:
305+
with pytest.raises(STTWebsocketConnectionError):
306306
turns = session.transcribe_turns()
307307
async for _ in turns:
308308
pass
309309

310-
assert "Simulated server error!" in str(exc_info.value)
311-
312310
await session.close()
313311

314312

@@ -362,8 +360,8 @@ async def test_inactivity_timeout():
362360
async for turn in session.transcribe_turns():
363361
collected_turns.append(turn)
364362

365-
assert "Timeout waiting for transcription_session" in str(exc_info.value)
363+
assert "Timeout waiting for transcription_session" in str(exc_info.value)
366364

367-
assert len(collected_turns) == 0, "No transcripts expected, but we got something?"
365+
assert len(collected_turns) == 0, "No transcripts expected, but we got something?"
368366

369367
await session.close()

0 commit comments

Comments
 (0)
Please sign in to comment.