Skip to content

feat (#34): inject to classes #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fast_depends/__about__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System"""

__version__ = "2.2.3"
__version__ = "2.2.4"
146 changes: 73 additions & 73 deletions fast_depends/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,25 @@ class CallModel(Generic[P, T]):
)

@property
def call_name(self) -> str:
return getattr(self.call, "__name__", type(self.call).__name__)
def call_name(self__) -> str:
return getattr(self__.call, "__name__", type(self__.call).__name__)

@property
def real_params(self) -> Dict[str, FieldInfo]:
params = self.params.copy()
for name in self.custom_fields.keys():
def real_params(self__) -> Dict[str, FieldInfo]:
params = self__.params.copy()
for name in self__.custom_fields.keys():
params.pop(name, None)
return params

@property
def flat_params(self) -> Dict[str, FieldInfo]:
params = self.real_params
for d in self.dependencies.values():
def flat_params(self__) -> Dict[str, FieldInfo]:
params = self__.real_params
for d in self__.dependencies.values():
params.update(d.flat_params)
return params

def __init__(
self,
self__,
call: Union[
Callable[P, T],
Callable[P, Awaitable[T]],
Expand All @@ -112,39 +112,39 @@ def __init__(
positional_args: Optional[List[str]] = None,
custom_fields: Optional[Dict[str, CustomField]] = None,
):
self.call = call
self.model = model
self.response_model = response_model
self__.call = call
self__.model = model
self__.response_model = response_model

fields: Dict[str, FieldInfo]
if PYDANTIC_V2:
fields = self.model.model_fields
fields = self__.model.model_fields
else:
fields = self.model.__fields__ # type: ignore
fields = self__.model.__fields__ # type: ignore

self.dependencies = dependencies or {}
self.extra_dependencies = extra_dependencies or []
self.custom_fields = custom_fields or {}
self__.dependencies = dependencies or {}
self__.extra_dependencies = extra_dependencies or []
self__.custom_fields = custom_fields or {}

self.alias_arguments = [f.alias or name for name, f in fields.items()]
self.keyword_args = tuple(keyword_args or ())
self.positional_args = tuple(positional_args or ())
self__.alias_arguments = [f.alias or name for name, f in fields.items()]
self__.keyword_args = tuple(keyword_args or ())
self__.positional_args = tuple(positional_args or ())

self.params = fields.copy()
for name in self.dependencies.keys():
self.params.pop(name, None)
self__.params = fields.copy()
for name in self__.dependencies.keys():
self__.params.pop(name, None)

self.use_cache = use_cache
self.cast = cast
self.is_async = (
is_async or is_coroutine_callable(call) or is_async_gen_callable(self.call)
self__.use_cache = use_cache
self__.cast = cast
self__.is_async = (
is_async or is_coroutine_callable(call) or is_async_gen_callable(self__.call)
)
self.is_generator = is_gen_callable(self.call) or is_async_gen_callable(
self.call
self__.is_generator = is_gen_callable(self__.call) or is_async_gen_callable(
self__.call
)

def _solve(
self,
self__,
*args: P.args,
cache_dependencies: Dict[
Union[
Expand All @@ -168,56 +168,56 @@ def _solve(
**kwargs: P.kwargs,
) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], Any, T]:
if dependency_overrides:
self.call = dependency_overrides.get(self.call, self.call)
assert self.is_async or not is_coroutine_callable(
self.call
), f"You cannot use async dependency `{self.call_name}` at sync main"
self__.call = dependency_overrides.get(self__.call, self__.call)
assert self__.is_async or not is_coroutine_callable(
self__.call
), f"You cannot use async dependency `{self__.call_name}` at sync main"

if self.use_cache and self.call in cache_dependencies:
return cache_dependencies[self.call]
if self__.use_cache and self__.call in cache_dependencies:
return cache_dependencies[self__.call]

kw = {}

for arg in self.keyword_args:
for arg in self__.keyword_args:
v = kwargs.pop(arg, inspect._empty)
if v is not inspect._empty:
kw[arg] = v

if "kwargs" in self.alias_arguments:
if "kwargs" in self__.alias_arguments:
kw["kwargs"] = kwargs

else:
kw.update(kwargs)

has_args = "args" in self.alias_arguments
has_args = "args" in self__.alias_arguments

for arg in self.positional_args:
for arg in self__.positional_args:
if args:
kw[arg], args = args[0], args[1:]

if has_args:
kw["args"] = args

else:
for arg in self.keyword_args:
for arg in self__.keyword_args:
if args:
kw[arg], args = args[0], args[1:]

solved_kw: Dict[str, Any]
solved_kw = yield (), kw

casted_model: object
if self.cast:
casted_model = self.model(**solved_kw)
if self__.cast:
casted_model = self__.model(**solved_kw)
else:
casted_model = object()

kwargs_ = {
arg: getattr(casted_model, arg, solved_kw.get(arg))
for arg in (
self.keyword_args + self.positional_args
self__.keyword_args + self__.positional_args
if not has_args
else self.keyword_args
else self__.keyword_args
)
}
kwargs_.update(getattr(casted_model, "kwargs", {}))
Expand All @@ -226,7 +226,7 @@ def _solve(
if has_args:
args_ = [
getattr(casted_model, arg, solved_kw.get(arg))
for arg in self.positional_args
for arg in self__.positional_args
]
args_.extend(getattr(casted_model, "args", ()))
else:
Expand All @@ -235,22 +235,22 @@ def _solve(
response: T
response = yield args_, kwargs_

if self.cast and not self.is_generator:
response = self._cast_response(response)
if self__.cast and not self__.is_generator:
response = self__._cast_response(response)

if self.use_cache: # pragma: no branch
cache_dependencies[self.call] = response
if self__.use_cache: # pragma: no branch
cache_dependencies[self__.call] = response

return response

def _cast_response(self, value: Any) -> Any:
if self.response_model is not None and self.cast:
return self.response_model(response=value).response
def _cast_response(self__, value: Any) -> Any:
if self__.response_model is not None and self__.cast:
return self__.response_model(response=value).response
else:
return value

def solve(
self,
self__,
*args: P.args,
stack: ExitStack,
cache_dependencies: Dict[
Expand All @@ -275,7 +275,7 @@ def solve(
nested: bool = False,
**kwargs: P.kwargs,
) -> T:
cast_gen = self._solve(
cast_gen = self__._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
Expand All @@ -287,7 +287,7 @@ def solve(
cached_value: T = e.value
return cached_value

for dep in self.extra_dependencies:
for dep in self__.extra_dependencies:
dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -296,7 +296,7 @@ def solve(
**kwargs,
)

for dep_arg, dep in self.dependencies.items():
for dep_arg, dep in self__.dependencies.items():
kwargs[dep_arg] = dep.solve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -305,37 +305,37 @@ def solve(
**kwargs,
)

for custom in self.custom_fields.values():
for custom in self__.custom_fields.values():
kwargs = custom.use(**kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)

if self.is_generator and nested:
if self__.is_generator and nested:
response = solve_generator_sync(
*final_args,
call=self.call,
call=self__.call,
stack=stack,
**final_kwargs,
)

else:
response = self.call(*final_args, **final_kwargs)
response = self__.call(*final_args, **final_kwargs)

try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value

if not self.cast or nested or not self.is_generator:
if not self__.cast or nested or not self__.is_generator:
return value

else:
return map(self._cast_response, value) # type: ignore[no-any-return, call-overload]
return map(self__._cast_response, value) # type: ignore[no-any-return, call-overload]

assert_never(response) # pragma: no cover

async def asolve(
self,
self__,
*args: P.args,
stack: AsyncExitStack,
cache_dependencies: Dict[
Expand All @@ -360,7 +360,7 @@ async def asolve(
nested: bool = False,
**kwargs: P.kwargs,
) -> T:
cast_gen = self._solve(
cast_gen = self__._solve(
*args,
cache_dependencies=cache_dependencies,
dependency_overrides=dependency_overrides,
Expand All @@ -372,7 +372,7 @@ async def asolve(
cached_value: T = e.value
return cached_value

for dep in self.extra_dependencies:
for dep in self__.extra_dependencies:
await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -381,7 +381,7 @@ async def asolve(
**kwargs,
)

for dep_arg, dep in self.dependencies.items():
for dep_arg, dep in self__.dependencies.items():
kwargs[dep_arg] = await dep.asolve(
stack=stack,
cache_dependencies=cache_dependencies,
Expand All @@ -390,30 +390,30 @@ async def asolve(
**kwargs,
)

for custom in self.custom_fields.values():
for custom in self__.custom_fields.values():
kwargs = await run_async(custom.use, **kwargs)

final_args, final_kwargs = cast_gen.send(kwargs)

if self.is_generator and nested:
if self__.is_generator and nested:
response = await solve_generator_async(
*final_args,
call=self.call,
call=self__.call,
stack=stack,
**final_kwargs,
)
else:
response = await run_async(self.call, *final_args, **final_kwargs)
response = await run_async(self__.call, *final_args, **final_kwargs)

try:
cast_gen.send(response)
except StopIteration as e:
value: T = e.value

if not self.cast or nested or not self.is_generator:
if not self__.cast or nested or not self__.is_generator:
return value

else:
return async_map(self._cast_response, value) # type: ignore[return-value, arg-type]
return async_map(self__._cast_response, value) # type: ignore[return-value, arg-type]

assert_never(response) # pragma: no cover
22 changes: 22 additions & 0 deletions tests/async/test_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from fast_depends import Depends, inject


def _get_var():
return 1


class Class:
@inject
def __init__(self, a = Depends(_get_var)) -> None:
self.a = a

@inject
async def calc(self, a = Depends(_get_var)) -> int:
return a + self.a


@pytest.mark.anyio
async def test_class():
assert await Class().calc() == 2
19 changes: 19 additions & 0 deletions tests/sync/test_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fast_depends import Depends, inject


def _get_var():
return 1


class Class:
@inject
def __init__(self, a = Depends(_get_var)) -> None:
self.a = a

@inject
def calc(self, a = Depends(_get_var)) -> int:
return a + self.a


def test_class():
assert Class().calc() == 2