Skip to content

Commit 597ad79

Browse files
committed
Changes following review
1 parent 3b5be0f commit 597ad79

File tree

2 files changed

+10
-23
lines changed

2 files changed

+10
-23
lines changed

libs/core/langchain_core/runnables/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
indent_lines_after_first,
8181
is_async_callable,
8282
is_async_generator,
83-
task_with_context,
8483
)
8584
from langchain_core.tracers._streaming import _StreamingCallbackHandler
8685
from langchain_core.tracers.event_stream import (
@@ -3189,7 +3188,7 @@ async def ainvoke(
31893188
part = functools.partial(step.ainvoke, input_, config, **kwargs)
31903189
else:
31913190
part = functools.partial(step.ainvoke, input_, config)
3192-
input_ = await task_with_context(part(), context)
3191+
input_ = await coro_with_context(part(), context, create_task=True)
31933192
# finish the root run
31943193
except BaseException as e:
31953194
await run_manager.on_chain_error(e)
@@ -3911,8 +3910,8 @@ async def _ainvoke_step(
39113910
callbacks=run_manager.get_child(f"map:key:{key}"),
39123911
)
39133912
with set_config_context(child_config) as context:
3914-
return await task_with_context(
3915-
step.ainvoke(input_, child_config), context
3913+
return await coro_with_context(
3914+
step.ainvoke(input_, child_config), context, create_task=True
39163915
)
39173916

39183917
# gather results from all steps

libs/core/langchain_core/runnables/utils.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TypeVar,
2121
)
2222

23-
from typing_extensions import override
23+
from typing_extensions import override, reveal_type
2424

2525
# Re-export create-model for backwards compatibility
2626
from langchain_core.utils.pydantic import create_model # noqa: F401
@@ -132,35 +132,23 @@ def asyncio_accepts_context() -> bool:
132132
_T = TypeVar("_T")
133133

134134

135-
def task_with_context(
136-
coro: Coroutine[Any, Any, _T], context: Context
137-
) -> asyncio.Task[_T]:
138-
"""Create a task with a context.
139-
140-
Args:
141-
coro: The coroutine to create a task for.
142-
context: The context to use.
143-
144-
Returns:
145-
The task with the context.
146-
"""
147-
if asyncio_accepts_context():
148-
return asyncio.create_task(coro, context=context) # type: ignore[arg-type,call-arg,unused-ignore]
149-
return asyncio.create_task(coro)
150-
151-
152-
def coro_with_context(coro: Awaitable[_T], context: Context) -> Awaitable[_T]:
135+
def coro_with_context(
136+
coro: Awaitable[_T], context: Context, *, create_task: bool = False
137+
) -> Awaitable[_T]:
153138
"""Await a coroutine with a context.
154139
155140
Args:
156141
coro: The coroutine to await.
157142
context: The context to use.
143+
create_task: Whether to create a task.
158144
159145
Returns:
160146
The coroutine with the context.
161147
"""
162148
if asyncio_accepts_context():
163149
return asyncio.create_task(coro, context=context) # type: ignore[arg-type,call-arg,unused-ignore]
150+
if create_task:
151+
return asyncio.create_task(coro) # type: ignore[arg-type]
164152
return coro
165153

166154

0 commit comments

Comments
 (0)