Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,23 @@ export interface AgentState {

export type N = typeof START | "agent" | "tools";

export type CreateReactAgentParams = {
llm: BaseChatModel;
tools: ToolNode<MessagesState> | StructuredTool[];
messageModifier?:
| SystemMessage
| string
| ((messages: BaseMessage[]) => BaseMessage[])
| ((messages: BaseMessage[]) => Promise<BaseMessage[]>)
| Runnable;
checkpointSaver?: BaseCheckpointSaver;
interruptBefore?: N[] | All;
interruptAfter?: N[] | All;
};

/**
* Creates a StateGraph agent that relies on a chat model utilizing tool calling.
* @param model The chat model that can utilize OpenAI-style function calling.
* Creates a StateGraph agent that relies on a chat llm utilizing tool calling.
* @param llm The chat llm that can utilize OpenAI-style function calling.
* @param tools A list of tools or a ToolNode.
* @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
* Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
Expand All @@ -47,21 +61,20 @@ export type N = typeof START | "agent" | "tools";
* @returns A compiled agent as a LangChain Runnable.
*/
export function createReactAgent(
model: BaseChatModel,
tools: ToolNode<MessagesState> | StructuredTool[],
messageModifier?:
| SystemMessage
| string
| ((messages: BaseMessage[]) => BaseMessage[])
| Runnable,
checkpointSaver?: BaseCheckpointSaver,
interruptBefore?: N[] | All,
interruptAfter?: N[] | All
props: CreateReactAgentParams
): CompiledStateGraph<
AgentState,
Partial<AgentState>,
typeof START | "agent" | "tools"
> {
const {
llm,
tools,
messageModifier,
checkpointSaver,
interruptBefore,
interruptAfter,
} = props;
const schema: StateGraphArgs<AgentState>["channels"] = {
messages: {
value: (left: BaseMessage[], right: BaseMessage[]) => left.concat(right),
Expand All @@ -75,10 +88,10 @@ export function createReactAgent(
} else {
toolClasses = tools;
}
if (!("bindTools" in model) || typeof model.bindTools !== "function") {
throw new Error(`Model ${model} must define bindTools method.`);
if (!("bindTools" in llm) || typeof llm.bindTools !== "function") {
throw new Error(`llm ${llm} must define bindTools method.`);
}
const modelWithTools = model.bindTools(toolClasses);
const modelWithTools = llm.bindTools(toolClasses);
const modelRunnable = _createModelWrapper(modelWithTools, messageModifier);

const shouldContinue = (state: AgentState) => {
Expand Down Expand Up @@ -132,6 +145,7 @@ function _createModelWrapper(
| SystemMessage
| string
| ((messages: BaseMessage[]) => BaseMessage[])
| ((messages: BaseMessage[]) => Promise<BaseMessage[]>)
| Runnable
) {
if (!messageModifier) {
Expand Down
4 changes: 2 additions & 2 deletions langgraph/src/tests/prebuilt.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ describe("createReactAgent", () => {
}
const tools = [new SanFranciscoWeatherTool()];

const reactAgent = createReactAgent(model, tools);
const reactAgent = createReactAgent({ llm: model, tools });

const response = await reactAgent.invoke({
messages: [new HumanMessage("What's the weather like in SF?")],
Expand Down Expand Up @@ -155,7 +155,7 @@ describe("createReactAgent", () => {
}
const tools = [new SanFranciscoWeatherTool()];

const reactAgent = createReactAgent(model, tools);
const reactAgent = createReactAgent({ llm: model, tools });

const stream = await reactAgent.stream(
{
Expand Down
52 changes: 45 additions & 7 deletions langgraph/src/tests/prebuilt.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ describe("createReactAgent", () => {
],
});

const agent = createReactAgent(llm, tools, "You are a helpful assistant");
const agent = createReactAgent({
llm,
tools,
messageModifier: "You are a helpful assistant",
});

const result = await agent.invoke({
messages: [new HumanMessage("Hello Input!")],
Expand Down Expand Up @@ -345,16 +349,15 @@ describe("createReactAgent", () => {
],
});

const agent = createReactAgent(
const agent = createReactAgent({
llm,
tools,
new SystemMessage("You are a helpful assistant")
);
messageModifier: new SystemMessage("You are a helpful assistant"),
});

const result = await agent.invoke({
messages: [],
});
console.log("RESULT THING", result);
expect(result.messages).toEqual([
new AIMessage({
content: "result1",
Expand Down Expand Up @@ -388,7 +391,42 @@ describe("createReactAgent", () => {
...messages,
];

const agent = createReactAgent(llm, tools, messageModifier);
const agent = createReactAgent({ llm, tools, messageModifier });

const result = await agent.invoke({
messages: [new HumanMessage("Hello Input!")],
});

expect(result.messages).toEqual([
new HumanMessage("Hello Input!"),
aiM1,
new ToolMessage({
name: "search_api",
content: "result for foo",
tool_call_id: "tool_abcd123",
}),
aiM2,
]);
});

it("Can use async custom function message modifier", async () => {
const aiM1 = new AIMessage({
content: "result1",
tool_calls: [
{ name: "search_api", id: "tool_abcd123", args: { query: "foo" } },
],
});
const aiM2 = new AIMessage("result2");
const llm = new FakeToolCallingChatModel({
responses: [aiM1, aiM2],
});

const messageModifier = async (messages: BaseMessage[]) => [
new SystemMessage("You are a helpful assistant"),
...messages,
];

const agent = createReactAgent({ llm, tools, messageModifier });

const result = await agent.invoke({
messages: [new HumanMessage("Hello Input!")],
Expand Down Expand Up @@ -425,7 +463,7 @@ describe("createReactAgent", () => {
],
});

const agent = createReactAgent(llm, tools, messageModifier);
const agent = createReactAgent({ llm, tools, messageModifier });

const result = await agent.invoke({
messages: [
Expand Down