Skip to content

Commit ac9d1e2

Browse files
authored
feat: update react with new agent api (#189)
# What does this PR do? - Upgrade client side ReAct lib with new API - Resolves issue with tool_prompt_format - Keeps backward compat [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` python -m examples.agents.react_agent localhost 8321 ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant)
1 parent 06db448 commit ac9d1e2

2 files changed

Lines changed: 188 additions & 65 deletions

File tree

src/llama_stack_client/lib/agents/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from llama_stack_client.types.shared_params.response_format import ResponseFormat
1818
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
1919

20-
from .client_tool import ClientTool, client_tool
20+
from .client_tool import client_tool, ClientTool
2121
from .tool_parser import ToolParser
2222

2323
DEFAULT_MAX_ITER = 10

src/llama_stack_client/lib/agents/react/agent.py

Lines changed: 187 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,105 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
from typing import Optional, Tuple
6+
import logging
7+
from typing import Any, Callable, List, Optional, Tuple, Union
78

89
from llama_stack_client import LlamaStackClient
910
from llama_stack_client.types.agent_create_params import AgentConfig
11+
from llama_stack_client.types.agents.turn_create_params import Toolgroup
12+
from llama_stack_client.types.shared_params.agent_config import ToolConfig
13+
from llama_stack_client.types.shared_params.response_format import ResponseFormat
14+
from llama_stack_client.types.shared_params.sampling_params import SamplingParams
1015

11-
12-
from ..agent import Agent
16+
from ..agent import Agent, AgentUtils
1317
from ..client_tool import ClientTool
1418
from ..tool_parser import ToolParser
1519
from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE
20+
from .tool_parser import ReActOutput, ReActToolParser
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def get_tool_defs(
26+
client: LlamaStackClient, builtin_toolgroups: Tuple[Toolgroup] = (), client_tools: Tuple[ClientTool] = ()
27+
):
28+
tool_defs = []
29+
for x in builtin_toolgroups:
30+
if isinstance(x, str):
31+
toolgroup_id = x
32+
else:
33+
toolgroup_id = x["name"]
34+
tool_defs.extend(
35+
[
36+
{
37+
"name": tool.identifier,
38+
"description": tool.description,
39+
"parameters": tool.parameters,
40+
}
41+
for tool in client.tools.list(toolgroup_id=toolgroup_id)
42+
]
43+
)
44+
45+
tool_defs.extend(
46+
[
47+
{
48+
"name": tool.get_name(),
49+
"description": tool.get_description(),
50+
"parameters": tool.get_params_definition(),
51+
}
52+
for tool in client_tools
53+
]
54+
)
55+
return tool_defs
56+
57+
58+
def get_default_react_instructions(
59+
client: LlamaStackClient, builtin_toolgroups: Tuple[str] = (), client_tools: Tuple[ClientTool] = ()
60+
):
61+
tool_defs = get_tool_defs(client, builtin_toolgroups, client_tools)
62+
tool_names = ", ".join([x["name"] for x in tool_defs])
63+
tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs])
64+
instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace(
65+
"<<tool_descriptions>>", tool_descriptions
66+
)
67+
return instruction
68+
69+
70+
def get_agent_config_DEPRECATED(
71+
client: LlamaStackClient,
72+
model: str,
73+
builtin_toolgroups: Tuple[str] = (),
74+
client_tools: Tuple[ClientTool] = (),
75+
json_response_format: bool = False,
76+
custom_agent_config: Optional[AgentConfig] = None,
77+
) -> AgentConfig:
78+
if custom_agent_config is None:
79+
instruction = get_default_react_instructions(client, builtin_toolgroups, client_tools)
80+
81+
# user default toolgroups
82+
agent_config = AgentConfig(
83+
model=model,
84+
instructions=instruction,
85+
toolgroups=builtin_toolgroups,
86+
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
87+
tool_config={
88+
"tool_choice": "auto",
89+
"system_message_behavior": "replace",
90+
},
91+
input_shields=[],
92+
output_shields=[],
93+
enable_session_persistence=False,
94+
)
95+
else:
96+
agent_config = custom_agent_config
97+
98+
if json_response_format:
99+
agent_config["response_format"] = {
100+
"type": "json_schema",
101+
"json_schema": ReActOutput.model_json_schema(),
102+
}
16103

17-
from .tool_parser import ReActToolParser, ReActOutput
104+
return agent_config
18105

19106

20107
class ReActAgent(Agent):
@@ -27,73 +114,109 @@ def __init__(
27114
self,
28115
client: LlamaStackClient,
29116
model: str,
30-
builtin_toolgroups: Tuple[str] = (),
31-
client_tools: Tuple[ClientTool] = (),
32117
tool_parser: ToolParser = ReActToolParser(),
118+
instructions: Optional[str] = None,
119+
tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None,
120+
tool_config: Optional[ToolConfig] = None,
121+
sampling_params: Optional[SamplingParams] = None,
122+
max_infer_iters: Optional[int] = None,
123+
input_shields: Optional[List[str]] = None,
124+
output_shields: Optional[List[str]] = None,
125+
response_format: Optional[ResponseFormat] = None,
126+
enable_session_persistence: Optional[bool] = None,
33127
json_response_format: bool = False,
34-
custom_agent_config: Optional[AgentConfig] = None,
128+
builtin_toolgroups: Tuple[str] = (), # DEPRECATED
129+
client_tools: Tuple[ClientTool] = (), # DEPRECATED
130+
custom_agent_config: Optional[AgentConfig] = None, # DEPRECATED
35131
):
36-
def get_tool_defs():
37-
tool_defs = []
38-
for x in builtin_toolgroups:
39-
tool_defs.extend(
40-
[
41-
{
42-
"name": tool.identifier,
43-
"description": tool.description,
44-
"parameters": tool.parameters,
45-
}
46-
for tool in client.tools.list(toolgroup_id=x)
47-
]
48-
)
49-
tool_defs.extend(
50-
[
51-
{
52-
"name": tool.get_name(),
53-
"description": tool.get_description(),
54-
"parameters": tool.get_params_definition(),
55-
}
56-
for tool in client_tools
57-
]
132+
"""Construct an Agent with the given parameters.
133+
134+
:param client: The LlamaStackClient instance.
135+
:param custom_agent_config: The AgentConfig instance.
136+
::deprecated: use other parameters instead
137+
:param client_tools: A tuple of ClientTool instances.
138+
::deprecated: use tools instead
139+
:param builtin_toolgroups: A tuple of Toolgroup instances.
140+
::deprecated: use tools instead
141+
:param tool_parser: Custom logic that parses tool calls from a message.
142+
:param model: The model to use for the agent.
143+
:param instructions: The instructions for the agent.
144+
:param tools: A list of tools for the agent. Values can be one of the following:
145+
- dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
146+
- a python function with a docstring. See @client_tool for more details.
147+
- str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
148+
- str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
149+
- an instance of ClientTool: A client tool object.
150+
:param tool_config: The tool configuration for the agent.
151+
:param sampling_params: The sampling parameters for the agent.
152+
:param max_infer_iters: The maximum number of inference iterations.
153+
:param input_shields: The input shields for the agent.
154+
:param output_shields: The output shields for the agent.
155+
:param response_format: The response format for the agent.
156+
:param enable_session_persistence: Whether to enable session persistence.
157+
:param json_response_format: Whether to use the json response format with default ReAct output schema.
158+
::deprecated: use response_format instead
159+
"""
160+
use_deprecated_params = False
161+
if custom_agent_config is not None:
162+
logger.warning("`custom_agent_config` is deprecated. Use inlined parameters instead.")
163+
use_deprecated_params = True
164+
if client_tools != ():
165+
logger.warning("`client_tools` is deprecated. Use `tools` instead.")
166+
use_deprecated_params = True
167+
if builtin_toolgroups != ():
168+
logger.warning("`builtin_toolgroups` is deprecated. Use `tools` instead.")
169+
use_deprecated_params = True
170+
171+
if use_deprecated_params:
172+
agent_config = get_agent_config_DEPRECATED(
173+
client=client,
174+
model=model,
175+
builtin_toolgroups=builtin_toolgroups,
176+
client_tools=client_tools,
177+
json_response_format=json_response_format,
58178
)
59-
return tool_defs
60-
61-
if custom_agent_config is None:
62-
tool_names, tool_descriptions = "", ""
63-
tool_defs = get_tool_defs()
64-
tool_names = ", ".join([x["name"] for x in tool_defs])
65-
tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs])
66-
instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<<tool_names>>", tool_names).replace(
67-
"<<tool_descriptions>>", tool_descriptions
179+
super().__init__(
180+
client=client,
181+
agent_config=agent_config,
182+
client_tools=client_tools,
183+
tool_parser=tool_parser,
68184
)
69185

70-
# user default toolgroups
71-
agent_config = AgentConfig(
72-
model=model,
73-
instructions=instruction,
74-
toolgroups=builtin_toolgroups,
75-
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
76-
tool_config={
186+
else:
187+
# build REACT instructions
188+
client_tools = AgentUtils.get_client_tools(tools)
189+
builtin_toolgroups = [x for x in tools if isinstance(x, str) or isinstance(x, dict)]
190+
if not instructions:
191+
instructions = get_default_react_instructions(client, builtin_toolgroups, client_tools)
192+
if not tool_config:
193+
tool_config = {
77194
"tool_choice": "auto",
78-
"tool_prompt_format": "json" if "3.1" in model else "python_list",
79195
"system_message_behavior": "replace",
80-
},
81-
input_shields=[],
82-
output_shields=[],
83-
enable_session_persistence=False,
84-
)
85-
else:
86-
agent_config = custom_agent_config
196+
}
87197

88-
if json_response_format:
89-
agent_config["response_format"] = {
90-
"type": "json_schema",
91-
"json_schema": ReActOutput.model_json_schema(),
92-
}
198+
if json_response_format:
199+
if instructions is not None:
200+
logger.warning(
201+
"Using a custom instructions, but json_response_format is set. Please make sure instructions are"
202+
"compatible with the default ReAct output format."
203+
)
204+
response_format = {
205+
"type": "json_schema",
206+
"json_schema": ReActOutput.model_json_schema(),
207+
}
93208

94-
super().__init__(
95-
client=client,
96-
agent_config=agent_config,
97-
client_tools=client_tools,
98-
tool_parser=tool_parser,
99-
)
209+
super().__init__(
210+
client=client,
211+
model=model,
212+
tool_parser=tool_parser,
213+
instructions=instructions,
214+
tools=tools,
215+
tool_config=tool_config,
216+
sampling_params=sampling_params,
217+
max_infer_iters=max_infer_iters,
218+
input_shields=input_shields,
219+
output_shields=output_shields,
220+
response_format=response_format,
221+
enable_session_persistence=enable_session_persistence,
222+
)

0 commit comments

Comments
 (0)