Skip to content

Commit e5fd485

Browse files
committed
support args and kwargs
1 parent e83b1b8 commit e5fd485

File tree

6 files changed

+184
-18
lines changed

6 files changed

+184
-18
lines changed

fast_depends/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""
22

3-
__version__ = "2.0.5"
3+
__version__ = "2.1.0"

fast_depends/core/build.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
import inspect
2-
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union
2+
from typing import (
3+
Any,
4+
Awaitable,
5+
Callable,
6+
Dict,
7+
List,
8+
Optional,
9+
Sequence,
10+
Tuple,
11+
Union,
12+
)
313

414
from typing_extensions import (
515
Annotated,
@@ -49,7 +59,14 @@ def build_call_model(
4959
class_fields: Dict[str, Tuple[Any, Any]] = {}
5060
dependencies: Dict[str, CallModel[..., Any]] = {}
5161
custom_fields: Dict[str, CustomField] = {}
62+
positional_args: List[str] = []
63+
keyword_args: List[str] = []
5264
for param in typed_params:
65+
if param.kind is param.KEYWORD_ONLY:
66+
keyword_args.append(param.name)
67+
elif param.name not in ("args", "kwargs"):
68+
positional_args.append(param.name)
69+
5370
dep: Optional[Depends] = None
5471
custom: Optional[CustomField] = None
5572

@@ -82,7 +99,13 @@ def build_call_model(
8299
else:
83100
annotation = param.annotation
84101

85-
default = param.default
102+
if param.name == "args":
103+
default = ()
104+
elif param.name == "kwargs":
105+
default = {}
106+
else:
107+
default = param.default
108+
86109
if isinstance(default, Depends):
87110
assert (
88111
not dep
@@ -146,6 +169,8 @@ def build_call_model(
146169
is_async=is_call_async,
147170
dependencies=dependencies,
148171
custom_fields=custom_fields,
172+
positional_args=positional_args,
173+
keyword_args=keyword_args,
149174
extra_dependencies=[
150175
build_call_model(
151176
d.dependency,

fast_depends/core/model.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Iterable,
1010
List,
1111
Optional,
12+
Tuple,
1213
Type,
1314
Union,
1415
)
@@ -47,11 +48,30 @@ class CallModel(Generic[P, T]):
4748
dependencies: Dict[str, "CallModel[..., Any]"]
4849
extra_dependencies: Iterable["CallModel[..., Any]"]
4950
custom_fields: Dict[str, CustomField]
51+
keyword_args: Tuple[str]
52+
positional_args: Tuple[str]
5053

5154
# Dependencies and custom fields
5255
use_cache: bool
5356
cast: bool
5457

58+
__slots__ = (
59+
"call",
60+
"is_async",
61+
"is_generator",
62+
"model",
63+
"response_model",
64+
"params",
65+
"alias_arguments",
66+
"keyword_args",
67+
"positional_args",
68+
"dependencies",
69+
"extra_dependencies",
70+
"custom_fields",
71+
"use_cache",
72+
"cast",
73+
)
74+
5575
@property
5676
def call_name(self) -> str:
5777
return getattr(self.call, "__name__", type(self.call).__name__)
@@ -83,6 +103,8 @@ def __init__(
83103
is_async: bool = False,
84104
dependencies: Optional[Dict[str, "CallModel[..., Any]"]] = None,
85105
extra_dependencies: Optional[Iterable["CallModel[..., Any]"]] = None,
106+
keyword_args: Optional[List[str]] = None,
107+
positional_args: Optional[List[str]] = None,
86108
custom_fields: Optional[Dict[str, CustomField]] = None,
87109
):
88110
self.call = call
@@ -100,6 +122,8 @@ def __init__(
100122
self.custom_fields = custom_fields or {}
101123

102124
self.alias_arguments = [f.alias or name for name, f in fields.items()]
125+
self.keyword_args = tuple(keyword_args or [])
126+
self.positional_args = tuple(positional_args or [])
103127

104128
self.params = fields.copy()
105129
for name in self.dependencies.keys():
@@ -152,13 +176,20 @@ def _solve(
152176

153177
casted_model = self.model(**solved_kw)
154178

155-
casted_kw = {
179+
args_ = [
180+
getattr(casted_model, arg, solved_kw.get(arg))
181+
for arg in self.positional_args
182+
]
183+
args_.extend(getattr(casted_model, "args", ()))
184+
185+
kwargs_ = {
156186
arg: getattr(casted_model, arg, solved_kw.get(arg))
157-
for arg in (*self.params.keys(), *self.dependencies.keys())
187+
for arg in self.keyword_args
158188
}
189+
kwargs_.update(getattr(casted_model, "kwargs", {}))
159190

160191
response: T
161-
response = yield casted_kw
192+
response = yield args_, kwargs_
162193

163194
if self.cast is True and self.response_model is not None:
164195
casted_resp = self.response_model(response=response)
@@ -225,16 +256,17 @@ def solve(
225256
for custom in self.custom_fields.values():
226257
kwargs = custom.use(**kwargs)
227258

228-
final_kw = cast_gen.send(kwargs)
259+
final_args, final_kwargs = cast_gen.send(kwargs)
229260

230261
if self.is_generator:
231262
response = solve_generator_sync(
263+
*final_args,
232264
call=self.call,
233265
stack=stack,
234-
**final_kw,
266+
**final_kwargs,
235267
)
236268
else:
237-
response = self.call(**final_kw)
269+
response = self.call(*final_args, **final_kwargs)
238270

239271
try:
240272
cast_gen.send(response)
@@ -300,16 +332,17 @@ async def asolve(
300332
for custom in self.custom_fields.values():
301333
kwargs = await run_async(custom.use, **kwargs)
302334

303-
final_kw = cast_gen.send(kwargs)
335+
final_args, final_kwargs = cast_gen.send(kwargs)
304336

305337
if self.is_generator:
306338
response = await solve_generator_async(
339+
*final_args,
307340
call=self.call,
308341
stack=stack,
309-
**final_kw,
342+
**final_kwargs,
310343
)
311344
else:
312-
response = await run_async(self.call, **final_kw)
345+
response = await run_async(self.call, *final_args, **final_kwargs)
313346

314347
try:
315348
cast_gen.send(response)

fast_depends/utils.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,38 +48,53 @@ async def run_in_threadpool(
4848

4949

5050
async def solve_generator_async(
51-
*, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any
51+
*sub_args: Any, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any
5252
) -> Any:
5353
if is_gen_callable(call):
5454
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
5555
elif is_async_gen_callable(call): # pragma: no branch
56-
cm = asynccontextmanager(call)(**sub_values)
56+
cm = asynccontextmanager(call)(*sub_args, **sub_values)
5757
return await stack.enter_async_context(cm)
5858

5959

6060
def solve_generator_sync(
61-
*, call: Callable[..., Any], stack: ExitStack, **sub_values: Any
61+
*sub_args: Any, call: Callable[..., Any], stack: ExitStack, **sub_values: Any
6262
) -> Any:
63-
cm = contextmanager(call)(**sub_values)
63+
cm = contextmanager(call)(*sub_args, **sub_values)
6464
return stack.enter_context(cm)
6565

6666

6767
def args_to_kwargs(
6868
arguments: Iterable[str], *args: Any, **kwargs: Any
6969
) -> Dict[str, Any]:
70+
arguments = tuple(filter(lambda i: i not in ("args", "kwargs"), arguments))
71+
7072
if not args:
7173
return kwargs
7274

73-
unused = filter(lambda x: x not in kwargs, arguments)
75+
merged = {"kwargs": kwargs.get("kwargs", {})}
76+
77+
for arg, v in kwargs.items():
78+
if arg not in arguments:
79+
merged["kwargs"][arg] = v
80+
else:
81+
merged[arg] = v
82+
83+
for arg in filter(lambda x: x not in merged, arguments):
84+
if args:
85+
merged[arg], args = args[0], args[1:]
7486

75-
return dict((*zip(unused, args), *kwargs.items()))
87+
merged["args"] = args
88+
89+
return merged
7690

7791

7892
def get_typed_signature(
7993
call: Callable[..., Any]
8094
) -> Tuple[List[inspect.Parameter], Any]:
8195
signature = inspect.signature(call)
8296
globalns = getattr(call, "__globals__", {})
97+
8398
return [
8499
inspect.Parameter(
85100
name=param.name,

tests/async/test_cast.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple, Dict
2+
13
import pytest
24
from pydantic import BaseModel, Field, ValidationError
35
from typing_extensions import Annotated
@@ -104,3 +106,50 @@ async def some_func(a: A) -> float:
104106
return a
105107

106108
assert isinstance(await some_func(b="2"), float)
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_args_kwargs_1():
113+
@inject
114+
async def simple_func(
115+
a: int,
116+
*args: Tuple[float, ...],
117+
b: int,
118+
**kwargs: Dict[str, int],
119+
):
120+
return a, args, b, kwargs
121+
122+
assert (1, (2.0, 3.0), 3, {"key": 1}) == await simple_func(
123+
1.0, 2.0, 3, b=3.0, key=1.0
124+
)
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_args_kwargs_2():
129+
@inject
130+
async def simple_func(
131+
a: int,
132+
*args: Tuple[float, ...],
133+
b: int,
134+
):
135+
return a, args, b
136+
137+
assert (1, (2.0, 3.0), 3) == await simple_func(
138+
1.0,
139+
2.0,
140+
3,
141+
b=3.0,
142+
)
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_args_kwargs_3():
147+
@inject
148+
async def simple_func(a: int, *, b: int):
149+
return a, b
150+
151+
assert (1, 3) == await simple_func(
152+
1.0,
153+
4,
154+
b=3.0,
155+
)

tests/sync/test_cast.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple, Dict
2+
13
import pytest
24
from pydantic import BaseModel, Field, ValidationError
35
from typing_extensions import Annotated
@@ -107,3 +109,45 @@ def some_func(a: A) -> float:
107109
return a
108110

109111
assert isinstance(some_func(b="2"), float)
112+
113+
114+
def test_args_kwargs_1():
115+
@inject
116+
def simple_func(
117+
a: int,
118+
*args: Tuple[float, ...],
119+
b: int,
120+
**kwargs: Dict[str, int],
121+
):
122+
return a, args, b, kwargs
123+
124+
assert (1, (2.0, 3.0), 3, {"key": 1}) == simple_func(1.0, 2.0, 3, b=3.0, key=1.0)
125+
126+
127+
def test_args_kwargs_2():
128+
@inject
129+
def simple_func(
130+
a: int,
131+
*args: Tuple[float, ...],
132+
b: int,
133+
):
134+
return a, args, b
135+
136+
assert (1, (2.0, 3.0), 3) == simple_func(
137+
1.0,
138+
2.0,
139+
3,
140+
b=3.0,
141+
)
142+
143+
144+
def test_args_kwargs_3():
145+
@inject
146+
def simple_func(a: int, *, b: int):
147+
return a, b
148+
149+
assert (1, 3) == simple_func(
150+
1.0,
151+
4,
152+
b=3.0,
153+
)

0 commit comments

Comments
 (0)