Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions langgraph/src/graph/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ export class MessageGraph extends StateGraph<BaseMessage[], Messages> {
});
}
}

export interface MessagesState {
messages: BaseMessage[];
}
4 changes: 2 additions & 2 deletions langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,15 @@ export class CompiledStateGraph<

// add node and output channel
if (key === START) {
this.nodes[key] = new PregelNode({
this.nodes[key] = new PregelNode<State, Update>({
tags: [TAG_HIDDEN],
triggers: [START],
channels: [START],
writers: [new ChannelWrite(stateWriteEntries, [TAG_HIDDEN])],
});
} else {
this.channels[key] = new EphemeralValue();
this.nodes[key] = new PregelNode({
this.nodes[key] = new PregelNode<State, Update>({
triggers: [],
// read state keys
channels:
Expand Down
1 change: 1 addition & 0 deletions langgraph/src/prebuilt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ export {
type ToolInvocationInterface,
ToolExecutor,
} from "./tool_executor.js";
export { ToolNode, toolsCondition } from "./tool_node.js";
72 changes: 72 additions & 0 deletions langgraph/src/prebuilt/tool_node.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { BaseMessage, ToolMessage, AIMessage } from "@langchain/core/messages";
import { RunnableConfig } from "@langchain/core/runnables";
import { Tool } from "@langchain/core/tools";
import { RunnableCallable } from "../utils.js";
import { END } from "../graph/graph.js";
import { MessagesState } from "../graph/message.js";

export class ToolNode extends RunnableCallable<
BaseMessage[] | MessagesState,
BaseMessage[] | MessagesState
> {
/**
A node that runs the tools requested in the last AIMessage. It can be used
either in StateGraph with a "messages" key or in MessageGraph. If multiple
tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
*/

tools: Tool[];

constructor(tools: Tool[], name: string = "tools", tags: string[] = []) {
super({ name, tags, func: (input, config) => this.run(input, config) });
this.tools = tools;
}

private async run(
input: BaseMessage[] | MessagesState,
config: RunnableConfig
): Promise<BaseMessage[] | MessagesState> {
const message = Array.isArray(input)
? input[input.length - 1]
: input.messages[input.messages.length - 1];

if (message._getType() !== "ai") {
throw new Error("ToolNode only accepts AIMessages as input.");
}

const outputs = await Promise.all(
(message as AIMessage).tool_calls?.map(async (call) => {
const tool = this.tools.find((tool) => tool.name === call.name);
if (tool === undefined) {
throw new Error(`Tool ${call.name} not found.`);
}
const output = await tool.invoke(call.args, config);
return new ToolMessage({
name: tool.name,
content: typeof output === "string" ? output : JSON.stringify(output),
tool_call_id: call.id!,
});
}) ?? []
);

return Array.isArray(input) ? outputs : { messages: outputs };
}
}

export function toolsCondition(
state: BaseMessage[] | MessagesState
): "tools" | typeof END {
const message = Array.isArray(state)
? state[state.length - 1]
: state.messages[state.messages.length - 1];

if (
"tool_calls" in message &&
((message as AIMessage).tool_calls?.length ?? 0) > 0
) {
return "tools";
} else {
return END;
}
}
2 changes: 1 addition & 1 deletion langgraph/src/pregel/read.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ interface PregelNodeArgs<RunInput, RunOutput>
triggers: Array<string>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
mapper?: (args: any) => any;
writers?: Runnable<RunOutput, RunOutput>[];
writers?: Runnable<RunOutput, unknown>[];
tags?: string[];
bound?: Runnable<RunInput, RunOutput>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down
2 changes: 1 addition & 1 deletion langgraph/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export interface RunnableCallableArgs extends Partial<any> {
recurse?: boolean;
}

export class RunnableCallable extends Runnable {
export class RunnableCallable<I = unknown, O = unknown> extends Runnable<I, O> {
lc_namespace: string[] = ["langgraph"];

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down