Skip to content

Commit ebecddd

Browse files
michaelnchinmdrxy
andauthored
fix(core): add init validator and serialization mappings for Bedrock models (#34510)
Adds serialization mappings for `ChatBedrockConverse` and `BedrockLLM` to unblock standard tests on `langchain-core>=1.2.5` (context: [langchain-aws#821](langchain-ai/langchain-aws#821)). Also introduces a class-specific validator system in `langchain_core.load` that blocks deserialization of AWS Bedrock models when `endpoint_url` or `base_url` parameters are present, preventing SSRF attacks via crafted serialized payloads. Closes #34645 ## Changes - Add `ChatBedrockConverse` and `BedrockLLM` entries to `SERIALIZABLE_MAPPING` in `mapping.py`, mapping legacy paths to their `langchain_aws` import locations - Add `validators.py` with `_bedrock_validator` — rejects deserialization kwargs containing `endpoint_url` or `base_url` for all Bedrock-related classes (`ChatBedrock`, `BedrockChat`, `ChatBedrockConverse`, `ChatAnthropicBedrock`, `BedrockLLM`, `Bedrock`) - `CLASS_INIT_VALIDATORS` registry covers both serialized (legacy) keys and resolved import paths from `ALL_SERIALIZABLE_MAPPINGS`, preventing bypass via direct-path payloads - Move kwargs extraction and all validator checks (`CLASS_INIT_VALIDATORS` + `init_validator`) in `Reviver.__call__` to run **before** `importlib.import_module()` — fail fast on security violations before executing third-party code - Class-specific validators are independent of `init_validator` and cannot be disabled by passing `init_validator=None` ## Testing - `test_validator_registry_keys_in_serializable_mapping` — structural invariant test ensuring every `CLASS_INIT_VALIDATORS` key exists in `ALL_SERIALIZABLE_MAPPINGS` - 10 end-to-end `load()` tests covering all Bedrock class paths (legacy aliases, resolved import paths, `ChatAnthropicBedrock`, `init_validator=None` bypass attempt) - Unit tests for `_bedrock_validator` covering `endpoint_url`, `base_url`, both params, and safe kwargs --------- Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
1 parent e94cd41 commit ebecddd

File tree

4 files changed

+370
-7
lines changed

4 files changed

+370
-7
lines changed

libs/core/langchain_core/load/load.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
SERIALIZABLE_MAPPING,
110110
)
111111
from langchain_core.load.serializable import Serializable
112+
from langchain_core.load.validators import CLASS_INIT_VALIDATORS
112113

113114
DEFAULT_NAMESPACES = [
114115
"langchain",
@@ -480,6 +481,19 @@ def __call__(self, value: dict[str, Any]) -> Any:
480481
msg = f"Invalid namespace: {value}"
481482
raise ValueError(msg)
482483

484+
# We don't need to recurse on kwargs
485+
# as json.loads will do that for us.
486+
kwargs = value.get("kwargs", {})
487+
488+
# Run class-specific validators before the general init_validator.
489+
# These run before importing to fail fast on security violations.
490+
if mapping_key in CLASS_INIT_VALIDATORS:
491+
CLASS_INIT_VALIDATORS[mapping_key](mapping_key, kwargs)
492+
493+
# Also run general init_validator (e.g., jinja2 blocking)
494+
if self.init_validator is not None:
495+
self.init_validator(mapping_key, kwargs)
496+
483497
mod = importlib.import_module(".".join(import_dir))
484498

485499
cls = getattr(mod, name)
@@ -489,13 +503,6 @@ def __call__(self, value: dict[str, Any]) -> Any:
489503
msg = f"Invalid namespace: {value}"
490504
raise ValueError(msg)
491505

492-
# We don't need to recurse on kwargs
493-
# as json.loads will do that for us.
494-
kwargs = value.get("kwargs", {})
495-
496-
if self.init_validator is not None:
497-
self.init_validator(mapping_key, kwargs)
498-
499506
return cls(**kwargs)
500507

501508
return value

libs/core/langchain_core/load/mapping.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,12 @@
321321
"bedrock",
322322
"ChatBedrock",
323323
),
324+
("langchain_aws", "chat_models", "ChatBedrockConverse"): (
325+
"langchain_aws",
326+
"chat_models",
327+
"bedrock_converse",
328+
"ChatBedrockConverse",
329+
),
324330
("langchain_google_genai", "chat_models", "ChatGoogleGenerativeAI"): (
325331
"langchain_google_genai",
326332
"chat_models",
@@ -380,6 +386,12 @@
380386
"bedrock",
381387
"BedrockLLM",
382388
),
389+
("langchain", "llms", "bedrock", "BedrockLLM"): (
390+
"langchain_aws",
391+
"llms",
392+
"bedrock",
393+
"BedrockLLM",
394+
),
383395
("langchain", "llms", "fireworks", "Fireworks"): (
384396
"langchain_fireworks",
385397
"llms",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Init validators for deserialization security.
2+
3+
This module contains extra validators that are called during deserialization,
4+
ex. to prevent security issues such as SSRF attacks.
5+
6+
Each validator is a callable matching the `InitValidator` protocol: it takes a
7+
class path tuple and kwargs dict, returns `None` on success, and raises
8+
`ValueError` if the deserialization should be blocked.
9+
"""
10+
11+
from typing import TYPE_CHECKING, Any
12+
13+
if TYPE_CHECKING:
14+
from langchain_core.load.load import InitValidator
15+
16+
17+
def _bedrock_validator(class_path: tuple[str, ...], kwargs: dict[str, Any]) -> None:
18+
"""Constructor kwargs validator for AWS Bedrock integrations.
19+
20+
Blocks deserialization if `endpoint_url` or `base_url` parameters are
21+
present, which could enable SSRF attacks.
22+
23+
Args:
24+
class_path: The class path tuple being deserialized.
25+
kwargs: The kwargs dict for the class constructor.
26+
27+
Raises:
28+
ValueError: If `endpoint_url` or `base_url` parameters are present.
29+
"""
30+
dangerous_params = ["endpoint_url", "base_url"]
31+
found_params = [p for p in dangerous_params if p in kwargs]
32+
33+
if found_params:
34+
class_name = class_path[-1] if class_path else "Unknown"
35+
param_str = ", ".join(found_params)
36+
msg = (
37+
f"Deserialization of {class_name} with {param_str} is not allowed "
38+
f"for security reasons. These parameters can enable Server-Side Request "
39+
f"Forgery (SSRF) attacks by directing network requests to arbitrary "
40+
f"endpoints during initialization. If you need to use a custom endpoint, "
41+
f"instantiate {class_name} directly rather than deserializing it."
42+
)
43+
raise ValueError(msg)
44+
45+
46+
# Keys must cover both serialized IDs (SERIALIZABLE_MAPPING keys) and resolved
47+
# import paths (SERIALIZABLE_MAPPING values) to prevent bypass via direct paths.
48+
CLASS_INIT_VALIDATORS: dict[tuple[str, ...], "InitValidator"] = {
49+
# Serialized (legacy) keys
50+
("langchain", "chat_models", "bedrock", "BedrockChat"): _bedrock_validator,
51+
("langchain", "chat_models", "bedrock", "ChatBedrock"): _bedrock_validator,
52+
(
53+
"langchain",
54+
"chat_models",
55+
"anthropic_bedrock",
56+
"ChatAnthropicBedrock",
57+
): _bedrock_validator,
58+
("langchain_aws", "chat_models", "ChatBedrockConverse"): _bedrock_validator,
59+
("langchain", "llms", "bedrock", "Bedrock"): _bedrock_validator,
60+
("langchain", "llms", "bedrock", "BedrockLLM"): _bedrock_validator,
61+
# Resolved import paths (from ALL_SERIALIZABLE_MAPPINGS values) to defend
62+
# against payloads that use the target tuple directly as the "id".
63+
(
64+
"langchain_aws",
65+
"chat_models",
66+
"bedrock_converse",
67+
"ChatBedrockConverse",
68+
): _bedrock_validator,
69+
(
70+
"langchain_aws",
71+
"chat_models",
72+
"anthropic",
73+
"ChatAnthropicBedrock",
74+
): _bedrock_validator,
75+
("langchain_aws", "chat_models", "ChatBedrock"): _bedrock_validator,
76+
("langchain_aws", "llms", "bedrock", "BedrockLLM"): _bedrock_validator,
77+
}

0 commit comments

Comments
 (0)