33#
44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
6- from typing import Optional , Tuple
6+ import logging
7+ from typing import Any , Callable , List , Optional , Tuple , Union
78
89from llama_stack_client import LlamaStackClient
910from llama_stack_client .types .agent_create_params import AgentConfig
11+ from llama_stack_client .types .agents .turn_create_params import Toolgroup
12+ from llama_stack_client .types .shared_params .agent_config import ToolConfig
13+ from llama_stack_client .types .shared_params .response_format import ResponseFormat
14+ from llama_stack_client .types .shared_params .sampling_params import SamplingParams
1015
11-
12- from ..agent import Agent
16+ from ..agent import Agent , AgentUtils
1317from ..client_tool import ClientTool
1418from ..tool_parser import ToolParser
1519from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE
20+ from .tool_parser import ReActOutput , ReActToolParser
21+
22+ logger = logging .getLogger (__name__ )
23+
24+
25+ def get_tool_defs (
26+ client : LlamaStackClient , builtin_toolgroups : Tuple [Toolgroup ] = (), client_tools : Tuple [ClientTool ] = ()
27+ ):
28+ tool_defs = []
29+ for x in builtin_toolgroups :
30+ if isinstance (x , str ):
31+ toolgroup_id = x
32+ else :
33+ toolgroup_id = x ["name" ]
34+ tool_defs .extend (
35+ [
36+ {
37+ "name" : tool .identifier ,
38+ "description" : tool .description ,
39+ "parameters" : tool .parameters ,
40+ }
41+ for tool in client .tools .list (toolgroup_id = toolgroup_id )
42+ ]
43+ )
44+
45+ tool_defs .extend (
46+ [
47+ {
48+ "name" : tool .get_name (),
49+ "description" : tool .get_description (),
50+ "parameters" : tool .get_params_definition (),
51+ }
52+ for tool in client_tools
53+ ]
54+ )
55+ return tool_defs
56+
57+
58+ def get_default_react_instructions (
59+ client : LlamaStackClient , builtin_toolgroups : Tuple [str ] = (), client_tools : Tuple [ClientTool ] = ()
60+ ):
61+ tool_defs = get_tool_defs (client , builtin_toolgroups , client_tools )
62+ tool_names = ", " .join ([x ["name" ] for x in tool_defs ])
63+ tool_descriptions = "\n " .join ([f"- { x ['name' ]} : { x } " for x in tool_defs ])
64+ instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE .replace ("<<tool_names>>" , tool_names ).replace (
65+ "<<tool_descriptions>>" , tool_descriptions
66+ )
67+ return instruction
68+
69+
70+ def get_agent_config_DEPRECATED (
71+ client : LlamaStackClient ,
72+ model : str ,
73+ builtin_toolgroups : Tuple [str ] = (),
74+ client_tools : Tuple [ClientTool ] = (),
75+ json_response_format : bool = False ,
76+ custom_agent_config : Optional [AgentConfig ] = None ,
77+ ) -> AgentConfig :
78+ if custom_agent_config is None :
79+ instruction = get_default_react_instructions (client , builtin_toolgroups , client_tools )
80+
81+ # user default toolgroups
82+ agent_config = AgentConfig (
83+ model = model ,
84+ instructions = instruction ,
85+ toolgroups = builtin_toolgroups ,
86+ client_tools = [client_tool .get_tool_definition () for client_tool in client_tools ],
87+ tool_config = {
88+ "tool_choice" : "auto" ,
89+ "system_message_behavior" : "replace" ,
90+ },
91+ input_shields = [],
92+ output_shields = [],
93+ enable_session_persistence = False ,
94+ )
95+ else :
96+ agent_config = custom_agent_config
97+
98+ if json_response_format :
99+ agent_config ["response_format" ] = {
100+ "type" : "json_schema" ,
101+ "json_schema" : ReActOutput .model_json_schema (),
102+ }
16103
17- from . tool_parser import ReActToolParser , ReActOutput
104+ return agent_config
18105
19106
20107class ReActAgent (Agent ):
@@ -27,73 +114,109 @@ def __init__(
27114 self ,
28115 client : LlamaStackClient ,
29116 model : str ,
30- builtin_toolgroups : Tuple [str ] = (),
31- client_tools : Tuple [ClientTool ] = (),
32117 tool_parser : ToolParser = ReActToolParser (),
118+ instructions : Optional [str ] = None ,
119+ tools : Optional [List [Union [Toolgroup , ClientTool , Callable [..., Any ]]]] = None ,
120+ tool_config : Optional [ToolConfig ] = None ,
121+ sampling_params : Optional [SamplingParams ] = None ,
122+ max_infer_iters : Optional [int ] = None ,
123+ input_shields : Optional [List [str ]] = None ,
124+ output_shields : Optional [List [str ]] = None ,
125+ response_format : Optional [ResponseFormat ] = None ,
126+ enable_session_persistence : Optional [bool ] = None ,
33127 json_response_format : bool = False ,
34- custom_agent_config : Optional [AgentConfig ] = None ,
128+ builtin_toolgroups : Tuple [str ] = (), # DEPRECATED
129+ client_tools : Tuple [ClientTool ] = (), # DEPRECATED
130+ custom_agent_config : Optional [AgentConfig ] = None , # DEPRECATED
35131 ):
36- def get_tool_defs ():
37- tool_defs = []
38- for x in builtin_toolgroups :
39- tool_defs .extend (
40- [
41- {
42- "name" : tool .identifier ,
43- "description" : tool .description ,
44- "parameters" : tool .parameters ,
45- }
46- for tool in client .tools .list (toolgroup_id = x )
47- ]
48- )
49- tool_defs .extend (
50- [
51- {
52- "name" : tool .get_name (),
53- "description" : tool .get_description (),
54- "parameters" : tool .get_params_definition (),
55- }
56- for tool in client_tools
57- ]
132+ """Construct an Agent with the given parameters.
133+
134+ :param client: The LlamaStackClient instance.
135+ :param custom_agent_config: The AgentConfig instance.
136+ ::deprecated: use other parameters instead
137+ :param client_tools: A tuple of ClientTool instances.
138+ ::deprecated: use tools instead
139+ :param builtin_toolgroups: A tuple of Toolgroup instances.
140+ ::deprecated: use tools instead
141+ :param tool_parser: Custom logic that parses tool calls from a message.
142+ :param model: The model to use for the agent.
143+ :param instructions: The instructions for the agent.
144+ :param tools: A list of tools for the agent. Values can be one of the following:
145+ - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}}
146+ - a python function with a docstring. See @client_tool for more details.
147+ - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search"
148+ - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent
149+ - an instance of ClientTool: A client tool object.
150+ :param tool_config: The tool configuration for the agent.
151+ :param sampling_params: The sampling parameters for the agent.
152+ :param max_infer_iters: The maximum number of inference iterations.
153+ :param input_shields: The input shields for the agent.
154+ :param output_shields: The output shields for the agent.
155+ :param response_format: The response format for the agent.
156+ :param enable_session_persistence: Whether to enable session persistence.
157+ :param json_response_format: Whether to use the json response format with default ReAct output schema.
158+ ::deprecated: use response_format instead
159+ """
160+ use_deprecated_params = False
161+ if custom_agent_config is not None :
162+ logger .warning ("`custom_agent_config` is deprecated. Use inlined parameters instead." )
163+ use_deprecated_params = True
164+ if client_tools != ():
165+ logger .warning ("`client_tools` is deprecated. Use `tools` instead." )
166+ use_deprecated_params = True
167+ if builtin_toolgroups != ():
168+ logger .warning ("`builtin_toolgroups` is deprecated. Use `tools` instead." )
169+ use_deprecated_params = True
170+
171+ if use_deprecated_params :
172+ agent_config = get_agent_config_DEPRECATED (
173+ client = client ,
174+ model = model ,
175+ builtin_toolgroups = builtin_toolgroups ,
176+ client_tools = client_tools ,
177+ json_response_format = json_response_format ,
58178 )
59- return tool_defs
60-
61- if custom_agent_config is None :
62- tool_names , tool_descriptions = "" , ""
63- tool_defs = get_tool_defs ()
64- tool_names = ", " .join ([x ["name" ] for x in tool_defs ])
65- tool_descriptions = "\n " .join ([f"- { x ['name' ]} : { x } " for x in tool_defs ])
66- instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE .replace ("<<tool_names>>" , tool_names ).replace (
67- "<<tool_descriptions>>" , tool_descriptions
179+ super ().__init__ (
180+ client = client ,
181+ agent_config = agent_config ,
182+ client_tools = client_tools ,
183+ tool_parser = tool_parser ,
68184 )
69185
70- # user default toolgroups
71- agent_config = AgentConfig (
72- model = model ,
73- instructions = instruction ,
74- toolgroups = builtin_toolgroups ,
75- client_tools = [client_tool .get_tool_definition () for client_tool in client_tools ],
76- tool_config = {
186+ else :
187+ # build REACT instructions
188+ client_tools = AgentUtils .get_client_tools (tools )
189+ builtin_toolgroups = [x for x in tools if isinstance (x , str ) or isinstance (x , dict )]
190+ if not instructions :
191+ instructions = get_default_react_instructions (client , builtin_toolgroups , client_tools )
192+ if not tool_config :
193+ tool_config = {
77194 "tool_choice" : "auto" ,
78- "tool_prompt_format" : "json" if "3.1" in model else "python_list" ,
79195 "system_message_behavior" : "replace" ,
80- },
81- input_shields = [],
82- output_shields = [],
83- enable_session_persistence = False ,
84- )
85- else :
86- agent_config = custom_agent_config
196+ }
87197
88- if json_response_format :
89- agent_config ["response_format" ] = {
90- "type" : "json_schema" ,
91- "json_schema" : ReActOutput .model_json_schema (),
92- }
198+ if json_response_format :
199+ if instructions is not None :
200+ logger .warning (
201+ "Using a custom instructions, but json_response_format is set. Please make sure instructions are"
202+ "compatible with the default ReAct output format."
203+ )
204+ response_format = {
205+ "type" : "json_schema" ,
206+ "json_schema" : ReActOutput .model_json_schema (),
207+ }
93208
94- super ().__init__ (
95- client = client ,
96- agent_config = agent_config ,
97- client_tools = client_tools ,
98- tool_parser = tool_parser ,
99- )
209+ super ().__init__ (
210+ client = client ,
211+ model = model ,
212+ tool_parser = tool_parser ,
213+ instructions = instructions ,
214+ tools = tools ,
215+ tool_config = tool_config ,
216+ sampling_params = sampling_params ,
217+ max_infer_iters = max_infer_iters ,
218+ input_shields = input_shields ,
219+ output_shields = output_shields ,
220+ response_format = response_format ,
221+ enable_session_persistence = enable_session_persistence ,
222+ )
0 commit comments