Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"devDependencies": {
"@jest/globals": "^29.5.0",
"@langchain/community": "^0.0.43",
"@langchain/openai": "^0.0.23",
"@langchain/openai": "latest",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to get the bindTools for testing

"@langchain/scripts": "^0.0.13",
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
Expand Down
2 changes: 2 additions & 0 deletions langgraph/src/prebuilt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export {
type FunctionCallingExecutorState,
createFunctionCallingExecutor,
} from "./chat_agent_executor.js";
export { type AgentState, createReactAgent } from "./react_agent_executor.js";

export {
type ToolExecutorArgs,
type ToolInvocationInterface,
Expand Down
187 changes: 187 additions & 0 deletions langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
import {
AIMessage,
BaseMessage,
BaseMessageChunk,
isAIMessage,
SystemMessage,
} from "@langchain/core/messages";
import {
Runnable,
RunnableInterface,
RunnableLambda,
} from "@langchain/core/runnables";
import { DynamicTool, StructuredTool } from "@langchain/core/tools";

import {
BaseLanguageModelCallOptions,
BaseLanguageModelInput,
} from "@langchain/core/language_models/base";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { BaseCheckpointSaver } from "../checkpoint/base.js";
import { END, START, StateGraph } from "../graph/index.js";
import { MessagesState } from "../graph/message.js";
import { CompiledStateGraph, StateGraphArgs } from "../graph/state.js";
import { All } from "../pregel/types.js";
import { ToolNode } from "./tool_node.js";

export interface AgentState {
messages: BaseMessage[];
// TODO: This won't be set until we
// implement managed values in LangGraphJS
is_last_step: boolean;
}

export type N = typeof START | "agent" | "tools";

/**
* Creates a StateGraph agent that relies on a chat model utilizing tool calling.
* @param model The chat model that can utilize OpenAI-style function calling.
* @param tools A list of tools or a ToolNode.
* @param messageModifier An optional message modifier to apply to messages before being passed to the LLM.
* Can be a SystemMessage, string, function that takes and returns a list of messages, or a Runnable.
* @param checkpointSaver An optional checkpoint saver to persist the agent's state.
* @param interruptBefore An optional list of node names to interrupt before running.
* @param interruptAfter An optional list of node names to interrupt after running.
* @returns A compiled agent as a LangChain Runnable.
*/
export function createReactAgent(
model: BaseChatModel,
tools: ToolNode<MessagesState> | StructuredTool[],
messageModifier?:
| SystemMessage
| string
| ((messages: BaseMessage[]) => BaseMessage[])
| Runnable,
checkpointSaver?: BaseCheckpointSaver,
interruptBefore?: N[] | All,
interruptAfter?: N[] | All
): CompiledStateGraph<
AgentState,
Partial<AgentState>,
typeof START | "agent" | "tools"
> {
const schema: StateGraphArgs<AgentState>["channels"] = {
messages: {
value: (left: BaseMessage[], right: BaseMessage[]) => left.concat(right),
default: () => [],
},
is_last_step: {
value: (_?: boolean, right?: boolean) => right ?? false,
default: () => false,
},
};

let toolClasses: (StructuredTool | DynamicTool)[];
if (!Array.isArray(tools)) {
toolClasses = tools.tools;
} else {
toolClasses = tools;
}
if (!("bindTools" in model) || typeof model.bindTools !== "function") {
throw new Error(`Model ${model} must define bindTools method.`);
}
const modelWithTools = model.bindTools(toolClasses);
const modelRunnable = _createModelWrapper(modelWithTools, messageModifier);

const shouldContinue = (state: AgentState) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
if (
isAIMessage(lastMessage) &&
(!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)
) {
return END;
} else {
return "continue";
}
};

const callModel = async (state: AgentState) => {
const { messages } = state;
// TODO: Stream
Copy link
Copy Markdown
Contributor Author

@hinthornw William FH (hinthornw) May 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we wanted to support token-wise streaming via streamEvents, i think i need to actually stream here and concatenate, though i need to check. been a while since ir eviewed the js chat model implementations

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in js we don't auto promote invoke to stream like we do in py?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaict

const response = (await modelRunnable.invoke(messages)) as AIMessage;
if (
state.is_last_step &&
response?.tool_calls &&
response.tool_calls?.length > 0
) {
return {
messages: [
new AIMessage(
"Sorry, more steps are needed to process this request."
),
],
};
}
return { messages: [response] };
};

const workflow = new StateGraph<AgentState>({
channels: schema,
})
.addNode(
"agent",
new RunnableLambda({ func: callModel }).withConfig({ runName: "agent" })
)
.addNode("tools", new ToolNode<AgentState>(toolClasses))
.addEdge(START, "agent")
.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
end: END,
})
.addEdge("tools", "agent");

return workflow.compile({
checkpointer: checkpointSaver,
interruptBefore,
interruptAfter,
});
}

function _createModelWrapper(
modelWithTools: RunnableInterface<
BaseLanguageModelInput,
BaseMessageChunk,
BaseLanguageModelCallOptions
>,
messageModifier?:
| SystemMessage
| string
| ((messages: BaseMessage[]) => BaseMessage[])
| Runnable
) {
if (!messageModifier) {
return modelWithTools;
}
const endict = new RunnableLambda({
func: (messages: BaseMessage[]) => ({ messages }),
});
if (typeof messageModifier === "string") {
const systemMessage = new SystemMessage(messageModifier);
const prompt = ChatPromptTemplate.fromMessages([
systemMessage,
["placeholder", "{messages}"],
]);
return endict.pipe(prompt).pipe(modelWithTools);
}
if (typeof messageModifier === "function") {
const lambda = new RunnableLambda({ func: messageModifier }).withConfig({
runName: "message_modifier",
});
return lambda.pipe(modelWithTools);
}
if (Runnable.isRunnable(messageModifier)) {
return messageModifier.pipe(modelWithTools);
}
if (messageModifier._getType() === "system") {
const prompt = ChatPromptTemplate.fromMessages([
messageModifier,
["placeholder", "{messages}"],
]);
return endict.pipe(prompt).pipe(modelWithTools);
}
throw new Error(
`Unsupported message modifier type: ${typeof messageModifier}`
);
}
17 changes: 10 additions & 7 deletions langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import { BaseMessage, ToolMessage, AIMessage } from "@langchain/core/messages";
import { RunnableConfig } from "@langchain/core/runnables";
import { Tool } from "@langchain/core/tools";
import { StructuredTool } from "@langchain/core/tools";
import { RunnableCallable } from "../utils.js";
import { END } from "../graph/graph.js";
import { MessagesState } from "../graph/message.js";

export class ToolNode extends RunnableCallable<
BaseMessage[] | MessagesState,
BaseMessage[] | MessagesState
> {
export class ToolNode<
T extends BaseMessage[] | MessagesState
> extends RunnableCallable<T, T> {
/**
A node that runs the tools requested in the last AIMessage. It can be used
either in StateGraph with a "messages" key or in MessageGraph. If multiple
tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
*/

tools: Tool[];
tools: StructuredTool[];

constructor(tools: Tool[], name: string = "tools", tags: string[] = []) {
constructor(
tools: StructuredTool[],
name: string = "tools",
tags: string[] = []
) {
super({ name, tags, func: (input, config) => this.run(input, config) });
this.tools = tools;
}
Expand Down
115 changes: 97 additions & 18 deletions langgraph/src/tests/prebuilt.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ import { Tool } from "@langchain/core/tools";
import { ChatOpenAI } from "@langchain/openai";
import { BaseMessage, HumanMessage } from "@langchain/core/messages";
import { END } from "../index.js";
import { createFunctionCallingExecutor } from "../prebuilt/index.js";
import {
createReactAgent,
createFunctionCallingExecutor,
} from "../prebuilt/index.js";

// Tracing slows down the tests
beforeAll(() => {
process.env.LANGCHAIN_TRACING_V2 = "false";
process.env.LANGCHAIN_ENDPOINT = "";
process.env.LANGCHAIN_API_KEY = "";
process.env.LANGCHAIN_PROJECT = "";
// process.env.LANGCHAIN_TRACING_V2 = "false";
// process.env.LANGCHAIN_ENDPOINT = "";
// process.env.LANGCHAIN_API_KEY = "";
// process.env.LANGCHAIN_PROJECT = "";
});

describe("createFunctionCallingExecutor", () => {
Expand Down Expand Up @@ -43,7 +46,6 @@ describe("createFunctionCallingExecutor", () => {
messages: [new HumanMessage("What's the weather like in SF?")],
});

console.log(response);
// It needs at least one human message, one AI and one function message.
expect(response.messages.length > 3).toBe(true);
const firstFunctionMessage = (response.messages as Array<BaseMessage>).find(
Expand Down Expand Up @@ -78,26 +80,103 @@ describe("createFunctionCallingExecutor", () => {
tools,
});

const stream = await functionsAgentExecutor.stream({
messages: [new HumanMessage("What's the weather like in SF?")],
});
const stream = await functionsAgentExecutor.stream(
{
messages: [new HumanMessage("What's the weather like in SF?")],
},
{ streamMode: "values" }
);
const fullResponse = [];
for await (const item of stream) {
console.log(item);
console.log("-----\n");
fullResponse.push(item);
}

// Needs at least 3 llm calls, plus one `__end__` call.
expect(fullResponse.length >= 4).toBe(true);

const endMessage = fullResponse[fullResponse.length - 1];
expect(END in endMessage).toBe(true);
expect(endMessage[END].messages.length > 0).toBe(true);
// human -> agent -> action -> agent
expect(fullResponse.length).toEqual(4);

const functionCall = endMessage[END].messages.find(
const endState = fullResponse[fullResponse.length - 1];
// 1 human, 2 llm calls, 1 function call.
expect(endState.messages.length).toEqual(4);
const functionCall = endState.messages.find(
(message: BaseMessage) => message._getType() === "function"
);
expect(functionCall.content).toBe(weatherResponse);
});
});

describe("createReactAgent", () => {
it("can call a tool", async () => {
const weatherResponse = `Not too cold, not too hot 😎`;
const model = new ChatOpenAI();
class SanFranciscoWeatherTool extends Tool {
name = "current_weather";

description = "Get the current weather report for San Francisco, CA";

constructor() {
super();
}

async _call(_: string): Promise<string> {
return weatherResponse;
}
}
const tools = [new SanFranciscoWeatherTool()];

const reactAgent = createReactAgent(model, tools);

const response = await reactAgent.invoke({
messages: [new HumanMessage("What's the weather like in SF?")],
});

// It needs at least one human message and one AI message.
expect(response.messages.length > 1).toBe(true);
const lastMessage = response.messages[response.messages.length - 1];
expect(lastMessage._getType()).toBe("ai");
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
});

it("can stream a tool call", async () => {
const weatherResponse = `Not too cold, not too hot 😎`;
const model = new ChatOpenAI({
streaming: true,
});
class SanFranciscoWeatherTool extends Tool {
name = "current_weather";

description = "Get the current weather report for San Francisco, CA";

constructor() {
super();
}

async _call(_: string): Promise<string> {
return weatherResponse;
}
}
const tools = [new SanFranciscoWeatherTool()];

const reactAgent = createReactAgent(model, tools);

const stream = await reactAgent.stream(
{
messages: [new HumanMessage("What's the weather like in SF?")],
},
{ streamMode: "values" }
);
const fullResponse = [];
for await (const item of stream) {
fullResponse.push(item);
}

// human -> agent -> action -> agent
expect(fullResponse.length).toEqual(4);
const endState = fullResponse[fullResponse.length - 1];
// 1 human, 2 ai, 1 tool.
expect(endState.messages.length).toEqual(4);

const lastMessage = endState.messages[endState.messages.length - 1];
expect(lastMessage._getType()).toBe("ai");
expect(lastMessage.content.toLowerCase()).toContain("not too cold");
});
});
Loading