Skip to content

Commit 5fda2ad

Browse files
authored
Merge pull request #148 from langchain-ai/nc/16may/update-state
Implement getState, updateState, getStateHistory
2 parents 0dc8f85 + d56d097 commit 5fda2ad

File tree

5 files changed

+185
-12
lines changed

5 files changed

+185
-12
lines changed

langgraph/src/checkpoint/base.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ export abstract class BaseCheckpointSaver {
120120
config: RunnableConfig
121121
): Promise<CheckpointTuple | undefined>;
122122

123-
abstract list(config: RunnableConfig): AsyncGenerator<CheckpointTuple>;
123+
abstract list(
124+
config: RunnableConfig,
125+
limit?: number,
126+
before?: RunnableConfig
127+
): AsyncGenerator<CheckpointTuple>;
124128

125129
abstract put(
126130
config: RunnableConfig,

langgraph/src/checkpoint/memory.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,21 @@ export class MemorySaver extends BaseCheckpointSaver {
4646
return undefined;
4747
}
4848

49-
async *list(config: RunnableConfig): AsyncGenerator<CheckpointTuple> {
49+
async *list(
50+
config: RunnableConfig,
51+
limit?: number,
52+
before?: RunnableConfig
53+
): AsyncGenerator<CheckpointTuple> {
5054
const thread_id = config.configurable?.thread_id;
5155
const checkpoints = this.storage[thread_id] ?? {};
5256

5357
// sort in desc order
54-
for (const [checkpoint_id, checkpoint] of Object.entries(checkpoints).sort(
55-
(a, b) => b[0].localeCompare(a[0])
56-
)) {
58+
for (const [checkpoint_id, checkpoint] of Object.entries(checkpoints)
59+
.filter((c) =>
60+
before ? c[0] < before.configurable?.checkpoint_id : true
61+
)
62+
.sort((a, b) => b[0].localeCompare(a[0]))
63+
.slice(0, limit)) {
5764
yield {
5865
config: { configurable: { thread_id, checkpoint_id } },
5966
checkpoint: this.serde.parse(checkpoint[0]) as Checkpoint,

langgraph/src/checkpoint/sqlite.ts

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,25 @@ CREATE TABLE IF NOT EXISTS checkpoints (
120120
return undefined;
121121
}
122122

123-
async *list(config: RunnableConfig): AsyncGenerator<CheckpointTuple> {
123+
async *list(
124+
config: RunnableConfig,
125+
limit?: number,
126+
before?: RunnableConfig
127+
): AsyncGenerator<CheckpointTuple> {
124128
this.setup();
125129
const thread_id = config.configurable?.thread_id;
130+
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ${
131+
before ? "AND checkpoint_id < ?" : ""
132+
} ORDER BY checkpoint_id DESC`;
133+
if (limit) {
134+
sql += ` LIMIT ${limit}`;
135+
}
136+
const args = [thread_id, before?.configurable?.checkpoint_id].filter(
137+
Boolean
138+
);
126139

127140
try {
128-
const rows: Row[] = this.db
129-
.prepare(
130-
`SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ORDER BY checkpoint_id DESC`
131-
)
132-
.all(thread_id) as Row[];
141+
const rows: Row[] = this.db.prepare(sql).all(...args) as Row[];
133142

134143
if (rows) {
135144
for (const row of rows) {

langgraph/src/pregel/index.ts

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
RunnableConfig,
55
RunnableFunc,
66
RunnableLike,
7+
RunnableSequence,
78
_coerceToRunnable,
89
ensureConfig,
910
patchConfig,
@@ -40,7 +41,12 @@ import {
4041
TAG_HIDDEN,
4142
} from "../constants.js";
4243
import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js";
43-
import { All, PregelExecutableTask, PregelTaskDescription } from "./types.js";
44+
import {
45+
All,
46+
PregelExecutableTask,
47+
PregelTaskDescription,
48+
StateSnapshot,
49+
} from "./types.js";
4450
import {
4551
EmptyChannelError,
4652
GraphRecursionError,
@@ -308,6 +314,148 @@ export class Pregel<
308314
}
309315
}
310316

317+
async getState(config: RunnableConfig): Promise<StateSnapshot> {
318+
if (!this.checkpointer) {
319+
throw new GraphValueError("No checkpointer set");
320+
}
321+
322+
const saved = await this.checkpointer.getTuple(config);
323+
const checkpoint = saved ? saved.checkpoint : emptyCheckpoint();
324+
const channels = emptyChannels(this.channels, checkpoint);
325+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
326+
const [_, nextTasks] = _prepareNextTasks(
327+
checkpoint,
328+
this.nodes,
329+
channels,
330+
false
331+
);
332+
return {
333+
values: readChannels(channels, this.streamChannelsAsIs),
334+
next: nextTasks.map((task) => task.name),
335+
metadata: saved?.metadata,
336+
config: saved ? saved.config : config,
337+
parentConfig: saved?.parentConfig,
338+
};
339+
}
340+
341+
async *getStateHistory(
342+
config: RunnableConfig,
343+
limit?: number,
344+
before?: RunnableConfig
345+
): AsyncIterableIterator<StateSnapshot> {
346+
if (!this.checkpointer) {
347+
throw new GraphValueError("No checkpointer set");
348+
}
349+
for await (const saved of this.checkpointer.list(config, limit, before)) {
350+
const channels = emptyChannels(this.channels, saved.checkpoint);
351+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
352+
const [_, nextTasks] = _prepareNextTasks(
353+
saved.checkpoint,
354+
this.nodes,
355+
channels,
356+
false
357+
);
358+
yield {
359+
values: readChannels(channels, this.streamChannelsAsIs),
360+
next: nextTasks.map((task) => task.name),
361+
metadata: saved.metadata,
362+
config: saved.config,
363+
parentConfig: saved.parentConfig,
364+
};
365+
}
366+
}
367+
368+
async updateState(
369+
config: RunnableConfig,
370+
values: Record<string, unknown> | unknown,
371+
asNode?: keyof Nn
372+
): Promise<RunnableConfig> {
373+
if (!this.checkpointer) {
374+
throw new GraphValueError("No checkpointer set");
375+
}
376+
377+
// Get the latest checkpoint
378+
const saved = await this.checkpointer.getTuple(config);
379+
const checkpoint = saved
380+
? copyCheckpoint(saved.checkpoint)
381+
: emptyCheckpoint();
382+
// Find last that updated the state, if not provided
383+
const maxSeens = Object.entries(checkpoint.versions_seen).reduce(
384+
(acc, [node, versions]) => {
385+
const maxSeen = Math.max(...Object.values(versions));
386+
if (maxSeen) {
387+
if (!acc[maxSeen]) {
388+
acc[maxSeen] = [];
389+
}
390+
acc[maxSeen].push(node);
391+
}
392+
return acc;
393+
},
394+
{} as Record<number, string[]>
395+
);
396+
if (!asNode && !Object.keys(maxSeens).length) {
397+
if (!Array.isArray(this.inputs) && this.inputs in this.nodes) {
398+
asNode = this.inputs as keyof Nn;
399+
}
400+
} else if (!asNode) {
401+
const maxSeen = Math.max(...Object.keys(maxSeens).map(Number));
402+
const nodes = maxSeens[maxSeen];
403+
if (nodes.length === 1) {
404+
asNode = nodes[0] as keyof Nn;
405+
}
406+
}
407+
if (!asNode) {
408+
throw new InvalidUpdateError("Ambiguous update, specify as_node");
409+
}
410+
// update channels
411+
const channels = emptyChannels(this.channels, checkpoint);
412+
// create task to run all writers of the chosen node
413+
const writers = this.nodes[asNode].getWriters();
414+
if (!writers.length) {
415+
throw new InvalidUpdateError(
416+
`No writers found for node ${asNode as string}`
417+
);
418+
}
419+
const task: PregelExecutableTask<keyof Nn, keyof Cc> = {
420+
name: asNode,
421+
input: values,
422+
proc:
423+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
424+
writers.length > 1 ? RunnableSequence.from(writers as any) : writers[0],
425+
writes: [],
426+
config: undefined,
427+
};
428+
// execute task
429+
await task.proc.invoke(
430+
task.input,
431+
patchConfig(config, {
432+
runName: `${this.name}UpdateState`,
433+
configurable: {
434+
[CONFIG_KEY_SEND]: (items: [keyof Cc, unknown][]) =>
435+
task.writes.push(...items),
436+
[CONFIG_KEY_READ]: _localRead.bind(
437+
undefined,
438+
checkpoint,
439+
channels,
440+
task.writes as Array<[string, unknown]>
441+
),
442+
},
443+
})
444+
);
445+
// apply to checkpoint and save
446+
_applyWrites(checkpoint, channels, task.writes);
447+
const step = (saved?.metadata?.step ?? -2) + 1;
448+
return await this.checkpointer.put(
449+
saved?.config ?? config,
450+
createCheckpoint(checkpoint, channels, step),
451+
{
452+
source: "update",
453+
step,
454+
writes: { [asNode]: values },
455+
}
456+
);
457+
}
458+
311459
_defaults(config: PregelOptions<Nn, Cc>): [
312460
boolean, // debug
313461
StreamMode, // stream mode

langgraph/src/pregel/types.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { Runnable, RunnableConfig } from "@langchain/core/runnables";
2+
import { CheckpointMetadata } from "../checkpoint/base.js";
23

34
export interface PregelTaskDescription {
45
readonly name: string;
@@ -30,6 +31,10 @@ export interface StateSnapshot {
3031
* Config used to fetch this snapshot
3132
*/
3233
readonly config: RunnableConfig;
34+
/**
35+
* Metadata about the checkpoint
36+
*/
37+
readonly metadata?: CheckpointMetadata;
3338
/**
3439
* Config used to fetch the parent snapshot, if any
3540
* @default undefined

0 commit comments

Comments
 (0)