11from __future__ import annotations
22
33import asyncio
4- import inspect
5- from collections .abc import Coroutine , Sequence
4+ import atexit
5+ import functools
6+ import logging
7+ from collections .abc import Awaitable , Callable , Coroutine , Sequence
8+ from concurrent .futures import ThreadPoolExecutor
9+ from contextvars import copy_context
10+ from functools import partial
611from typing import Any , TypeAlias , cast
712from uuid import UUID
813
2126)
2227
2328GRAPH_CALLBACKS_KEY = "graph_callbacks"
29+ logger = logging .getLogger (__name__ )
30+
31+
32+ @functools .lru_cache (maxsize = 1 )
33+ def _executor () -> ThreadPoolExecutor :
34+ executor = ThreadPoolExecutor (max_workers = 10 )
35+ atexit .register (executor .shutdown , wait = True )
36+ return executor
37+
38+
39+ def _coerce_coroutine (awaitable : Awaitable [Any ]) -> Coroutine [Any , Any , Any ]:
40+ if asyncio .iscoroutine (awaitable ):
41+ return awaitable
42+
43+ async def _await () -> Any :
44+ return await awaitable
45+
46+ return _await ()
47+
48+
49+ def _run_coroutine_with_new_loop (coro : Coroutine [Any , Any , Any ]) -> None :
50+ if hasattr (asyncio , "Runner" ):
51+ with asyncio .Runner () as runner :
52+ runner .run (coro )
53+ while pending := asyncio .all_tasks (runner .get_loop ()):
54+ runner .run (asyncio .wait (pending ))
55+ else :
56+ loop = asyncio .new_event_loop ()
57+ try :
58+ asyncio .set_event_loop (loop )
59+ loop .run_until_complete (coro )
60+ while pending := asyncio .all_tasks (loop ):
61+ loop .run_until_complete (asyncio .wait (pending ))
62+ finally :
63+ asyncio .set_event_loop (None )
64+ loop .close ()
2465
2566
2667class GraphCallbackHandler (BaseCallbackHandler ):
@@ -131,6 +172,74 @@ def configure(
131172 )
132173 return cls (handlers , run_id = run_id )
133174
175+ @staticmethod
176+ def _maybe_raise_handler_error (
177+ handler : GraphCallbackHandler ,
178+ event_name : str ,
179+ error : Exception ,
180+ ) -> None :
181+ if getattr (handler , "raise_error" , False ):
182+ raise error
183+ logger .exception (
184+ "Ignoring graph callback handler error in %s for %s" ,
185+ event_name ,
186+ handler .__class__ .__name__ ,
187+ exc_info = error ,
188+ )
189+
190+ def _drain_awaitable (
191+ self ,
192+ handler : GraphCallbackHandler ,
193+ event_name : str ,
194+ awaitable : Awaitable [Any ],
195+ ) -> None :
196+ coro = _coerce_coroutine (awaitable )
197+ try :
198+ asyncio .get_running_loop ()
199+ except RuntimeError :
200+ try :
201+ _run_coroutine_with_new_loop (coro )
202+ except Exception as exc :
203+ self ._maybe_raise_handler_error (handler , event_name , exc )
204+ else :
205+ future = _executor ().submit (
206+ cast (Callable [..., Any ], copy_context ().run ),
207+ _run_coroutine_with_new_loop ,
208+ coro ,
209+ )
210+ try :
211+ future .result ()
212+ except Exception as exc :
213+ self ._maybe_raise_handler_error (handler , event_name , exc )
214+
215+ async def _ainvoke_handler (
216+ self ,
217+ handler : GraphCallbackHandler ,
218+ callback : Callable [..., Any ],
219+ event_name : str ,
220+ * args : Any ,
221+ ** kwargs : Any ,
222+ ) -> None :
223+ try :
224+ if asyncio .iscoroutinefunction (callback ):
225+ await callback (* args , ** kwargs )
226+ elif handler .run_inline :
227+ result = callback (* args , ** kwargs )
228+ if isinstance (result , Awaitable ):
229+ await result
230+ else :
231+ result = await asyncio .get_running_loop ().run_in_executor (
232+ None ,
233+ cast (
234+ Callable [..., Any ],
235+ partial (copy_context ().run , callback , * args , ** kwargs ),
236+ ),
237+ )
238+ if isinstance (result , Awaitable ):
239+ await result
240+ except Exception as exc :
241+ self ._maybe_raise_handler_error (handler , event_name , exc )
242+
134243 def on_interrupt (
135244 self ,
136245 interrupts : Sequence [Interrupt ],
@@ -142,16 +251,19 @@ def on_interrupt(
142251 ) -> None :
143252 for handler in self .handlers :
144253 graph_handler = cast (GraphCallbackHandler , handler )
145- result = graph_handler .on_interrupt (
146- interrupts ,
147- run_id = self .run_id ,
148- status = status ,
149- checkpoint_id = checkpoint_id ,
150- checkpoint_ns = checkpoint_ns ,
151- is_nested = is_nested ,
152- )
153- if inspect .iscoroutine (result ):
154- self ._drain_coroutine (result )
254+ try :
255+ result = graph_handler .on_interrupt (
256+ interrupts ,
257+ run_id = self .run_id ,
258+ status = status ,
259+ checkpoint_id = checkpoint_id ,
260+ checkpoint_ns = checkpoint_ns ,
261+ is_nested = is_nested ,
262+ )
263+ if isinstance (result , Awaitable ):
264+ self ._drain_awaitable (graph_handler , "on_interrupt" , result )
265+ except Exception as exc :
266+ self ._maybe_raise_handler_error (graph_handler , "on_interrupt" , exc )
155267
156268 def on_resume (
157269 self ,
@@ -163,15 +275,18 @@ def on_resume(
163275 ) -> None :
164276 for handler in self .handlers :
165277 graph_handler = cast (GraphCallbackHandler , handler )
166- result = graph_handler .on_resume (
167- run_id = self .run_id ,
168- status = status ,
169- checkpoint_id = checkpoint_id ,
170- checkpoint_ns = checkpoint_ns ,
171- is_nested = is_nested ,
172- )
173- if inspect .iscoroutine (result ):
174- self ._drain_coroutine (result )
278+ try :
279+ result = graph_handler .on_resume (
280+ run_id = self .run_id ,
281+ status = status ,
282+ checkpoint_id = checkpoint_id ,
283+ checkpoint_ns = checkpoint_ns ,
284+ is_nested = is_nested ,
285+ )
286+ if isinstance (result , Awaitable ):
287+ self ._drain_awaitable (graph_handler , "on_resume" , result )
288+ except Exception as exc :
289+ self ._maybe_raise_handler_error (graph_handler , "on_resume" , exc )
175290
176291 async def aon_interrupt (
177292 self ,
@@ -182,18 +297,40 @@ async def aon_interrupt(
182297 checkpoint_ns : tuple [str , ...],
183298 is_nested : bool ,
184299 ) -> None :
185- for handler in self .handlers :
186- graph_handler = cast (GraphCallbackHandler , handler )
187- result = graph_handler .on_interrupt (
300+ inline_handlers = [
301+ cast (GraphCallbackHandler , handler )
302+ for handler in self .handlers
303+ if handler .run_inline
304+ ]
305+ for graph_handler in inline_handlers :
306+ await self ._ainvoke_handler (
307+ graph_handler ,
308+ graph_handler .on_interrupt ,
309+ "on_interrupt" ,
188310 interrupts ,
189311 run_id = self .run_id ,
190312 status = status ,
191313 checkpoint_id = checkpoint_id ,
192314 checkpoint_ns = checkpoint_ns ,
193315 is_nested = is_nested ,
194316 )
195- if inspect .isawaitable (result ):
196- await result
317+ await asyncio .gather (
318+ * (
319+ self ._ainvoke_handler (
320+ cast (GraphCallbackHandler , handler ),
321+ cast (GraphCallbackHandler , handler ).on_interrupt ,
322+ "on_interrupt" ,
323+ interrupts ,
324+ run_id = self .run_id ,
325+ status = status ,
326+ checkpoint_id = checkpoint_id ,
327+ checkpoint_ns = checkpoint_ns ,
328+ is_nested = is_nested ,
329+ )
330+ for handler in self .handlers
331+ if not handler .run_inline
332+ )
333+ )
197334
198335 async def aon_resume (
199336 self ,
@@ -203,26 +340,38 @@ async def aon_resume(
203340 checkpoint_ns : tuple [str , ...],
204341 is_nested : bool ,
205342 ) -> None :
206- for handler in self .handlers :
207- graph_handler = cast (GraphCallbackHandler , handler )
208- result = graph_handler .on_resume (
343+ inline_handlers = [
344+ cast (GraphCallbackHandler , handler )
345+ for handler in self .handlers
346+ if handler .run_inline
347+ ]
348+ for graph_handler in inline_handlers :
349+ await self ._ainvoke_handler (
350+ graph_handler ,
351+ graph_handler .on_resume ,
352+ "on_resume" ,
209353 run_id = self .run_id ,
210354 status = status ,
211355 checkpoint_id = checkpoint_id ,
212356 checkpoint_ns = checkpoint_ns ,
213357 is_nested = is_nested ,
214358 )
215- if inspect .isawaitable (result ):
216- await result
217-
218- @staticmethod
219- def _drain_coroutine (coro : Coroutine [Any , Any , Any ]) -> None :
220- try :
221- loop = asyncio .get_running_loop ()
222- except RuntimeError :
223- asyncio .run (coro )
224- else :
225- loop .create_task (coro )
359+ await asyncio .gather (
360+ * (
361+ self ._ainvoke_handler (
362+ cast (GraphCallbackHandler , handler ),
363+ cast (GraphCallbackHandler , handler ).on_resume ,
364+ "on_resume" ,
365+ run_id = self .run_id ,
366+ status = status ,
367+ checkpoint_id = checkpoint_id ,
368+ checkpoint_ns = checkpoint_ns ,
369+ is_nested = is_nested ,
370+ )
371+ for handler in self .handlers
372+ if not handler .run_inline
373+ )
374+ )
226375
227376
228377GraphCallbacks : TypeAlias = (
0 commit comments