Skip to content

Commit 1c8b34f

Browse files
committed
fix(langgraph): harden graph callback dispatch semantics
1 parent 7468378 commit 1c8b34f

File tree

4 files changed

+240
-55
lines changed

4 files changed

+240
-55
lines changed

libs/langgraph/langgraph/callbacks.py

Lines changed: 189 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import inspect
5-
from collections.abc import Coroutine, Sequence
4+
import atexit
5+
import functools
6+
import logging
7+
from collections.abc import Awaitable, Callable, Coroutine, Sequence
8+
from concurrent.futures import ThreadPoolExecutor
9+
from contextvars import copy_context
10+
from functools import partial
611
from typing import Any, TypeAlias, cast
712
from uuid import UUID
813

@@ -21,6 +26,42 @@
2126
)
2227

2328
GRAPH_CALLBACKS_KEY = "graph_callbacks"
29+
logger = logging.getLogger(__name__)
30+
31+
32+
@functools.lru_cache(maxsize=1)
33+
def _executor() -> ThreadPoolExecutor:
34+
executor = ThreadPoolExecutor(max_workers=10)
35+
atexit.register(executor.shutdown, wait=True)
36+
return executor
37+
38+
39+
def _coerce_coroutine(awaitable: Awaitable[Any]) -> Coroutine[Any, Any, Any]:
40+
if asyncio.iscoroutine(awaitable):
41+
return awaitable
42+
43+
async def _await() -> Any:
44+
return await awaitable
45+
46+
return _await()
47+
48+
49+
def _run_coroutine_with_new_loop(coro: Coroutine[Any, Any, Any]) -> None:
50+
if hasattr(asyncio, "Runner"):
51+
with asyncio.Runner() as runner:
52+
runner.run(coro)
53+
while pending := asyncio.all_tasks(runner.get_loop()):
54+
runner.run(asyncio.wait(pending))
55+
else:
56+
loop = asyncio.new_event_loop()
57+
try:
58+
asyncio.set_event_loop(loop)
59+
loop.run_until_complete(coro)
60+
while pending := asyncio.all_tasks(loop):
61+
loop.run_until_complete(asyncio.wait(pending))
62+
finally:
63+
asyncio.set_event_loop(None)
64+
loop.close()
2465

2566

2667
class GraphCallbackHandler(BaseCallbackHandler):
@@ -131,6 +172,74 @@ def configure(
131172
)
132173
return cls(handlers, run_id=run_id)
133174

175+
@staticmethod
176+
def _maybe_raise_handler_error(
177+
handler: GraphCallbackHandler,
178+
event_name: str,
179+
error: Exception,
180+
) -> None:
181+
if getattr(handler, "raise_error", False):
182+
raise error
183+
logger.exception(
184+
"Ignoring graph callback handler error in %s for %s",
185+
event_name,
186+
handler.__class__.__name__,
187+
exc_info=error,
188+
)
189+
190+
def _drain_awaitable(
191+
self,
192+
handler: GraphCallbackHandler,
193+
event_name: str,
194+
awaitable: Awaitable[Any],
195+
) -> None:
196+
coro = _coerce_coroutine(awaitable)
197+
try:
198+
asyncio.get_running_loop()
199+
except RuntimeError:
200+
try:
201+
_run_coroutine_with_new_loop(coro)
202+
except Exception as exc:
203+
self._maybe_raise_handler_error(handler, event_name, exc)
204+
else:
205+
future = _executor().submit(
206+
cast(Callable[..., Any], copy_context().run),
207+
_run_coroutine_with_new_loop,
208+
coro,
209+
)
210+
try:
211+
future.result()
212+
except Exception as exc:
213+
self._maybe_raise_handler_error(handler, event_name, exc)
214+
215+
async def _ainvoke_handler(
216+
self,
217+
handler: GraphCallbackHandler,
218+
callback: Callable[..., Any],
219+
event_name: str,
220+
*args: Any,
221+
**kwargs: Any,
222+
) -> None:
223+
try:
224+
if asyncio.iscoroutinefunction(callback):
225+
await callback(*args, **kwargs)
226+
elif handler.run_inline:
227+
result = callback(*args, **kwargs)
228+
if isinstance(result, Awaitable):
229+
await result
230+
else:
231+
result = await asyncio.get_running_loop().run_in_executor(
232+
None,
233+
cast(
234+
Callable[..., Any],
235+
partial(copy_context().run, callback, *args, **kwargs),
236+
),
237+
)
238+
if isinstance(result, Awaitable):
239+
await result
240+
except Exception as exc:
241+
self._maybe_raise_handler_error(handler, event_name, exc)
242+
134243
def on_interrupt(
135244
self,
136245
interrupts: Sequence[Interrupt],
@@ -142,16 +251,19 @@ def on_interrupt(
142251
) -> None:
143252
for handler in self.handlers:
144253
graph_handler = cast(GraphCallbackHandler, handler)
145-
result = graph_handler.on_interrupt(
146-
interrupts,
147-
run_id=self.run_id,
148-
status=status,
149-
checkpoint_id=checkpoint_id,
150-
checkpoint_ns=checkpoint_ns,
151-
is_nested=is_nested,
152-
)
153-
if inspect.iscoroutine(result):
154-
self._drain_coroutine(result)
254+
try:
255+
result = graph_handler.on_interrupt(
256+
interrupts,
257+
run_id=self.run_id,
258+
status=status,
259+
checkpoint_id=checkpoint_id,
260+
checkpoint_ns=checkpoint_ns,
261+
is_nested=is_nested,
262+
)
263+
if isinstance(result, Awaitable):
264+
self._drain_awaitable(graph_handler, "on_interrupt", result)
265+
except Exception as exc:
266+
self._maybe_raise_handler_error(graph_handler, "on_interrupt", exc)
155267

156268
def on_resume(
157269
self,
@@ -163,15 +275,18 @@ def on_resume(
163275
) -> None:
164276
for handler in self.handlers:
165277
graph_handler = cast(GraphCallbackHandler, handler)
166-
result = graph_handler.on_resume(
167-
run_id=self.run_id,
168-
status=status,
169-
checkpoint_id=checkpoint_id,
170-
checkpoint_ns=checkpoint_ns,
171-
is_nested=is_nested,
172-
)
173-
if inspect.iscoroutine(result):
174-
self._drain_coroutine(result)
278+
try:
279+
result = graph_handler.on_resume(
280+
run_id=self.run_id,
281+
status=status,
282+
checkpoint_id=checkpoint_id,
283+
checkpoint_ns=checkpoint_ns,
284+
is_nested=is_nested,
285+
)
286+
if isinstance(result, Awaitable):
287+
self._drain_awaitable(graph_handler, "on_resume", result)
288+
except Exception as exc:
289+
self._maybe_raise_handler_error(graph_handler, "on_resume", exc)
175290

176291
async def aon_interrupt(
177292
self,
@@ -182,18 +297,40 @@ async def aon_interrupt(
182297
checkpoint_ns: tuple[str, ...],
183298
is_nested: bool,
184299
) -> None:
185-
for handler in self.handlers:
186-
graph_handler = cast(GraphCallbackHandler, handler)
187-
result = graph_handler.on_interrupt(
300+
inline_handlers = [
301+
cast(GraphCallbackHandler, handler)
302+
for handler in self.handlers
303+
if handler.run_inline
304+
]
305+
for graph_handler in inline_handlers:
306+
await self._ainvoke_handler(
307+
graph_handler,
308+
graph_handler.on_interrupt,
309+
"on_interrupt",
188310
interrupts,
189311
run_id=self.run_id,
190312
status=status,
191313
checkpoint_id=checkpoint_id,
192314
checkpoint_ns=checkpoint_ns,
193315
is_nested=is_nested,
194316
)
195-
if inspect.isawaitable(result):
196-
await result
317+
await asyncio.gather(
318+
*(
319+
self._ainvoke_handler(
320+
cast(GraphCallbackHandler, handler),
321+
cast(GraphCallbackHandler, handler).on_interrupt,
322+
"on_interrupt",
323+
interrupts,
324+
run_id=self.run_id,
325+
status=status,
326+
checkpoint_id=checkpoint_id,
327+
checkpoint_ns=checkpoint_ns,
328+
is_nested=is_nested,
329+
)
330+
for handler in self.handlers
331+
if not handler.run_inline
332+
)
333+
)
197334

198335
async def aon_resume(
199336
self,
@@ -203,26 +340,38 @@ async def aon_resume(
203340
checkpoint_ns: tuple[str, ...],
204341
is_nested: bool,
205342
) -> None:
206-
for handler in self.handlers:
207-
graph_handler = cast(GraphCallbackHandler, handler)
208-
result = graph_handler.on_resume(
343+
inline_handlers = [
344+
cast(GraphCallbackHandler, handler)
345+
for handler in self.handlers
346+
if handler.run_inline
347+
]
348+
for graph_handler in inline_handlers:
349+
await self._ainvoke_handler(
350+
graph_handler,
351+
graph_handler.on_resume,
352+
"on_resume",
209353
run_id=self.run_id,
210354
status=status,
211355
checkpoint_id=checkpoint_id,
212356
checkpoint_ns=checkpoint_ns,
213357
is_nested=is_nested,
214358
)
215-
if inspect.isawaitable(result):
216-
await result
217-
218-
@staticmethod
219-
def _drain_coroutine(coro: Coroutine[Any, Any, Any]) -> None:
220-
try:
221-
loop = asyncio.get_running_loop()
222-
except RuntimeError:
223-
asyncio.run(coro)
224-
else:
225-
loop.create_task(coro)
359+
await asyncio.gather(
360+
*(
361+
self._ainvoke_handler(
362+
cast(GraphCallbackHandler, handler),
363+
cast(GraphCallbackHandler, handler).on_resume,
364+
"on_resume",
365+
run_id=self.run_id,
366+
status=status,
367+
checkpoint_id=checkpoint_id,
368+
checkpoint_ns=checkpoint_ns,
369+
is_nested=is_nested,
370+
)
371+
for handler in self.handlers
372+
if not handler.run_inline
373+
)
374+
)
226375

227376

228377
GraphCallbacks: TypeAlias = (

libs/langgraph/langgraph/pregel/_loop.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class PregelLoop:
215215
tasks: dict[str, PregelExecutableTask]
216216
output: None | dict[str, Any] | Any = None
217217
updated_channels: set[str] | None = None
218-
_graph_lifecycle_events: list[GraphLifecycleEvent]
218+
_graph_lifecycle_events: deque[GraphLifecycleEvent]
219219

220220
# public
221221

@@ -265,7 +265,7 @@ def __init__(
265265
self.retry_policy = retry_policy
266266
self.cache_policy = cache_policy
267267
self.durability = durability
268-
self._graph_lifecycle_events = []
268+
self._graph_lifecycle_events = deque()
269269
if self.stream is not None and CONFIG_KEY_STREAM in config[CONF]:
270270
self.stream = DuplexStream(self.stream, config[CONF][CONFIG_KEY_STREAM])
271271
scratchpad: PregelScratchpad | None = config[CONF].get(CONFIG_KEY_SCRATCHPAD)
@@ -334,10 +334,10 @@ def _push_graph_lifecycle_event(
334334
)
335335
)
336336

337-
def shift_graph_lifecycle_event(self) -> GraphLifecycleEvent | None:
337+
def _pop_lifecycle_event(self) -> GraphLifecycleEvent | None:
338338
if not self._graph_lifecycle_events:
339339
return None
340-
return self._graph_lifecycle_events.pop(0)
340+
return self._graph_lifecycle_events.popleft()
341341

342342
def put_writes(self, task_id: str, writes: WritesT) -> None:
343343
"""Put writes for a task, to be read by the next tick."""
@@ -923,9 +923,8 @@ def _suppress_interrupt(
923923
self._put_checkpoint(self.checkpoint_metadata)
924924
self._put_pending_writes()
925925
# suppress interrupt
926-
suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested
927-
if suppress:
928-
interrupt = cast(GraphInterrupt, exc_value)
926+
if isinstance(exc_value, GraphInterrupt) and not self.is_nested:
927+
interrupt = exc_value
929928
interrupts = tuple(interrupt.args[0]) if interrupt.args else ()
930929
self._push_graph_lifecycle_event("interrupt", interrupts=interrupts)
931930
# emit one last "values" event, with pending writes applied
@@ -954,12 +953,11 @@ def _suppress_interrupt(
954953
self.channels,
955954
)
956955
# emit INTERRUPT if exception is empty (otherwise emitted by put_writes)
957-
if exc_value is not None and (not exc_value.args or not exc_value.args[0]):
956+
if not interrupt.args or not interrupt.args[0]:
957+
interrupt_payload = interrupt.args[0] if interrupt.args else ()
958958
self._emit(
959959
"updates",
960-
lambda: iter(
961-
[{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}]
962-
),
960+
lambda: iter([{INTERRUPT: interrupt_payload}]),
963961
)
964962
# save final output
965963
self.output = read_channels(self.channels, self.output_keys)
@@ -1177,7 +1175,7 @@ def put_writes(self, task_id: str, writes: WritesT) -> None:
11771175
# context manager
11781176

11791177
def __enter__(self) -> Self:
1180-
self._graph_lifecycle_events = []
1178+
self._graph_lifecycle_events = deque()
11811179
if not self.checkpointer:
11821180
saved = None
11831181
elif self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID):
@@ -1377,7 +1375,7 @@ def put_writes(self, task_id: str, writes: WritesT) -> None:
13771375
# context manager
13781376

13791377
async def __aenter__(self) -> Self:
1380-
self._graph_lifecycle_events = []
1378+
self._graph_lifecycle_events = deque()
13811379
if not self.checkpointer:
13821380
saved = None
13831381
elif self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID):

libs/langgraph/langgraph/pregel/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2675,7 +2675,7 @@ def stream_writer(c: Any) -> None:
26752675
_state_mapper = self._state_mapper if version == "v2" else None
26762676

26772677
def emit_graph_lifecycle_events(loop: SyncPregelLoop) -> None:
2678-
while (event := loop.shift_graph_lifecycle_event()) is not None:
2678+
while (event := loop._pop_lifecycle_event()) is not None:
26792679
if event.kind == "resume":
26802680
graph_callback_manager.on_resume(
26812681
status=event.status,
@@ -3073,7 +3073,7 @@ def stream_writer(c: Any) -> None:
30733073
_state_mapper = self._state_mapper if version == "v2" else None
30743074

30753075
async def aemit_graph_lifecycle_events(loop: AsyncPregelLoop) -> None:
3076-
while (event := loop.shift_graph_lifecycle_event()) is not None:
3076+
while (event := loop._pop_lifecycle_event()) is not None:
30773077
if event.kind == "resume":
30783078
await graph_callback_manager.aon_resume(
30793079
status=event.status,

0 commit comments

Comments
 (0)