Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion langgraph/src/checkpoint/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export interface CheckpointMetadata {
* -1 for the first "input" checkpoint.
* 0 for the first "loop" checkpoint.
* ... for the nth checkpoint afterwards. */
writes?: Record<string, unknown>;
writes: Record<string, unknown> | null;
/**
* The writes that were made between the previous checkpoint and this one.
* Mapping from node name to writes emitted by that node.
Expand Down
28 changes: 2 additions & 26 deletions langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export interface BranchOptions<IO, N extends string> {
source: N;
path: Branch<IO, N>["condition"];
pathMap?: Record<string, N | typeof END> | N[];
then?: N | typeof END;
}

export class Branch<IO, N extends string> {
Expand All @@ -33,8 +32,6 @@ export class Branch<IO, N extends string> {

ends?: Record<string, N | typeof END>;

then?: BranchOptions<IO, N>["then"];

constructor(options: Omit<BranchOptions<IO, N>, "source">) {
this.condition = options.path;
this.ends = Array.isArray(options.pathMap)
Expand All @@ -43,7 +40,6 @@ export class Branch<IO, N extends string> {
return acc;
}, {} as Record<string, N | typeof END>)
: options.pathMap;
this.then = options.then;
}

compile(
Expand Down Expand Up @@ -266,25 +262,8 @@ export class Graph<
validate(interrupt?: string[]): void {
// assemble sources
const allSources = new Set([...this.allEdges].map(([src, _]) => src));
for (const [start, branches] of Object.entries(this.branches)) {
for (const [start] of Object.entries(this.branches)) {
allSources.add(start);
for (const branch of Object.values(branches)) {
if (branch.then) {
if (branch.ends) {
for (const end of Object.values(branch.ends)) {
if (end !== END) {
allSources.add(end);
}
}
} else {
for (const node of Object.keys(this.nodes)) {
if (node !== start) {
allSources.add(node);
}
}
}
}
}
}
// validate sources
for (const node of Object.keys(this.nodes)) {
Expand All @@ -302,17 +281,14 @@ export class Graph<
const allTargets = new Set([...this.allEdges].map(([_, target]) => target));
for (const [start, branches] of Object.entries(this.branches)) {
for (const branch of Object.values(branches)) {
if (branch.then) {
allTargets.add(branch.then);
}
if (branch.ends) {
for (const end of Object.values(branch.ends)) {
allTargets.add(end);
}
} else {
allTargets.add(END);
for (const node of Object.keys(this.nodes)) {
if (node !== start && node !== branch.then) {
if (node !== start) {
allTargets.add(node);
}
}
Expand Down
24 changes: 1 addition & 23 deletions langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import { RunnableCallable } from "../utils.js";
import { All } from "../pregel/types.js";
import { TAG_HIDDEN } from "../constants.js";
import { InvalidUpdateError } from "../errors.js";
import { DynamicBarrierValue } from "../channels/dynamic_barrier_value.js";

const ROOT = "__root__";

Expand Down Expand Up @@ -341,12 +340,6 @@ export class CompiledStateGraph<
channel: `branch:${start}:${name}:${dest}`,
value: start,
}));
if (branch.then && branch.then !== END) {
writes.push({
channel: `branch:${start}:${name}:then`,
value: { __names: filteredDests },
});
}
return new ChannelWrite(writes, [TAG_HIDDEN]);
},
// reader
Expand All @@ -357,7 +350,7 @@ export class CompiledStateGraph<
// attach branch subscribers
const ends = branch.ends
? Object.values(branch.ends)
: Object.keys(this.builder.nodes).filter((n) => n !== branch.then);
: Object.keys(this.builder.nodes);
for (const end of ends) {
if (end === END) {
continue;
Expand All @@ -367,20 +360,5 @@ export class CompiledStateGraph<
new EphemeralValue();
this.nodes[end as N].triggers.push(channelName);
}

if (branch.then && branch.then !== END) {
const channelName = `branch:${start}:${name}:then`;
(this.channels as Record<string, BaseChannel>)[channelName] =
new DynamicBarrierValue();
this.nodes[branch.then].triggers.push(channelName);
for (const end of ends) {
if (end === END) {
continue;
}
this.nodes[end as N].writers.push(
new ChannelWrite([{ channel: channelName, value: end }], [TAG_HIDDEN])
);
}
}
}
}
2 changes: 1 addition & 1 deletion langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ export class Pregel<
source: "loop",
step,
writes: single(
streamMode === "values"
this.streamMode === "values"
? mapOutputValues(outputKeys, pendingWrites, channels)
: mapOutputUpdates(outputKeys, nextTasks)
),
Expand Down
4 changes: 2 additions & 2 deletions langgraph/src/pregel/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ export function* mapOutputUpdates<N extends PropertyKey, C extends PropertyKey>(
}
}

export function single<T>(iter: IterableIterator<T>): T | undefined {
export function single<T>(iter: IterableIterator<T>): T | null {
// eslint-disable-next-line no-unreachable-loop
for (const value of iter) {
return value;
}
return undefined;
return null;
}
7 changes: 4 additions & 3 deletions langgraph/src/tests/checkpoints.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ describe("MemorySaver", () => {
const runnableConfig = await memorySaver.put(
{ configurable: { thread_id: "1" } },
checkpoint1,
{ source: "update", step: -1 }
{ source: "update", step: -1, writes: null }
);
expect(runnableConfig).toEqual({
configurable: {
Expand All @@ -110,6 +110,7 @@ describe("MemorySaver", () => {
await memorySaver.put({ configurable: { thread_id: "1" } }, checkpoint2, {
source: "update",
step: -1,
writes: null,
});

// list checkpoints
Expand Down Expand Up @@ -143,7 +144,7 @@ describe("SqliteSaver", () => {
const runnableConfig = await sqliteSaver.put(
{ configurable: { thread_id: "1" } },
checkpoint1,
{ source: "update", step: -1 }
{ source: "update", step: -1, writes: null }
);
expect(runnableConfig).toEqual({
configurable: {
Expand Down Expand Up @@ -174,7 +175,7 @@ describe("SqliteSaver", () => {
},
},
checkpoint2,
{ source: "update", step: -1 }
{ source: "update", step: -1, writes: null }
);

// verify that parentTs is set and retrieved correctly for second checkpoint
Expand Down
117 changes: 49 additions & 68 deletions langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2111,8 +2111,9 @@ it("StateGraph start branch then end", async () => {
source: START,
path: (state: State) =>
state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
then: END,
});
})
.addEdge("tool_two_fast", END)
.addEdge("tool_two_slow", END);

const toolTwo = toolTwoBuilder.compile();

Expand All @@ -2134,73 +2135,52 @@ it("StateGraph start branch then end", async () => {
toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" })
).rejects.toThrowError("thread_id");

// const thread1 = { configurable: { thread_id: "1" } }
// expect(toolTwoWithCheckpointer.invoke({ my_key: "value", market: "DE" }, thread1)).toEqual({ my_key: "value", market: "DE" })
// expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
// values: { my_key: "value", market: "DE" },
// next: ["tool_two_slow"],
// config: toolTwoWithCheckpointer.checkpointer.getTuple(thread1).config,
// metadata: { source: "loop", step: 0, writes: null },
// parentConfig: [...toolTwoWithCheckpointer.checkpointer.list(thread1, { limit: 2 })].pop().config
// })

// expect(toolTwoWithCheckpointer.invoke(null, thread1, { debug: 1 })).toEqual({ my_key: "value slow", market: "DE" })
// expect(toolTwoWithCheckpointer.getState(thread1)).toEqual({
// values: { my_key
// : "value slow", market: "DE" },
// next: [],
// config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))!.config,
// metadata: { source: "loop", step: 1, writes: { tool_two_slow: { my_key: " slow" } } },
// parentConfig: [...toolTwoWithCheckpointer.checkpointer!.list(thread1, { limit: 2 })].pop().config
});
async function last<T>(iter: AsyncIterableIterator<T>): Promise<T> {
// eslint-disable-next-line no-undef-init
let value: T | undefined = undefined;
for await (value of iter) {
// do nothing
}
return value as T;
}

/**
* def test_branch_then_node(snapshot: SnapshotAssertion) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

# this graph is invalid because there is no path to "finish"
invalid_graph = StateGraph(State)
invalid_graph.set_entry_point("prepare")
invalid_graph.set_finish_point("finish")
invalid_graph.add_conditional_edges(
source="prepare",
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
path_map=["tool_two_slow", "tool_two_fast"],
)
invalid_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
invalid_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
invalid_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
invalid_graph.add_node("finish", lambda s: {"my_key": " finished"})
with pytest.raises(ValueError):
invalid_graph.compile()

tool_two_graph = StateGraph(State)
tool_two_graph.set_entry_point("prepare")
tool_two_graph.set_finish_point("finish")
tool_two_graph.add_conditional_edges(
source="prepare",
path=lambda s: "tool_two_slow" if s["market"] == "DE" else "tool_two_fast",
then="finish",
const thread1 = { configurable: { thread_id: "1" } };
expect(
await toolTwoWithCheckpointer.invoke(
{ my_key: "value", market: "DE" },
thread1
)
tool_two_graph.add_node("prepare", lambda s: {"my_key": " prepared"})
tool_two_graph.add_node("tool_two_slow", lambda s: {"my_key": " slow"})
tool_two_graph.add_node("tool_two_fast", lambda s: {"my_key": " fast"})
tool_two_graph.add_node("finish", lambda s: {"my_key": " finished"})
tool_two = tool_two_graph.compile()
assert tool_two.get_graph().draw_mermaid(with_styles=False) == snapshot
assert tool_two.get_graph().draw_mermaid() == snapshot

assert tool_two.invoke({"my_key": "value", "market": "DE"}, debug=1) == {
"my_key": "value prepared slow finished",
"market": "DE",
}
assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value prepared fast finished",
"market": "US",
}
*/
).toEqual({ my_key: "value", market: "DE" });
expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
values: { my_key: "value", market: "DE" },
next: ["tool_two_slow"],
config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))!
.config,
metadata: { source: "loop", step: 0, writes: null },
parentConfig: (
await last(toolTwoWithCheckpointer.checkpointer!.list(thread1, 2))
).config,
});

expect(await toolTwoWithCheckpointer.invoke(null, thread1)).toEqual({
my_key: "value slow",
market: "DE",
});
expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({
values: { my_key: "value slow", market: "DE" },
next: [],
config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))!
.config,
metadata: {
source: "loop",
step: 1,
writes: { tool_two_slow: { my_key: " slow" } },
},
parentConfig: (
await last(toolTwoWithCheckpointer.checkpointer!.list(thread1, 2))
).config,
});
});

it("StateGraph branch then node", async () => {
interface State {
Expand Down Expand Up @@ -2244,8 +2224,9 @@ it("StateGraph branch then node", async () => {
source: "prepare",
path: (state: State) =>
state.market === "DE" ? "tool_two_slow" : "tool_two_fast",
then: "finish",
})
.addEdge("tool_two_fast", "finish")
.addEdge("tool_two_slow", "finish")
.addEdge("finish", END);

const tool = toolBuilder.compile();
Expand Down