Skip to content

Commit e87f39a

Browse files
authored
feat: annotated_types with CustomField support (#120)
1 parent 3d7e6f8 commit e87f39a

File tree

4 files changed

+57
-34
lines changed

4 files changed

+57
-34
lines changed

fast_depends/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""
22

3-
__version__ = "2.4.7"
3+
__version__ = "2.4.8"

fast_depends/core/build.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,14 @@ def build_call_model(
8585
elif get_origin(param.annotation) is Annotated:
8686
annotated_args = get_args(param.annotation)
8787
type_annotation = annotated_args[0]
88-
custom_annotations = [
89-
arg for arg in annotated_args[1:] if isinstance(arg, CUSTOM_ANNOTATIONS)
90-
]
88+
89+
custom_annotations = []
90+
regular_annotations = []
91+
for arg in annotated_args[1:]:
92+
if isinstance(arg, CUSTOM_ANNOTATIONS):
93+
custom_annotations.append(arg)
94+
else:
95+
regular_annotations.append(arg)
9196

9297
assert (
9398
len(custom_annotations) <= 1
@@ -102,7 +107,10 @@ def build_call_model(
102107
else: # pragma: no cover
103108
raise AssertionError("unreachable")
104109

105-
annotation = type_annotation
110+
if regular_annotations:
111+
annotation = param.annotation
112+
else:
113+
annotation = type_annotation
106114
else:
107115
annotation = param.annotation
108116
else:
@@ -113,23 +121,22 @@ def build_call_model(
113121
default = ()
114122
elif param_name == "kwargs":
115123
default = {}
124+
elif param.default is inspect.Parameter.empty:
125+
default = Ellipsis
116126
else:
117127
default = param.default
118128

119129
if isinstance(default, Depends):
120130
assert (
121131
not dep
122132
), "You can not use `Depends` with `Annotated` and default both"
123-
dep = default
133+
dep, default = default, Ellipsis
124134

125135
elif isinstance(default, CustomField):
126136
assert (
127137
not custom
128138
), "You can not use `CustomField` with `Annotated` and default both"
129-
custom = default
130-
131-
elif default is inspect.Parameter.empty:
132-
class_fields[param_name] = (annotation, ...)
139+
custom, default = default, Ellipsis
133140

134141
else:
135142
class_fields[param_name] = (annotation, default)
@@ -147,7 +154,7 @@ def build_call_model(
147154
)
148155

149156
if dep.cast is True:
150-
class_fields[param_name] = (annotation, ...)
157+
class_fields[param_name] = (annotation, Ellipsis)
151158

152159
keyword_args.append(param_name)
153160

@@ -163,7 +170,7 @@ def build_call_model(
163170
annotation = Any
164171

165172
if custom.required:
166-
class_fields[param_name] = (annotation, ...)
173+
class_fields[param_name] = (annotation, default)
167174

168175
else:
169176
class_fields[param_name] = class_fields.get(param_name, (Optional[annotation], None))
@@ -184,10 +191,10 @@ def build_call_model(
184191

185192
response_model: Optional[Type[ResponseModel[T]]] = None
186193
if cast and return_annotation and return_annotation is not inspect.Parameter.empty:
187-
response_model = create_model(
194+
response_model = create_model( # type: ignore[call-overload]
188195
"ResponseModel",
189-
__config__=get_config_base(pydantic_config), # type: ignore[assignment]
190-
response=(return_annotation, ...),
196+
__config__=get_config_base(pydantic_config),
197+
response=(return_annotation, Ellipsis),
191198
)
192199

193200
return CallModel(

tests/async/test_depends.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async def another_dep_func(b: int, a: int = 3) -> dict: # pragma: no cover
7272

7373
@inject
7474
async def some_func(
75-
b: int, c=Depends(dep_func), d=Depends(another_dep_func)
75+
b: int, c=Depends(dep_func), d=Depends(another_dep_func)
7676
) -> int: # pragma: no cover
7777
assert c is None
7878
return b
@@ -108,17 +108,17 @@ async def dep_func(a):
108108

109109
@inject
110110
async def some_func(
111-
a: int,
112-
b: int,
113-
c: "Annotated[int, Depends(dep_func)]",
111+
a: int,
112+
b: int,
113+
c: "Annotated[int, Depends(dep_func)]",
114114
) -> float:
115115
assert isinstance(c, int)
116116
return a + b + c
117117

118118
@inject
119119
async def another_func(
120-
a: int,
121-
c: "Annotated[int, Depends(dep_func)]",
120+
a: int,
121+
c: "Annotated[int, Depends(dep_func)]",
122122
):
123123
return a + c
124124

@@ -133,17 +133,17 @@ async def adep_func(a):
133133

134134
@inject
135135
async def some_func(
136-
a: int,
137-
b: int,
138-
c: Annotated["float", Depends(adep_func)],
136+
a: int,
137+
b: int,
138+
c: Annotated["float", Depends(adep_func)],
139139
) -> float:
140140
assert isinstance(c, float)
141141
return a + b + c
142142

143143
@inject
144144
async def another_func(
145-
a: int,
146-
c: Annotated["float", Depends(adep_func)],
145+
a: int,
146+
c: Annotated["float", Depends(adep_func)],
147147
):
148148
return a + c
149149

@@ -184,8 +184,8 @@ async def dep_func(a=Depends(nested_dep_func, use_cache=False)):
184184

185185
@inject
186186
async def some_func(
187-
a=Depends(dep_func, use_cache=False),
188-
b=Depends(nested_dep_func, use_cache=False),
187+
a=Depends(dep_func, use_cache=False),
188+
b=Depends(nested_dep_func, use_cache=False),
189189
):
190190
assert a is b
191191
return a + b
@@ -361,9 +361,9 @@ async def get_logger() -> logging.Logger:
361361

362362
@inject
363363
async def some_func(
364-
b,
365-
a: A = Depends(dep, cast=False),
366-
logger: logging.Logger = Depends(get_logger, cast=False),
364+
b,
365+
a: A = Depends(dep, cast=False),
366+
logger: logging.Logger = Depends(get_logger, cast=False),
367367
):
368368
assert a.a == 1
369369
assert logger
@@ -386,9 +386,9 @@ async def get_logger() -> logging.Logger:
386386

387387
@inject(cast=False)
388388
async def some_func(
389-
b: str,
390-
a: A = Depends(dep),
391-
logger: logging.Logger = Depends(get_logger),
389+
b: str,
390+
a: A = Depends(dep),
391+
logger: logging.Logger = Depends(get_logger),
392392
) -> str:
393393
assert a.a == 1
394394
assert logger

tests/library/test_custom.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import anyio
66
import pydantic
77
import pytest
8+
from annotated_types import Ge
89
from typing_extensions import Annotated
910

1011
from fast_depends import Depends, inject
1112
from fast_depends.library import CustomField
13+
from tests.marks import pydanticV2
1214

1315

1416
class Header(CustomField):
@@ -128,6 +130,20 @@ def sync_catch(key: Annotated[int, Header()]):
128130
assert sync_catch(headers={"key": "1"}) == 1
129131

130132

133+
@pydanticV2
134+
def test_annotated_header_with_meta():
135+
@inject
136+
def sync_catch(key: Annotated[int, Header(), Ge(3)] = 3): # noqa: B008
137+
return key
138+
139+
with pytest.raises(pydantic.ValidationError):
140+
sync_catch(headers={"key": "2"})
141+
142+
assert sync_catch(headers={"key": "4"}) == 4
143+
144+
assert sync_catch(headers={}) == 3
145+
146+
131147
def test_header_required():
132148
@inject
133149
def sync_catch(key2=Header()): # pragma: no cover # noqa: B008

0 commit comments

Comments
 (0)