|
4 | 4 | RunnableConfig, |
5 | 5 | RunnableFunc, |
6 | 6 | RunnableLike, |
| 7 | + RunnableSequence, |
7 | 8 | _coerceToRunnable, |
8 | 9 | ensureConfig, |
9 | 10 | patchConfig, |
@@ -40,7 +41,12 @@ import { |
40 | 41 | TAG_HIDDEN, |
41 | 42 | } from "../constants.js"; |
42 | 43 | import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js"; |
43 | | -import { All, PregelExecutableTask, PregelTaskDescription } from "./types.js"; |
| 44 | +import { |
| 45 | + All, |
| 46 | + PregelExecutableTask, |
| 47 | + PregelTaskDescription, |
| 48 | + StateSnapshot, |
| 49 | +} from "./types.js"; |
44 | 50 | import { |
45 | 51 | EmptyChannelError, |
46 | 52 | GraphRecursionError, |
@@ -308,6 +314,148 @@ export class Pregel< |
308 | 314 | } |
309 | 315 | } |
310 | 316 |
|
| 317 | + async getState(config: RunnableConfig): Promise<StateSnapshot> { |
| 318 | + if (!this.checkpointer) { |
| 319 | + throw new GraphValueError("No checkpointer set"); |
| 320 | + } |
| 321 | + |
| 322 | + const saved = await this.checkpointer.getTuple(config); |
| 323 | + const checkpoint = saved ? saved.checkpoint : emptyCheckpoint(); |
| 324 | + const channels = emptyChannels(this.channels, checkpoint); |
| 325 | + // eslint-disable-next-line @typescript-eslint/no-unused-vars |
| 326 | + const [_, nextTasks] = _prepareNextTasks( |
| 327 | + checkpoint, |
| 328 | + this.nodes, |
| 329 | + channels, |
| 330 | + false |
| 331 | + ); |
| 332 | + return { |
| 333 | + values: readChannels(channels, this.streamChannelsAsIs), |
| 334 | + next: nextTasks.map((task) => task.name), |
| 335 | + metadata: saved?.metadata, |
| 336 | + config: saved ? saved.config : config, |
| 337 | + parentConfig: saved?.parentConfig, |
| 338 | + }; |
| 339 | + } |
| 340 | + |
| 341 | + async *getStateHistory( |
| 342 | + config: RunnableConfig, |
| 343 | + limit?: number, |
| 344 | + before?: RunnableConfig |
| 345 | + ): AsyncIterableIterator<StateSnapshot> { |
| 346 | + if (!this.checkpointer) { |
| 347 | + throw new GraphValueError("No checkpointer set"); |
| 348 | + } |
| 349 | + for await (const saved of this.checkpointer.list(config, limit, before)) { |
| 350 | + const channels = emptyChannels(this.channels, saved.checkpoint); |
| 351 | + // eslint-disable-next-line @typescript-eslint/no-unused-vars |
| 352 | + const [_, nextTasks] = _prepareNextTasks( |
| 353 | + saved.checkpoint, |
| 354 | + this.nodes, |
| 355 | + channels, |
| 356 | + false |
| 357 | + ); |
| 358 | + yield { |
| 359 | + values: readChannels(channels, this.streamChannelsAsIs), |
| 360 | + next: nextTasks.map((task) => task.name), |
| 361 | + metadata: saved.metadata, |
| 362 | + config: saved.config, |
| 363 | + parentConfig: saved.parentConfig, |
| 364 | + }; |
| 365 | + } |
| 366 | + } |
| 367 | + |
| 368 | + async updateState( |
| 369 | + config: RunnableConfig, |
| 370 | + values: Record<string, unknown> | unknown, |
| 371 | + asNode?: keyof Nn |
| 372 | + ): Promise<RunnableConfig> { |
| 373 | + if (!this.checkpointer) { |
| 374 | + throw new GraphValueError("No checkpointer set"); |
| 375 | + } |
| 376 | + |
| 377 | + // Get the latest checkpoint |
| 378 | + const saved = await this.checkpointer.getTuple(config); |
| 379 | + const checkpoint = saved |
| 380 | + ? copyCheckpoint(saved.checkpoint) |
| 381 | + : emptyCheckpoint(); |
| 382 | + // Find last that updated the state, if not provided |
| 383 | + const maxSeens = Object.entries(checkpoint.versions_seen).reduce( |
| 384 | + (acc, [node, versions]) => { |
| 385 | + const maxSeen = Math.max(...Object.values(versions)); |
| 386 | + if (maxSeen) { |
| 387 | + if (!acc[maxSeen]) { |
| 388 | + acc[maxSeen] = []; |
| 389 | + } |
| 390 | + acc[maxSeen].push(node); |
| 391 | + } |
| 392 | + return acc; |
| 393 | + }, |
| 394 | + {} as Record<number, string[]> |
| 395 | + ); |
| 396 | + if (!asNode && !Object.keys(maxSeens).length) { |
| 397 | + if (!Array.isArray(this.inputs) && this.inputs in this.nodes) { |
| 398 | + asNode = this.inputs as keyof Nn; |
| 399 | + } |
| 400 | + } else if (!asNode) { |
| 401 | + const maxSeen = Math.max(...Object.keys(maxSeens).map(Number)); |
| 402 | + const nodes = maxSeens[maxSeen]; |
| 403 | + if (nodes.length === 1) { |
| 404 | + asNode = nodes[0] as keyof Nn; |
| 405 | + } |
| 406 | + } |
| 407 | + if (!asNode) { |
| 408 | + throw new InvalidUpdateError("Ambiguous update, specify as_node"); |
| 409 | + } |
| 410 | + // update channels |
| 411 | + const channels = emptyChannels(this.channels, checkpoint); |
| 412 | + // create task to run all writers of the chosen node |
| 413 | + const writers = this.nodes[asNode].getWriters(); |
| 414 | + if (!writers.length) { |
| 415 | + throw new InvalidUpdateError( |
| 416 | + `No writers found for node ${asNode as string}` |
| 417 | + ); |
| 418 | + } |
| 419 | + const task: PregelExecutableTask<keyof Nn, keyof Cc> = { |
| 420 | + name: asNode, |
| 421 | + input: values, |
| 422 | + proc: |
| 423 | + // eslint-disable-next-line @typescript-eslint/no-explicit-any |
| 424 | + writers.length > 1 ? RunnableSequence.from(writers as any) : writers[0], |
| 425 | + writes: [], |
| 426 | + config: undefined, |
| 427 | + }; |
| 428 | + // execute task |
| 429 | + await task.proc.invoke( |
| 430 | + task.input, |
| 431 | + patchConfig(config, { |
| 432 | + runName: `${this.name}UpdateState`, |
| 433 | + configurable: { |
| 434 | + [CONFIG_KEY_SEND]: (items: [keyof Cc, unknown][]) => |
| 435 | + task.writes.push(...items), |
| 436 | + [CONFIG_KEY_READ]: _localRead.bind( |
| 437 | + undefined, |
| 438 | + checkpoint, |
| 439 | + channels, |
| 440 | + task.writes as Array<[string, unknown]> |
| 441 | + ), |
| 442 | + }, |
| 443 | + }) |
| 444 | + ); |
| 445 | + // apply to checkpoint and save |
| 446 | + _applyWrites(checkpoint, channels, task.writes); |
| 447 | + const step = (saved?.metadata?.step ?? -2) + 1; |
| 448 | + return await this.checkpointer.put( |
| 449 | + saved?.config ?? config, |
| 450 | + createCheckpoint(checkpoint, channels, step), |
| 451 | + { |
| 452 | + source: "update", |
| 453 | + step, |
| 454 | + writes: { [asNode]: values }, |
| 455 | + } |
| 456 | + ); |
| 457 | + } |
| 458 | + |
311 | 459 | _defaults(config: PregelOptions<Nn, Cc>): [ |
312 | 460 | boolean, // debug |
313 | 461 | StreamMode, // stream mode |
|
0 commit comments