Skip to content

Commit 1191188

Browse files
committed
Define MiddlewareType to match both classes and callables
1 parent fe3aa6e commit 1191188

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

starlette/applications.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from starlette.requests import Request
1010
from starlette.responses import Response
1111
from starlette.routing import BaseRoute, Router
12-
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
12+
from starlette.types import ASGIApp, Lifespan, MiddlewareType, Receive, Scope, Send
1313

1414
AppType = typing.TypeVar("AppType", bound="Starlette")
1515

@@ -135,7 +135,7 @@ def host(
135135
self.router.host(host, app=app, name=name)
136136

137137
def add_middleware(
138-
self, middleware_class: typing.Type[ASGIApp], **options: typing.Any
138+
self, middleware_class: MiddlewareType, **options: typing.Any
139139
) -> None: # pragma: no cover
140140
self.user_middleware.insert(0, Middleware(middleware_class, **options))
141141

starlette/middleware/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import typing
22

3-
from starlette.types import ASGIApp
3+
from starlette.types import MiddlewareType
44

55

66
class Middleware:
7-
def __init__(self, cls: typing.Type[ASGIApp], **options: typing.Any) -> None:
7+
def __init__(self, cls: MiddlewareType, **options: typing.Any) -> None:
88
self.cls = cls
99
self.options = options
1010

starlette/types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
import sys
12
import typing
23

4+
if sys.version_info < (3, 8): # pragma: no cover
5+
from typing_extensions import Protocol
6+
else: # pragma: no cover
7+
from typing import Protocol
8+
39
AppType = typing.TypeVar("AppType")
410

511
Scope = typing.MutableMapping[str, typing.Any]
@@ -15,3 +21,15 @@
1521
[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
1622
]
1723
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
24+
25+
26+
# This callable protocol can both be used to represent a function returning
27+
# an ASGIApp, or a class with an __init__ method matching this __call__ signature
28+
# and a __call__ method matching the ASGIApp signature.
29+
class MiddlewareType(Protocol):
30+
__name__: str
31+
32+
def __call__(
33+
self, *args: typing.Any, **kwargs: typing.Any
34+
) -> ASGIApp: # pragma: no cover
35+
...

tests/middleware/test_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextvars
2-
import typing
32
from contextlib import AsyncExitStack
43

54
import anyio
@@ -11,7 +10,7 @@
1110
from starlette.middleware.base import BaseHTTPMiddleware
1211
from starlette.responses import PlainTextResponse, StreamingResponse
1312
from starlette.routing import Route, WebSocketRoute
14-
from starlette.types import ASGIApp, Receive, Scope, Send
13+
from starlette.types import ASGIApp, MiddlewareType, Receive, Scope, Send
1514

1615

1716
class CustomMiddleware(BaseHTTPMiddleware):
@@ -194,7 +193,7 @@ async def dispatch(self, request, call_next):
194193
),
195194
],
196195
)
197-
def test_contextvars(test_client_factory, middleware_cls: typing.Type[ASGIApp]):
196+
def test_contextvars(test_client_factory, middleware_cls: MiddlewareType):
198197
# this has to be an async endpoint because Starlette calls run_in_threadpool
199198
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
200199
# contextvars (it propagates them forwards but not backwards)

0 commit comments

Comments
 (0)