Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
80c2e9f
Implement Pregel types and update debug print methods.
andrewnguonly Apr 22, 2024
4aa4d54
Merge branch 'main' into migrate-pregel
andrewnguonly Apr 23, 2024
34b5251
Rename ChannelInvoke to PregelNode. Remove field. Update constructor…
andrewnguonly Apr 23, 2024
61fe2dd
Change type of PregelNodes.channels back to support string instead of…
andrewnguonly Apr 24, 2024
a53ceed
Update ChannelRead class.
andrewnguonly Apr 24, 2024
c25077a
Update ChannelWrite.
andrewnguonly Apr 24, 2024
6dea09e
Update PregelNode pipe() function.
andrewnguonly Apr 24, 2024
229b40e
Update mapInput() and readChannels() in io.ts.
andrewnguonly Apr 24, 2024
8532bf8
Add mapOutputValues() and mapOutputUpdates() to io.ts.
andrewnguonly Apr 24, 2024
4381301
Fix lint error.
andrewnguonly Apr 24, 2024
2ed1110
Remove when parameter from Channel.subscribeTo() function.
andrewnguonly Apr 24, 2024
915e233
Fix bug in readChannels(). Add unit tests for readChannel() and readC…
andrewnguonly Apr 25, 2024
fb74f00
Add unit tests for mapInput(), mapOutputValues(), mapOutputUpdates().
andrewnguonly Apr 25, 2024
8736ce5
Update _write() to match implementation in https://github.com/langcha…
andrewnguonly Apr 25, 2024
e594c1d
Fix lint errors.
andrewnguonly Apr 25, 2024
3b99228
Add (failing) unit test for invoking nested graph.
andrewnguonly Apr 26, 2024
5fbf72e
Fix minor typing issues.
andrewnguonly Apr 26, 2024
ba84592
Fix skipped unit test. Implement RunnableCallable to skip invocation …
andrewnguonly Apr 26, 2024
f64ef02
Fix unit test: Throw error immediately after calling method to test t…
andrewnguonly Apr 26, 2024
57b2cce
Fix unit test that wasn't implemented correctly to catch error thrown…
andrewnguonly Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions langgraph/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
export const CONFIG_KEY_SEND = "__pregel_send";
export const CONFIG_KEY_READ = "__pregel_read";

export const TAG_HIDDEN = "langsmith:hidden";
4 changes: 2 additions & 2 deletions langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
RunnableLike,
_coerceToRunnable,
} from "@langchain/core/runnables";
import { ChannelInvoke } from "../pregel/read.js";
import { PregelNode } from "../pregel/read.js";
import { Channel, Pregel } from "../pregel/index.js";
import { BaseCheckpointSaver } from "../checkpoint/base.js";

Expand Down Expand Up @@ -179,7 +179,7 @@ export class Graph<
outgoingEdges[start].push(end !== END ? `${end}:inbox` : END);
});

const nodes: Record<string, ChannelInvoke<RunInput, RunOutput>> = {};
const nodes: Record<string, PregelNode<RunInput, RunOutput>> = {};
for (const [key, node] of Object.entries(this.nodes)) {
nodes[key] = Channel.subscribeTo(`${key}:inbox`)
.pipe(node)
Expand Down
25 changes: 13 additions & 12 deletions langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import { BaseChannel } from "../channels/base.js";
import { BinaryOperator, BinaryOperatorAggregate } from "../channels/binop.js";
import { END, Graph } from "./graph.js";
import { LastValue } from "../channels/last_value.js";
import { ChannelWrite, SKIP_WRITE } from "../pregel/write.js";
import { ChannelWrite, PASSTHROUGH, SKIP_WRITE } from "../pregel/write.js";
import { BaseCheckpointSaver } from "../checkpoint/base.js";
import { Pregel, Channel } from "../pregel/index.js";
import { ChannelInvoke, ChannelRead } from "../pregel/read.js";
import { PregelNode, ChannelRead } from "../pregel/read.js";
import { NamedBarrierValue } from "../channels/named_barrier_value.js";
import { AnyValue } from "../channels/any_value.js";
import { EphemeralValue } from "../channels/ephemeral_value.js";
Expand Down Expand Up @@ -136,12 +136,13 @@ export class StateGraph<
const updateChannels = Array.isArray(stateKeysRead)
? stateKeysRead.map((key) => ({
channel: key,
value: new RunnableLambda({
value: PASSTHROUGH,
skipNone: false,
mapper: new RunnableLambda({
func: (input) => getInputKey(key, input),
}),
skipNone: false,
}))
: [{ channel: "__root__", value: null, skipNone: true }];
: [{ channel: "__root__", value: PASSTHROUGH, skipNone: true }];

const waitingEdges: Set<[string, string[], string]> = new Set();
this.waitingEdges.forEach(([starts, end]) => {
Expand Down Expand Up @@ -170,7 +171,7 @@ export class StateGraph<
}
}

const nodes: Record<string, ChannelInvoke> = {};
const nodes: Record<string, PregelNode> = {};

for (const [key, node] of Object.entries(this.nodes)) {
const triggers = [
Expand All @@ -179,14 +180,14 @@ export class StateGraph<
.filter(([, , end]) => end === key)
.map(([chan]) => chan),
];
nodes[key] = new ChannelInvoke({
nodes[key] = new PregelNode({
triggers,
channels: stateChannels,
})
.pipe(node)
.pipe(
new ChannelWrite([
{ channel: key, value: null, skipNone: false },
{ channel: key, value: PASSTHROUGH, skipNone: false },
...updateChannels,
])
);
Expand All @@ -203,7 +204,7 @@ export class StateGraph<
const outgoing = outgoingEdges[key];
const edgesKey = `${key}:edges`;
if (outgoing || this.branches[key]) {
nodes[edgesKey] = new ChannelInvoke({
nodes[edgesKey] = new PregelNode({
triggers: [key],
tags: ["langsmith:hidden"],
channels: stateChannels,
Expand All @@ -214,7 +215,7 @@ export class StateGraph<
new ChannelWrite(
outgoing.map((dest) => ({
channel: dest,
value: dest === END ? null : key,
value: dest === END ? PASSTHROUGH : key,
skipNone: false,
}))
)
Expand All @@ -235,12 +236,12 @@ export class StateGraph<
tags: ["langsmith:hidden"],
}).pipe(
new ChannelWrite([
{ channel: START, value: null, skipNone: false },
{ channel: START, value: PASSTHROUGH, skipNone: false },
...updateChannels,
])
);

nodes[`${START}:edges`] = new ChannelInvoke({
nodes[`${START}:edges`] = new PregelNode({
triggers: [START],
tags: ["langsmith:hidden"],
channels: stateChannels,
Expand Down
11 changes: 5 additions & 6 deletions langgraph/src/pregel/debug.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Runnable } from "@langchain/core/runnables";
import { BaseChannel, EmptyChannelError } from "../channels/base.js";
import { PregelExecutableTask } from "./types.js";

type ConsoleColors = {
start: string;
Expand All @@ -25,17 +25,16 @@ const wrap = (color: ConsoleColors, text: string): string =>

export function printStepStart(
step: number,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
nextTasks: Array<[Runnable, any, string]>
nextTasks: Array<PregelExecutableTask>
): void {
const nTasks = nextTasks.length;
console.log(
`${wrap(COLORS_MAP.blue, "[pregel/step]")}`,
`${wrap(COLORS_MAP.blue, "[langgraph/step]")}`,
`Starting step ${step} with ${nTasks} task${
nTasks === 1 ? "" : "s"
}. Next tasks:\n`,
`\n${nextTasks
.map(([_, val, name]) => `- ${name}(${JSON.stringify(val, null, 2)})`)
.map((task) => `${task.name}(${JSON.stringify(task.input, null, 2)})`)
.join("\n")}`
);
}
Expand All @@ -45,7 +44,7 @@ export function printCheckpoint<Value>(
channels: Record<string, BaseChannel<Value>>
) {
console.log(
`${wrap(COLORS_MAP.blue, "[pregel/checkpoint]")}`,
`${wrap(COLORS_MAP.blue, "[langgraph/checkpoint]")}`,
`Finishing step ${step}. Channel values:\n`,
`\n${JSON.stringify(
Object.fromEntries(_readChannels<Value>(channels)),
Expand Down
72 changes: 25 additions & 47 deletions langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import {
CheckpointAt,
emptyCheckpoint,
} from "../checkpoint/base.js";
import { ChannelInvoke } from "./read.js";
import { PregelNode } from "./read.js";
import { validateGraph } from "./validate.js";
import { ReservedChannelsMap } from "./reserved.js";
import { mapInput, mapOutput } from "./io.js";
import { ChannelWrite, ChannelWriteEntry } from "./write.js";
import { mapInput, mapOutput, readChannel } from "./io.js";
import { ChannelWrite, ChannelWriteEntry, PASSTHROUGH } from "./write.js";
import { CONFIG_KEY_READ, CONFIG_KEY_SEND } from "../constants.js";
import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js";

Expand Down Expand Up @@ -58,32 +58,26 @@ export class Channel {
channels: string,
options?: {
key?: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
when?: (arg: any) => boolean;
tags?: string[];
}
): // eslint-disable-next-line @typescript-eslint/no-explicit-any
ChannelInvoke;
PregelNode;

static subscribeTo(
channels: string[],
options?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
when?: (arg: any) => boolean;
tags?: string[];
}
): ChannelInvoke;
): PregelNode;

static subscribeTo(
channels: string | string[],
options?: {
key?: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
when?: (arg: any) => boolean;
tags?: string[];
}
): ChannelInvoke {
const { key, when, tags } = options ?? {};
): PregelNode {
const { key, tags } = options ?? {};
if (Array.isArray(channels) && key !== undefined) {
throw new Error(
"Can't specify a key when subscribing to multiple channels"
Expand All @@ -106,10 +100,9 @@ export class Channel {

const triggers: string[] = Array.isArray(channels) ? channels : [channels];

return new ChannelInvoke({
return new PregelNode({
channels: channelMappingOrString,
triggers,
when,
tags,
});
}
Expand All @@ -124,20 +117,22 @@ export class Channel {
Object.entries(additionalArgs).forEach(([key, value]) => {
channelPairs.push({
channel: key,
value: _coerceWriteValue(value),
skipNone: false,
value: PASSTHROUGH,
skipNone: true,
mapper: _coerceWriteValue(value),
});
});
} else {
args.forEach((channel) => {
if (typeof channel === "string") {
channelPairs.push({ channel, value: undefined, skipNone: false });
channelPairs.push({ channel, value: PASSTHROUGH, skipNone: false });
} else if (typeof channel === "object") {
Object.entries(channel).forEach(([key, value]) => {
channelPairs.push({
channel: key,
value: _coerceWriteValue(value),
skipNone: false,
value: PASSTHROUGH,
skipNone: true,
mapper: _coerceWriteValue(value),
});
});
}
Expand Down Expand Up @@ -174,7 +169,7 @@ export interface PregelInterface {
*/
interrupt?: string[];

nodes: Record<string, ChannelInvoke>;
nodes: Record<string, PregelNode>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
checkpointer?: BaseCheckpointSaver<any>;
Expand Down Expand Up @@ -213,7 +208,7 @@ export class Pregel

debug: boolean = false;

nodes: Record<string, ChannelInvoke>;
nodes: Record<string, PregelNode>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
checkpointer?: BaseCheckpointSaver<any>;
Expand Down Expand Up @@ -293,7 +288,7 @@ export class Pregel

_applyWrites(checkpoint, channels, inputPendingWrites, config, 0);

const read = (chan: string) => _readChannel(channels, chan);
const read = (chan: string) => readChannel(channels, chan);

// Similarly to Bulk Synchronous Parallel / Pregel model
// computation proceeds in steps, while there are channel updates
Expand Down Expand Up @@ -458,21 +453,6 @@ async function executeTasks<RunOutput>(
}
}

function _readChannel(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved to io.ts.

channels: Record<string, BaseChannel>,
chan: string
): unknown | null {
try {
return channels[chan].get();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
if (e.name === EmptyChannelError.name) {
return null;
}
throw e;
}
}

function _applyWrites(
checkpoint: Checkpoint,
channels: Record<string, BaseChannel>,
Expand Down Expand Up @@ -536,7 +516,7 @@ function _applyWritesFromView(
values: Record<string, unknown>
) {
for (const [chan, val] of Object.entries(values)) {
if (val === _readChannel(channels, chan)) {
if (val === readChannel(channels, chan)) {
continue;
}

Expand All @@ -550,14 +530,14 @@ function _applyWritesFromView(

function _prepareNextTasks(
checkpoint: Checkpoint,
processes: Record<string, ChannelInvoke>,
processes: Record<string, PregelNode>,
channels: Record<string, BaseChannel>
): Array<[RunnableInterface, unknown, string]> {
const tasks: Array<[RunnableInterface, unknown, string]> = [];

// Check if any processes should be run in next step
// If so, prepare the values to be passed to them
for (const [name, proc] of Object.entries<ChannelInvoke>(processes)) {
for (const [name, proc] of Object.entries<PregelNode>(processes)) {
let seen: Record<string, number> = checkpoint.versionsSeen[name];
if (!seen) {
checkpoint.versionsSeen[name] = {};
Expand All @@ -575,10 +555,10 @@ function _prepareNextTasks(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let val: Record<string, any> = {};
if (typeof proc.channels === "string") {
val[proc.channels] = _readChannel(channels, proc.channels);
val[proc.channels] = readChannel(channels, proc.channels);
} else {
for (const [k, chan] of Object.entries(proc.channels)) {
val[k] = _readChannel(channels, chan);
val[k] = readChannel(channels, chan);
}
}

Expand All @@ -601,10 +581,8 @@ function _prepareNextTasks(
}
});

// skip if condition is not met
if (proc.when === undefined || proc.when(val)) {
tasks.push([proc, val, name]);
}
tasks.push([proc, val, name]);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is valid, but all of the existing tests are passing.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would it not be valid?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just comparing with the current Python implementation which relies on the for_execution API, which isn't implemented yet in JS.

https://github.com/langchain-ai/langgraph/blob/main/langgraph/pregel/__init__.py#L1287


// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (error: any) {
if (error.name === EmptyChannelError.name) {
Expand Down
Loading