Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions langgraph/src/channels/dynamic_barrier_value.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import { EmptyChannelError, InvalidUpdateError } from "../errors.js";
import { BaseChannel } from "./index.js";
import { areSetsEqual } from "./named_barrier_value.js";

export interface WaitForNames<Value> {
__names: Value[];
}

/**
A channel that switches between two states

- in the "priming" state it can't be read from.
- if it receives a WaitForNames update, it switches to the "waiting" state.
- in the "waiting" state it collects named values until all are received.
- once all named values are received, it can be read once, and it switches
back to the "priming" state.
*/
export class DynamicBarrierValue<Value> extends BaseChannel<
void,
Value | WaitForNames<Value>,
[Value[] | undefined, Value[]]
> {
lc_graph_name = "DynamicBarrierValue";

names?: Set<Value>; // Names of nodes that we want to wait for.

seen: Set<Value>;

constructor() {
super();
this.names = undefined;
this.seen = new Set<Value>();
}

fromCheckpoint(checkpoint?: [Value[] | undefined, Value[]]) {
const empty = new DynamicBarrierValue<Value>();
if (checkpoint) {
empty.names = new Set(checkpoint[0]);
empty.seen = new Set(checkpoint[1]);
}
return empty as this;
}

update(values: (Value | WaitForNames<Value>)[]): void {
// switch to priming state after reading it once
if (this.names && areSetsEqual(this.names, this.seen)) {
this.seen = new Set<Value>();
this.names = undefined;
}

const newNames = values.filter(
(v) =>
typeof v === "object" &&
!!v &&
"__names" in v &&
Object.keys(v).join(",") === "__names" &&
Array.isArray(v.__names)
) as WaitForNames<Value>[];

if (newNames.length > 1) {
throw new InvalidUpdateError(
`Expected at most one WaitForNames object, got ${newNames.length}`
);
} else if (newNames.length === 1) {
this.names = new Set(newNames[0].__names);
} else if (this.names) {
for (const value of values) {
if (this.names.has(value as Value)) {
this.seen.add(value as Value);
} else {
throw new InvalidUpdateError(
`Value ${value} not in names ${this.names}`
);
}
}
}
}

// If we have not yet seen all the node names we want to wait for,
// throw an error to prevent continuing.
get(): void {
if (!this.names || !areSetsEqual(this.names, this.seen)) {
throw new EmptyChannelError();
}
return undefined;
}

checkpoint(): [Value[] | undefined, Value[]] {
return [this.names ? [...this.names] : undefined, [...this.seen]];
}
}
23 changes: 12 additions & 11 deletions langgraph/src/channels/named_barrier_value.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import { EmptyChannelError } from "../errors.js";
import { BaseChannel } from "./index.js";

const areSetsEqual = (a: Set<unknown>, b: Set<unknown>) =>
export const areSetsEqual = <T>(a: Set<T>, b: Set<T>) =>
a.size === b.size && [...a].every((value) => b.has(value));

/**
* A channel that waits until all named values are received before making the value available.
*
* This ensures that if node N and node M both write to channel C, the value of C will not be updated
* until N and M have completed updating.
*/
export class NamedBarrierValue<Value> extends BaseChannel<
void,
Value,
Value,
Set<Value>
Value[]
> {
lc_graph_name = "NamedBarrierValue";

Expand All @@ -26,10 +27,10 @@ export class NamedBarrierValue<Value> extends BaseChannel<
this.seen = new Set<Value>();
}

fromCheckpoint(checkpoint?: Set<Value>) {
fromCheckpoint(checkpoint?: Value[]) {
const empty = new NamedBarrierValue<Value>(this.names);
if (checkpoint) {
empty.seen = checkpoint;
empty.seen = new Set(checkpoint);
}
return empty as this;
}
Expand All @@ -52,16 +53,16 @@ export class NamedBarrierValue<Value> extends BaseChannel<
}
}

// If we have not yet seen all the node names we want to wait for, throw an error to
// prevent continuing.
get(): Value {
// If we have not yet seen all the node names we want to wait for,
// throw an error to prevent continuing.
get(): void {
if (!areSetsEqual(this.names, this.seen)) {
throw new EmptyChannelError();
}
return undefined as Value;
return undefined;
}

checkpoint(): Set<Value> {
return this.seen;
checkpoint(): Value[] {
return [...this.seen];
}
}
12 changes: 7 additions & 5 deletions langgraph/src/channels/topic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function* flatten<Value>(
export class Topic<Value> extends BaseChannel<
Array<Value>,
Value | Value[],
[Set<Value>, Value[]]
[Value[], Value[]]
> {
lc_graph_name = "Topic";

Expand All @@ -37,13 +37,15 @@ export class Topic<Value> extends BaseChannel<
this.values = [];
}

public fromCheckpoint(checkpoint?: [Set<Value>, Value[]]) {
public fromCheckpoint(checkpoint?: [Value[], Value[]]) {
const empty = new Topic<Value>({
unique: this.unique,
accumulate: this.accumulate,
});
if (checkpoint) {
[empty.seen, empty.values] = checkpoint;
empty.seen = new Set(checkpoint[0]);
// eslint-disable-next-line prefer-destructuring
empty.values = checkpoint[1];
}
return empty as this;
}
Expand Down Expand Up @@ -71,7 +73,7 @@ export class Topic<Value> extends BaseChannel<
return this.values;
}

public checkpoint(): [Set<Value>, Array<Value>] {
return [this.seen, this.values];
public checkpoint(): [Value[], Value[]] {
return [[...this.seen], this.values];
}
}
82 changes: 52 additions & 30 deletions langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ import { RunnableCallable } from "../utils.js";
export const START = "__start__";
export const END = "__end__";

export interface BranchOptions<IO, N extends string> {
source: N;
path: Branch<IO, N>["condition"];
pathMap?: Record<string, N | typeof END> | N[];
then?: N | typeof END;
}

export class Branch<IO, N extends string> {
condition: (
input: IO,
Expand All @@ -26,12 +33,17 @@ export class Branch<IO, N extends string> {

ends?: Record<string, N | typeof END>;

constructor(
condition: Branch<IO, N>["condition"],
ends?: Branch<IO, N>["ends"]
) {
this.condition = condition;
this.ends = ends;
then?: BranchOptions<IO, N>["then"];

constructor(options: Omit<BranchOptions<IO, N>, "source">) {
this.condition = options.path;
this.ends = Array.isArray(options.pathMap)
? options.pathMap.reduce((acc, n) => {
acc[n] = n;
return acc;
}, {} as Record<string, N | typeof END>)
: options.pathMap;
this.then = options.then;
}

compile(
Expand Down Expand Up @@ -144,30 +156,37 @@ export class Graph<
return this;
}

addConditionalEdges(source: BranchOptions<RunInput, N>): this;

addConditionalEdges(
startKey: N,
condition: Branch<RunInput, N>["condition"],
conditionalEdgeMapping?: Record<string, N | typeof END>
source: N,
path: Branch<RunInput, N>["condition"],
pathMap?: BranchOptions<RunInput, N>["pathMap"]
): this;

addConditionalEdges(
source: N | BranchOptions<RunInput, N>,
path?: Branch<RunInput, N>["condition"],
pathMap?: BranchOptions<RunInput, N>["pathMap"]
): this {
const options: BranchOptions<RunInput, N> =
typeof source === "object" ? source : { source, path: path!, pathMap };
this.warnIfCompiled(
"Adding an edge to a graph that has already been compiled. This will not be reflected in the compiled graph."
);
// find a name for condition
const name = condition.name || "condition";
const name = options.path.name || "condition";
// validate condition
if (this.branches[startKey] && this.branches[startKey][name]) {
if (this.branches[options.source] && this.branches[options.source][name]) {
throw new Error(
`Condition \`${name}\` already present for node \`${startKey}\``
`Condition \`${name}\` already present for node \`${source}\``
);
}
// save it
if (!this.branches[startKey]) {
this.branches[startKey] = {};
if (!this.branches[options.source]) {
this.branches[options.source] = {};
}
this.branches[startKey][name] = new Branch(
condition,
conditionalEdgeMapping
);
this.branches[options.source][name] = new Branch(options);
return this;
}

Expand Down Expand Up @@ -250,17 +269,18 @@ export class Graph<
for (const [start, branches] of Object.entries(this.branches)) {
allSources.add(start);
for (const branch of Object.values(branches)) {
// TODO revise when adding branch.then
if (branch.ends) {
for (const end of Object.values(branch.ends)) {
if (end !== END) {
allSources.add(end);
if (branch.then) {
if (branch.ends) {
for (const end of Object.values(branch.ends)) {
if (end !== END) {
allSources.add(end);
}
}
}
} else {
for (const node of Object.keys(this.nodes)) {
if (node !== start) {
allSources.add(node);
} else {
for (const node of Object.keys(this.nodes)) {
if (node !== start) {
allSources.add(node);
}
}
}
}
Expand All @@ -282,15 +302,17 @@ export class Graph<
const allTargets = new Set([...this.allEdges].map(([_, target]) => target));
for (const [start, branches] of Object.entries(this.branches)) {
for (const branch of Object.values(branches)) {
// TODO revise when adding branch.then
if (branch.then) {
allTargets.add(branch.then);
}
if (branch.ends) {
for (const end of Object.values(branch.ends)) {
allTargets.add(end);
}
} else {
allTargets.add(END);
for (const node of Object.keys(this.nodes)) {
if (node !== start) {
if (node !== start && node !== branch.then) {
allTargets.add(node);
}
}
Expand Down
25 changes: 22 additions & 3 deletions langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { RunnableCallable } from "../utils.js";
import { All } from "../pregel/types.js";
import { TAG_HIDDEN } from "../constants.js";
import { InvalidUpdateError } from "../errors.js";
import { DynamicBarrierValue } from "../channels/dynamic_barrier_value.js";

const ROOT = "__root__";

Expand Down Expand Up @@ -340,7 +341,12 @@ export class CompiledStateGraph<
channel: `branch:${start}:${name}:${dest}`,
value: start,
}));
// TODO implement branch.then
if (branch.then && branch.then !== END) {
writes.push({
channel: `branch:${start}:${name}:then`,
value: { __names: filteredDests },
});
}
return new ChannelWrite(writes, [TAG_HIDDEN]);
},
// reader
Expand All @@ -351,7 +357,7 @@ export class CompiledStateGraph<
// attach branch subscribers
const ends = branch.ends
? Object.values(branch.ends)
: Object.keys(this.builder.nodes);
: Object.keys(this.builder.nodes).filter((n) => n !== branch.then);
for (const end of ends) {
if (end === END) {
continue;
Expand All @@ -362,6 +368,19 @@ export class CompiledStateGraph<
this.nodes[end as N].triggers.push(channelName);
}

// TODO: implement branch.then
if (branch.then && branch.then !== END) {
const channelName = `branch:${start}:${name}:then`;
(this.channels as Record<string, BaseChannel>)[channelName] =
new DynamicBarrierValue();
this.nodes[branch.then].triggers.push(channelName);
for (const end of ends) {
if (end === END) {
continue;
}
this.nodes[end as N].writers.push(
new ChannelWrite([{ channel: channelName, value: end }], [TAG_HIDDEN])
);
}
}
}
}
4 changes: 2 additions & 2 deletions langgraph/src/prebuilt/agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Runnable, type RunnableConfig } from "@langchain/core/runnables";
import { Tool } from "@langchain/core/tools";
import { ToolExecutor } from "./tool_executor.js";
import { StateGraph } from "../graph/state.js";
import { END } from "../index.js";
import { END, START } from "../index.js";

interface Step {
action: AgentAction | AgentFinish;
Expand Down Expand Up @@ -82,7 +82,7 @@ export function createAgentExecutor({
.addNode("action", executeTools)
// Set the entrypoint as `agent`
// This means that this node is the first one called
.setEntryPoint("agent")
.addEdge(START, "agent")
// We now add a conditional edge
.addConditionalEdges(
// First, we define the start node. We use `agent`.
Expand Down
Loading