Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/core/langchain_core/_api/internal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from typing import cast


def is_caller_internal(depth: int = 2) -> bool:
Expand All @@ -16,7 +17,7 @@ def is_caller_internal(depth: int = 2) -> bool:
return False
# Directly access the module name from the frame's global variables
module_globals = frame.f_globals
caller_module_name = module_globals.get("__name__", "")
caller_module_name = cast("str", module_globals.get("__name__", ""))
return caller_module_name.startswith("langchain")
finally:
del frame
3 changes: 2 additions & 1 deletion libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Literal,
TypeAlias,
TypeVar,
cast,
)

from pydantic import BaseModel, ConfigDict, Field, field_validator
Expand Down Expand Up @@ -92,7 +93,7 @@ def _get_token_ids_default_method(text: str) -> list[int]:
tokenizer = get_tokenizer()

# tokenize the text using the GPT-2 tokenizer
return tokenizer.encode(text)
return cast("list[int]", tokenizer.encode(text))


LanguageModelInput = PromptValue | str | Sequence[MessageLikeRepresentation]
Expand Down
7 changes: 5 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):

@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# self is always a Serializable object in this case, thus the result is
# guaranteed to be a dict since dumps uses the default callback, which uses
# obj.to_json which always returns TypedDict subclasses
return cast("dict[str, Any]", dumpd(self))

# --- Runnable methods ---

Expand Down Expand Up @@ -462,7 +465,7 @@ def _should_stream(

# Check if a runtime streaming flag has been passed in.
if "stream" in kwargs:
return kwargs["stream"]
return bool(kwargs["stream"])
Copy link
Copy Markdown
Collaborator Author

@cbornet Christophe Bornet (cbornet) Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling bool() here doesn't change the behavior as _should_stream is always called in a if condition.


if "streaming" in self.model_fields_set:
streaming_value = getattr(self, "streaming", None)
Expand Down
5 changes: 4 additions & 1 deletion libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):

@functools.cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# self is always a Serializable object in this case, thus the result is
# guaranteed to be a dict since dumps uses the default callback, which uses
# obj.to_json which always returns TypedDict subclasses
return cast("dict[str, Any]", dumpd(self))

# --- Runnable methods ---

Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415

prompt = ChatPromptTemplate(messages=[self])
return prompt + other
return prompt.__add__(other)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy doesn't figure the + operator...


def pretty_repr(
self,
Expand Down
62 changes: 50 additions & 12 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,19 @@
def _get_type(v: Any) -> str:
"""Get the type associated with the object for serialization purposes."""
if isinstance(v, dict) and "type" in v:
return v["type"]
if hasattr(v, "type"):
return v.type
msg = (
f"Expected either a dictionary with a 'type' key or an object "
f"with a 'type' attribute. Instead got type {type(v)}."
)
raise TypeError(msg)
result = v["type"]
elif hasattr(v, "type"):
result = v.type
else:
msg = (
f"Expected either a dictionary with a 'type' key or an object "
f"with a 'type' attribute. Instead got type {type(v)}."
)
raise TypeError(msg)
if not isinstance(result, str):
msg = f"Expected 'type' to be a str, got {type(result).__name__}"
raise TypeError(msg)
return result


AnyMessage = Annotated[
Expand Down Expand Up @@ -215,8 +220,11 @@ def message_chunk_to_message(chunk: BaseMessage) -> BaseMessage:
ignore_keys = ["type"]
if isinstance(chunk, AIMessageChunk):
ignore_keys.extend(["tool_call_chunks", "chunk_position"])
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys}
return cast(
"BaseMessage",
chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys}
),
)


Expand Down Expand Up @@ -1112,6 +1120,32 @@ def list_token_counter(messages: Sequence[BaseMessage]) -> int:
raise ValueError(msg)


_SingleMessage = BaseMessage | str | dict[str, Any]
_T = TypeVar("_T", bound=_SingleMessage)
# A sequence of _SingleMessage that is NOT a bare str
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because otherwise str matches Sequence[str]

_MultipleMessages = Sequence[_T]


@overload
def convert_to_openai_messages(
messages: _SingleMessage,
*,
text_format: Literal["string", "block"] = "string",
include_id: bool = False,
pass_through_unknown_blocks: bool = True,
) -> dict: ...


@overload
def convert_to_openai_messages(
messages: _MultipleMessages,
*,
text_format: Literal["string", "block"] = "string",
include_id: bool = False,
pass_through_unknown_blocks: bool = True,
) -> list[dict]: ...


def convert_to_openai_messages(
messages: MessageLikeRepresentation | Sequence[MessageLikeRepresentation],
*,
Expand Down Expand Up @@ -1207,7 +1241,7 @@ def convert_to_openai_messages(
err = f"Unrecognized {text_format=}, expected one of 'string' or 'block'."
raise ValueError(err)

oai_messages: list = []
oai_messages: list[dict] = []

if is_single := isinstance(messages, (BaseMessage, dict, str)):
messages = [messages]
Expand Down Expand Up @@ -1774,7 +1808,11 @@ def _get_message_openai_role(message: BaseMessage) -> str:
if isinstance(message, ToolMessage):
return "tool"
if isinstance(message, SystemMessage):
return message.additional_kwargs.get("__openai_role__", "system")
role = message.additional_kwargs.get("__openai_role__", "system")
if not isinstance(role, str):
msg = f"Expected '__openai_role__' to be a str, got {type(role).__name__}"
raise TypeError(msg)
return role
if isinstance(message, FunctionMessage):
return "function"
if isinstance(message, ChatMessage):
Expand Down
5 changes: 3 additions & 2 deletions libs/core/langchain_core/output_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
Generic,
TypeVar,
cast,
)

from typing_extensions import override
Expand Down Expand Up @@ -77,7 +78,7 @@ def OutputType(self) -> type[T]:
"""Return the output type for the parser."""
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
return cast("type[T]", T) # type: ignore[misc]

@override
def invoke(
Expand Down Expand Up @@ -181,7 +182,7 @@ def OutputType(self) -> type[T]:
if hasattr(base, "__pydantic_generic_metadata__"):
metadata = base.__pydantic_generic_metadata__
if "args" in metadata and len(metadata["args"]) > 0:
return metadata["args"][0]
return cast("type[T]", metadata["args"][0])

msg = (
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
Expand Down
14 changes: 12 additions & 2 deletions libs/core/langchain_core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Output parsers using Pydantic."""

import json
from typing import Annotated, Generic
from typing import Annotated, Generic, Literal, overload

import pydantic
from pydantic import SkipValidation
Expand Down Expand Up @@ -42,6 +42,16 @@ def _parser_exception(
msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
return OutputParserException(msg, llm_output=json_string)

@overload
def parse_result(
self, result: list[Generation], *, partial: Literal[False] = False
) -> TBaseModel: ...

@overload
def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> TBaseModel | None: ...

def parse_result(
self, result: list[Generation], *, partial: bool = False
) -> TBaseModel | None:
Expand Down Expand Up @@ -77,7 +87,7 @@ def parse(self, text: str) -> TBaseModel:
Returns:
The parsed Pydantic object.
"""
return super().parse(text)
return self.parse_result([Generation(text=text)])

def get_format_instructions(self) -> str:
"""Return the format instructions for the JSON output.
Expand Down
22 changes: 11 additions & 11 deletions libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from collections.abc import Mapping # noqa: TC003
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Generic,
TypeVar,
)
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast

import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
Expand Down Expand Up @@ -122,7 +117,10 @@ def is_lc_serializable(cls) -> bool:

@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# self is always a Serializable object in this case, thus the result is
# guaranteed to be a dict since dumpd uses the default callback, which uses
# obj.to_json which always returns TypedDict subclasses
return cast("dict[str, Any]", dumpd(self))

@property
@override
Expand Down Expand Up @@ -156,7 +154,7 @@ def _validate_input(self, inner_input: Any) -> dict:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
inner_input = {var_name: inner_input}
inner_input_ = {var_name: inner_input}

else:
msg = (
Expand All @@ -168,12 +166,14 @@ def _validate_input(self, inner_input: Any) -> dict:
message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT
)
)
missing = set(self.input_variables).difference(inner_input)
else:
inner_input_ = inner_input
missing = set(self.input_variables).difference(inner_input_)
if missing:
msg = (
f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
f" Received: {list(inner_input_.keys())}"
)
example_key = missing.pop()
msg += (
Expand All @@ -184,7 +184,7 @@ def _validate_input(self, inner_input: Any) -> dict:
raise KeyError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return inner_input
return inner_input_

def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
inner_input_ = self._validate_input(inner_input)
Expand Down
11 changes: 7 additions & 4 deletions libs/core/langchain_core/prompts/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from functools import cached_property
from typing import Any, Literal
from typing import Any, Literal, cast

from typing_extensions import override

Expand Down Expand Up @@ -65,7 +65,10 @@ def _prompt_type(self) -> str:

@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# self is always a Serializable object in this case, thus the result is
# guaranteed to be a dict since dumpd uses the default callback, which uses
# obj.to_json which always returns TypedDict subclasses
return cast("dict[str, Any]", dumpd(self))

@classmethod
def is_lc_serializable(cls) -> bool:
Expand Down Expand Up @@ -116,7 +119,7 @@ def _insert_input_variables(
inputs: dict[str, Any],
template_format: Literal["f-string", "mustache"],
) -> dict[str, Any]:
formatted = {}
formatted: dict[str, Any] = {}
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
for k, v in template.items():
if isinstance(v, str):
Expand All @@ -132,7 +135,7 @@ def _insert_input_variables(
warnings.warn(msg, stacklevel=2)
formatted[k] = _insert_input_variables(v, inputs, template_format)
elif isinstance(v, (list, tuple)):
formatted_v = []
formatted_v: list[str | dict[str, Any]] = []
for x in v:
if isinstance(x, str):
formatted_v.append(formatter(x, **inputs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import ConfigDict, model_validator
from typing_extensions import Self

from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
Expand All @@ -21,7 +22,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Examples to format into the prompt.
Either this or example_selector should be provided."""

example_selector: Any = None
example_selector: BaseExampleSelector | None = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""

Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/prompts/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Image prompt template for a multimodal model."""

from typing import Any
from typing import Any, Literal, cast

from pydantic import Field

Expand Down Expand Up @@ -125,7 +125,7 @@ def format(
output: ImageURL = {"url": url}
if detail:
# Don't check literal values here: let the API check them
output["detail"] = detail
output["detail"] = cast("Literal['auto', 'low', 'high']", detail)
return output

async def aformat(self, **kwargs: Any) -> ImageURL:
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/prompts/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,4 @@ def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415

prompt = ChatPromptTemplate(messages=[self])
return prompt + other
return prompt.__add__(other)
Loading