-
Notifications
You must be signed in to change notification settings - Fork 452
Implement createReActAgent #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
c959235
def2047
ac3bb7a
9194467
0a9f31a
065510e
c2de065
eb25fc6
af25385
09fccd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| import { BaseChatModel } from "@langchain/core/language_models/chat_models"; | ||
| import { | ||
| AIMessage, | ||
| BaseMessage, | ||
| BaseMessageChunk, | ||
| isAIMessage, | ||
| SystemMessage, | ||
| } from "@langchain/core/messages"; | ||
| import { | ||
| Runnable, | ||
| RunnableInterface, | ||
| RunnableLambda, | ||
| } from "@langchain/core/runnables"; | ||
| import { DynamicTool, StructuredTool } from "@langchain/core/tools"; | ||
|
|
||
| import { | ||
| BaseLanguageModelCallOptions, | ||
| BaseLanguageModelInput, | ||
| } from "@langchain/core/language_models/base"; | ||
| import { ChatPromptTemplate } from "@langchain/core/prompts"; | ||
| import { BaseCheckpointSaver } from "../checkpoint/base.js"; | ||
| import { END, START, StateGraph } from "../graph/index.js"; | ||
| import { MessagesState } from "../graph/message.js"; | ||
| import { CompiledStateGraph, StateGraphArgs } from "../graph/state.js"; | ||
| import { All } from "../pregel/types.js"; | ||
| import { ToolNode } from "./tool_node.js"; | ||
|
|
||
| export interface AgentState { | ||
| messages: BaseMessage[]; | ||
| // TODO: This won't be set until we | ||
| // implement managed values in LangGraphJS | ||
| is_last_step: boolean; | ||
hinthornw marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| export type N = typeof START | "agent" | "tools"; | ||
|
|
||
| /** | ||
| * Creates a StateGraph agent that relies on a chat model utilizing tool calling. | ||
| * @param model The chat model that can utilize OpenAI-style function calling. | ||
| * @param tools A list of tools or a ToolNode. | ||
| * @param messageModifier An optional message modifier to apply to messages before being passed to the LLM. | ||
| * Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable. | ||
| * @param checkpointSaver An optional checkpoint saver to persist the agent's state. | ||
| * @param interruptBefore An optional list of node names to interrupt before running. | ||
| * @param interruptAfter An optional list of node names to interrupt after running. | ||
| * @returns A compiled agent as a LangChain Runnable. | ||
| */ | ||
| export function createReactAgent( | ||
| model: BaseChatModel, | ||
| tools: ToolNode<MessagesState> | StructuredTool[], | ||
| messageModifier?: | ||
| | SystemMessage | ||
| | string | ||
| | ((messages: BaseMessage[]) => BaseMessage[]) | ||
| | Runnable, | ||
| checkpointSaver?: BaseCheckpointSaver, | ||
| interruptBefore?: N[] | All, | ||
| interruptAfter?: N[] | All | ||
| ): CompiledStateGraph< | ||
| AgentState, | ||
| Partial<AgentState>, | ||
| typeof START | "agent" | "tools" | ||
| > { | ||
| const schema: StateGraphArgs<AgentState>["channels"] = { | ||
| messages: { | ||
| value: (left: BaseMessage[], right: BaseMessage[]) => left.concat(right), | ||
| default: () => [], | ||
| }, | ||
| is_last_step: { | ||
| value: (_?: boolean, right?: boolean) => right ?? false, | ||
| default: () => false, | ||
| }, | ||
| }; | ||
|
|
||
| let toolClasses: (StructuredTool | DynamicTool)[]; | ||
| if (!Array.isArray(tools)) { | ||
| toolClasses = tools.tools; | ||
| } else { | ||
| toolClasses = tools; | ||
| } | ||
| if (!("bindTools" in model) || typeof model.bindTools !== "function") { | ||
| throw new Error(`Model ${model} must define bindTools method.`); | ||
| } | ||
| const modelWithTools = model.bindTools(toolClasses); | ||
| const modelRunnable = _createModelWrapper(modelWithTools, messageModifier); | ||
|
|
||
| const shouldContinue = (state: AgentState) => { | ||
| const { messages } = state; | ||
| const lastMessage = messages[messages.length - 1]; | ||
| if ( | ||
| isAIMessage(lastMessage) && | ||
| (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) | ||
| ) { | ||
| return END; | ||
| } else { | ||
| return "continue"; | ||
| } | ||
| }; | ||
|
|
||
| const callModel = async (state: AgentState) => { | ||
| const { messages } = state; | ||
| // TODO: Stream | ||
|
||
| const response = (await modelRunnable.invoke(messages)) as AIMessage; | ||
| if ( | ||
| state.is_last_step && | ||
| response?.tool_calls && | ||
| response.tool_calls?.length > 0 | ||
| ) { | ||
| return { | ||
| messages: [ | ||
| new AIMessage( | ||
| "Sorry, more steps are needed to process this request." | ||
| ), | ||
| ], | ||
| }; | ||
| } | ||
| return { messages: [response] }; | ||
| }; | ||
|
|
||
| const workflow = new StateGraph<AgentState>({ | ||
| channels: schema, | ||
| }) | ||
| .addNode( | ||
| "agent", | ||
| new RunnableLambda({ func: callModel }).withConfig({ runName: "agent" }) | ||
| ) | ||
| .addNode("tools", new ToolNode<AgentState>(toolClasses)) | ||
| .addEdge(START, "agent") | ||
| .addConditionalEdges("agent", shouldContinue, { | ||
| continue: "tools", | ||
| end: END, | ||
| }) | ||
| .addEdge("tools", "agent"); | ||
|
|
||
| return workflow.compile({ | ||
| checkpointer: checkpointSaver, | ||
| interruptBefore, | ||
| interruptAfter, | ||
| }); | ||
| } | ||
|
|
||
| function _createModelWrapper( | ||
| modelWithTools: RunnableInterface< | ||
| BaseLanguageModelInput, | ||
| BaseMessageChunk, | ||
| BaseLanguageModelCallOptions | ||
| >, | ||
| messageModifier?: | ||
| | SystemMessage | ||
| | string | ||
| | ((messages: BaseMessage[]) => BaseMessage[]) | ||
| | Runnable | ||
| ) { | ||
| if (!messageModifier) { | ||
| return modelWithTools; | ||
| } | ||
| const endict = new RunnableLambda({ | ||
| func: (messages: BaseMessage[]) => ({ messages }), | ||
| }); | ||
| if (typeof messageModifier === "string") { | ||
| const systemMessage = new SystemMessage(messageModifier); | ||
| const prompt = ChatPromptTemplate.fromMessages([ | ||
| systemMessage, | ||
| ["placeholder", "{messages}"], | ||
| ]); | ||
| return endict.pipe(prompt).pipe(modelWithTools); | ||
| } | ||
| if (typeof messageModifier === "function") { | ||
| const lambda = new RunnableLambda({ func: messageModifier }).withConfig({ | ||
| runName: "message_modifier", | ||
| }); | ||
| return lambda.pipe(modelWithTools); | ||
| } | ||
| if (Runnable.isRunnable(messageModifier)) { | ||
| return messageModifier.pipe(modelWithTools); | ||
| } | ||
| if (messageModifier._getType() === "system") { | ||
| const prompt = ChatPromptTemplate.fromMessages([ | ||
| messageModifier, | ||
| ["placeholder", "{messages}"], | ||
| ]); | ||
| return endict.pipe(prompt).pipe(modelWithTools); | ||
| } | ||
| throw new Error( | ||
| `Unsupported message modifier type: ${typeof messageModifier}` | ||
| ); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to get the bindTools for testing