From fe3aa6ee75b5f2a0fe177ade3da6cc3e6e0acef8 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Tue, 20 Dec 2022 10:32:04 +0100 Subject: [PATCH 1/5] Use `ASGIApp` as a type hint for middlewares --- starlette/applications.py | 6 +++--- starlette/middleware/__init__.py | 4 +++- tests/middleware/test_base.py | 3 ++- tests/middleware/test_middleware.py | 4 +++- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 013364be3..0b4b31619 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -134,9 +134,9 @@ def host( ) -> None: # pragma: no cover self.router.host(host, app=app, name=name) - def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: - if self.middleware_stack is not None: # pragma: no cover - raise RuntimeError("Cannot add middleware after an application has started") + def add_middleware( + self, middleware_class: typing.Type[ASGIApp], **options: typing.Any + ) -> None: # pragma: no cover self.user_middleware.insert(0, Middleware(middleware_class, **options)) def add_exception_handler( diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 5ac5b96c8..38037dd76 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -1,8 +1,10 @@ import typing +from starlette.types import ASGIApp + class Middleware: - def __init__(self, cls: type, **options: typing.Any) -> None: + def __init__(self, cls: typing.Type[ASGIApp], **options: typing.Any) -> None: self.cls = cls self.options = options diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index ed0734bd3..ccf1cebcd 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,4 +1,5 @@ import contextvars +import typing from contextlib import AsyncExitStack import anyio @@ -193,7 +194,7 @@ async def dispatch(self, request, call_next): ), ], ) -def test_contextvars(test_client_factory, middleware_cls: type): +def test_contextvars(test_client_factory, middleware_cls: typing.Type[ASGIApp]): # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) diff --git a/tests/middleware/test_middleware.py b/tests/middleware/test_middleware.py index f4d7a32f0..3a0fa0131 100644 --- a/tests/middleware/test_middleware.py +++ b/tests/middleware/test_middleware.py @@ -1,8 +1,10 @@ from starlette.middleware import Middleware +from starlette.types import Receive, Scope, Send class CustomMiddleware: - pass + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + return None # pragma: no cover def test_middleware_repr(): From 119118885cc83078149cb059c808cc3b1d4a32ce Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sat, 1 Apr 2023 19:33:53 +0200 Subject: [PATCH 2/5] Define `MiddlewareType` to match both classes and callables --- starlette/applications.py | 4 ++-- starlette/middleware/__init__.py | 4 ++-- starlette/types.py | 18 ++++++++++++++++++ tests/middleware/test_base.py | 5 ++--- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 0b4b31619..66cd8cd3e 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router -from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send +from starlette.types import ASGIApp, Lifespan, MiddlewareType, Receive, Scope, Send AppType = typing.TypeVar("AppType", bound="Starlette") @@ -135,7 +135,7 @@ def host( self.router.host(host, app=app, name=name) def add_middleware( - self, middleware_class: typing.Type[ASGIApp], **options: typing.Any + self, middleware_class: MiddlewareType, **options: typing.Any ) -> None: # pragma: no cover self.user_middleware.insert(0, Middleware(middleware_class, **options)) diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 38037dd76..3e7db265b 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -1,10 +1,10 @@ import typing -from starlette.types import ASGIApp +from starlette.types import MiddlewareType class Middleware: - def __init__(self, cls: typing.Type[ASGIApp], **options: typing.Any) -> None: + def __init__(self, cls: MiddlewareType, **options: typing.Any) -> None: self.cls = cls self.options = options diff --git a/starlette/types.py b/starlette/types.py index 713d18a80..0f94b9f4c 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -1,5 +1,11 @@ +import sys import typing +if sys.version_info < (3, 8): # pragma: no cover + from typing_extensions import Protocol +else: # pragma: no cover + from typing import Protocol + AppType = typing.TypeVar("AppType") Scope = typing.MutableMapping[str, typing.Any] @@ -15,3 +21,15 @@ [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] ] Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] + + +# This callable protocol can both be used to represent a function returning +# an ASGIApp, or a class with an __init__ method matching this __call__ signature +# and a __call__ method matching the ASGIApp signature. +class MiddlewareType(Protocol): + __name__: str + + def __call__( + self, *args: typing.Any, **kwargs: typing.Any + ) -> ASGIApp: # pragma: no cover + ... diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index ccf1cebcd..ba2f557d3 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,5 +1,4 @@ import contextvars -import typing from contextlib import AsyncExitStack import anyio @@ -11,7 +10,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse, StreamingResponse from starlette.routing import Route, WebSocketRoute -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, MiddlewareType, Receive, Scope, Send class CustomMiddleware(BaseHTTPMiddleware): @@ -194,7 +193,7 @@ async def dispatch(self, request, call_next): ), ], ) -def test_contextvars(test_client_factory, middleware_cls: typing.Type[ASGIApp]): +def test_contextvars(test_client_factory, middleware_cls: MiddlewareType): # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) From 6ef22f4b4234d2b7dc017e65ff6af695108d7cb4 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 23 Apr 2023 18:06:37 +0200 Subject: [PATCH 3/5] failed rebase --- starlette/applications.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 66cd8cd3e..8edc2250f 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -134,9 +134,9 @@ def host( ) -> None: # pragma: no cover self.router.host(host, app=app, name=name) - def add_middleware( - self, middleware_class: MiddlewareType, **options: typing.Any - ) -> None: # pragma: no cover + def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: + if self.middleware_stack is not None: # pragma: no cover + raise RuntimeError("Cannot add middleware after an application has started") self.user_middleware.insert(0, Middleware(middleware_class, **options)) def add_exception_handler( From 6a532614e1b3910ffe2ce5307bd9a490314edf58 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 23 Apr 2023 18:09:29 +0200 Subject: [PATCH 4/5] Missing type --- starlette/applications.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/applications.py b/starlette/applications.py index 8edc2250f..dbb5495c9 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -134,7 +134,7 @@ def host( ) -> None: # pragma: no cover self.router.host(host, app=app, name=name) - def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: + def add_middleware(self, middleware_class: MiddlewareType, **options: typing.Any) -> None: if self.middleware_stack is not None: # pragma: no cover raise RuntimeError("Cannot add middleware after an application has started") self.user_middleware.insert(0, Middleware(middleware_class, **options)) From 46e1414601eab0ea738889829ced088431d7a024 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Sun, 23 Apr 2023 18:16:15 +0200 Subject: [PATCH 5/5] black --- starlette/applications.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/starlette/applications.py b/starlette/applications.py index dbb5495c9..1c4789e12 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -134,7 +134,9 @@ def host( ) -> None: # pragma: no cover self.router.host(host, app=app, name=name) - def add_middleware(self, middleware_class: MiddlewareType, **options: typing.Any) -> None: + def add_middleware( + self, middleware_class: MiddlewareType, **options: typing.Any + ) -> None: if self.middleware_stack is not None: # pragma: no cover raise RuntimeError("Cannot add middleware after an application has started") self.user_middleware.insert(0, Middleware(middleware_class, **options))