Skip to content

Commit 504ca53

Browse files
authored
Use core loop for reverse proxy async IO operations (#675)
* Make reverse proxy plugin use proxy.py core loop for async io operations * Address lint errors * Deprecate on_websocket_close and replace with on_client_connection_close * Lint fixes * Retry on SSLWantReadError and SSLWantWriteError
1 parent 880c3c8 commit 504ca53

File tree

9 files changed

+154
-28
lines changed

9 files changed

+154
-28
lines changed

proxy/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
DEFAULT_TIMEOUT = 10
8383
DEFAULT_VERSION = False
8484
DEFAULT_HTTP_PORT = 80
85+
DEFAULT_HTTPS_PORT = 443
8586
DEFAULT_MAX_SEND_SIZE = 16 * 1024
8687

8788
DEFAULT_DATA_DIRECTORY_PATH = os.path.join(str(pathlib.Path.home()), '.proxy')

proxy/core/acceptor/threadless.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def run(self) -> None:
198198
self.loop = asyncio.get_event_loop()
199199
while not self.running.is_set():
200200
self.run_once()
201+
except KeyboardInterrupt:
202+
pass
201203
finally:
202204
assert self.selector is not None
203205
self.selector.unregister(self.client_queue)

proxy/dashboard/dashboard.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def handle_request(self, request: HttpParser) -> None:
6363
if request.path == b'/dashboard/':
6464
self.client.queue(
6565
HttpWebServerPlugin.read_and_build_static_file_response(
66-
os.path.join(self.flags.static_server_dir, 'dashboard', 'proxy.html'),
66+
os.path.join(
67+
self.flags.static_server_dir,
68+
'dashboard', 'proxy.html',
69+
),
6770
),
6871
)
6972
elif request.path in (
@@ -105,7 +108,7 @@ def on_websocket_message(self, frame: WebsocketFrame) -> None:
105108
logger.info(frame.opcode)
106109
self.reply({'id': message['id'], 'response': 'not_implemented'})
107110

108-
def on_websocket_close(self) -> None:
111+
def on_client_connection_close(self) -> None:
109112
logger.info('app ws closed')
110113
# TODO(abhinavsingh): unsubscribe
111114

proxy/http/inspector/devtools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def handle_request(self, request: HttpParser) -> None:
6060

6161
def on_websocket_open(self) -> None:
6262
self.subscriber.subscribe(
63-
lambda event: CoreEventsToDevtoolsProtocol.transformer(self.client, event),
63+
lambda event: CoreEventsToDevtoolsProtocol.transformer(
64+
self.client, event,
65+
),
6466
)
6567

6668
def on_websocket_message(self, frame: WebsocketFrame) -> None:
@@ -73,7 +75,7 @@ def on_websocket_message(self, frame: WebsocketFrame) -> None:
7375
return
7476
self.handle_devtools_message(message)
7577

76-
def on_websocket_close(self) -> None:
78+
def on_client_connection_close(self) -> None:
7779
self.subscriber.unsubscribe()
7880

7981
def handle_devtools_message(self, message: Dict[str, Any]) -> None:

proxy/http/server/pac_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def on_websocket_open(self) -> None:
6262
def on_websocket_message(self, frame: WebsocketFrame) -> None:
6363
pass # pragma: no cover
6464

65-
def on_websocket_close(self) -> None:
65+
def on_client_connection_close(self) -> None:
6666
pass # pragma: no cover
6767

6868
def cache_pac_file_response(self) -> None:

proxy/http/server/plugin.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ def __init__(
3838
self.client = client
3939
self.event_queue = event_queue
4040

41+
def name(self) -> str:
42+
"""A unique name for your plugin.
43+
44+
Defaults to name of the class. This helps plugin developers to directly
45+
access a specific plugin by its name."""
46+
return self.__class__.__name__ # pragma: no cover
47+
4148
# TODO(abhinavsingh): get_descriptors, write_to_descriptors, read_from_descriptors
4249
# can be placed into their own abstract class which can then be shared by
4350
# HttpProxyBasePlugin, HttpWebServerBasePlugin and HttpProtocolHandlerPlugin class.
@@ -79,6 +86,10 @@ def handle_request(self, request: HttpParser) -> None:
7986
"""Handle the request and serve response."""
8087
raise NotImplementedError() # pragma: no cover
8188

89+
def on_client_connection_close(self) -> None:
90+
"""Client has closed the connection, do any clean up task now."""
91+
pass
92+
8293
@abstractmethod
8394
def on_websocket_open(self) -> None:
8495
"""Called when websocket handshake has finished."""
@@ -89,7 +100,14 @@ def on_websocket_message(self, frame: WebsocketFrame) -> None:
89100
"""Handle websocket frame."""
90101
raise NotImplementedError() # pragma: no cover
91102

92-
@abstractmethod
93-
def on_websocket_close(self) -> None:
94-
"""Called when websocket connection has been closed."""
95-
raise NotImplementedError() # pragma: no cover
103+
# Deprecated since v2.4.0
104+
#
105+
# Instead use on_client_connection_close.
106+
#
107+
# This callback is no longer invoked. Kindly
108+
# update your plugin before upgrading to v2.4.0.
109+
#
110+
# @abstractmethod
111+
# def on_websocket_close(self) -> None:
112+
# """Called when websocket connection has been closed."""
113+
# raise NotImplementedError() # pragma: no cover

proxy/http/server/web.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
}
8484
self.route: Optional[HttpWebServerBasePlugin] = None
8585

86+
self.plugins: Dict[str, HttpWebServerBasePlugin] = {}
8687
if b'HttpWebServerBasePlugin' in self.flags.plugins:
8788
for klass in self.flags.plugins[b'HttpWebServerBasePlugin']:
8889
instance: HttpWebServerBasePlugin = klass(
@@ -91,6 +92,7 @@ def __init__(
9192
self.client,
9293
self.event_queue,
9394
)
95+
self.plugins[instance.name()] = instance
9496
for (protocol, route) in instance.routes():
9597
self.routes[protocol][re.compile(route)] = instance
9698

@@ -201,16 +203,28 @@ def on_request_complete(self) -> Union[socket.socket, bool]:
201203
self.client.queue(self.DEFAULT_404_RESPONSE)
202204
return True
203205

204-
# TODO(abhinavsingh): Call plugin get/read/write descriptor callbacks
205206
def get_descriptors(
206207
self,
207208
) -> Tuple[List[socket.socket], List[socket.socket]]:
208-
return [], []
209+
r, w = [], []
210+
for plugin in self.plugins.values():
211+
r1, w1 = plugin.get_descriptors()
212+
r.extend(r1)
213+
w.extend(w1)
214+
return r, w
209215

210216
def write_to_descriptors(self, w: Writables) -> bool:
217+
for plugin in self.plugins.values():
218+
teardown = plugin.write_to_descriptors(w)
219+
if teardown:
220+
return True
211221
return False
212222

213223
def read_from_descriptors(self, r: Readables) -> bool:
224+
for plugin in self.plugins.values():
225+
teardown = plugin.read_from_descriptors(r)
226+
if teardown:
227+
return True
214228
return False
215229

216230
def on_client_data(self, raw: memoryview) -> Optional[memoryview]:
@@ -260,12 +274,12 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]:
260274
def on_client_connection_close(self) -> None:
261275
if self.request.has_host():
262276
return
263-
if self.switched_protocol:
264-
# Invoke plugin.on_websocket_close
265-
assert self.route
266-
self.route.on_websocket_close()
277+
if self.route:
278+
self.route.on_client_connection_close()
267279
self.access_log()
268280

281+
# TODO: Allow plugins to customize access_log, similar
282+
# to how proxy server plugins are able to do it.
269283
def access_log(self) -> None:
270284
logger.info(
271285
'%s:%s - %s %s - %.2f ms' %

proxy/plugin/reverse_proxy.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,36 @@
88
:copyright: (c) 2013-present by Abhinav Singh and contributors.
99
:license: BSD, see LICENSE for more details.
1010
"""
11+
import ssl
1112
import random
12-
from typing import List, Tuple
13+
import socket
14+
import logging
15+
import sysconfig
16+
17+
from pathlib import Path
18+
from typing import List, Optional, Tuple, Any
1319
from urllib import parse as urlparse
1420

15-
from ..common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_HTTP_PORT
16-
from ..common.utils import socket_connection, text_
21+
from ..common.utils import text_
22+
from ..common.constants import DEFAULT_HTTPS_PORT, DEFAULT_HTTP_PORT
23+
from ..common.types import Readables, Writables
24+
from ..core.connection import TcpServerConnection
25+
from ..http.exception import HttpProtocolException
1726
from ..http.parser import HttpParser
1827
from ..http.websocket import WebsocketFrame
1928
from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes
2029

30+
logger = logging.getLogger(__name__)
31+
32+
# We need CA bundle to verify TLS connection to upstream servers
33+
PURE_LIB = sysconfig.get_path('purelib')
34+
assert PURE_LIB
35+
CACERT_PEM_PATH = Path(PURE_LIB) / 'certifi' / 'cacert.pem'
2136

37+
38+
# TODO: ReverseProxyPlugin and ProxyPoolPlugin are implementing
39+
# a similar behavior. Abstract that particular logic out into its
40+
# own class.
2241
class ReverseProxyPlugin(HttpWebServerBasePlugin):
2342
"""Extend in-built Web Server to add Reverse Proxy capabilities.
2443
@@ -39,35 +58,102 @@ class ReverseProxyPlugin(HttpWebServerBasePlugin):
3958
"User-Agent": "curl/7.64.1"
4059
},
4160
"origin": "1.2.3.4, 5.6.7.8",
42-
"url": "https://localhost/get"
61+
"url": "http://localhost/get"
4362
}
4463
"""
4564

65+
# TODO: We must use nginx python parser and
66+
# make this plugin nginx.conf complaint.
4667
REVERSE_PROXY_LOCATION: str = r'/get$'
68+
# Randomly choose either http or https upstream endpoint.
69+
#
70+
# This is just to demonstrate that both http and https upstream
71+
# reverse proxy works.
4772
REVERSE_PROXY_PASS = [
4873
b'http://httpbin.org/get',
74+
b'https://httpbin.org/get',
4975
]
5076

77+
def __init__(self, *args: Any, **kwargs: Any):
78+
super().__init__(*args, **kwargs)
79+
self.upstream: Optional[TcpServerConnection] = None
80+
5181
def routes(self) -> List[Tuple[int, str]]:
5282
return [
5383
(httpProtocolTypes.HTTP, ReverseProxyPlugin.REVERSE_PROXY_LOCATION),
5484
(httpProtocolTypes.HTTPS, ReverseProxyPlugin.REVERSE_PROXY_LOCATION),
5585
]
5686

57-
# TODO(abhinavsingh): Upgrade to use non-blocking get/read/write API.
87+
def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]:
88+
if not self.upstream:
89+
return [], []
90+
return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else []
91+
92+
def read_from_descriptors(self, r: Readables) -> bool:
93+
if self.upstream and self.upstream.connection in r:
94+
try:
95+
raw = self.upstream.recv(self.flags.server_recvbuf_size)
96+
if raw is not None:
97+
self.client.queue(raw)
98+
else:
99+
return True # Teardown because upstream server closed the connection
100+
except ssl.SSLWantReadError:
101+
logger.info('Upstream server SSLWantReadError, will retry')
102+
return False
103+
except ConnectionResetError:
104+
logger.debug('Connection reset by upstream server')
105+
return True
106+
return super().read_from_descriptors(r)
107+
108+
def write_to_descriptors(self, w: Writables) -> bool:
109+
if self.upstream and self.upstream.connection in w and self.upstream.has_buffer():
110+
try:
111+
self.upstream.flush()
112+
except ssl.SSLWantWriteError:
113+
logger.info('Upstream server SSLWantWriteError, will retry')
114+
return False
115+
except BrokenPipeError:
116+
logger.debug(
117+
'BrokenPipeError when flushing to upstream server',
118+
)
119+
return True
120+
return super().write_to_descriptors(w)
121+
58122
def handle_request(self, request: HttpParser) -> None:
59-
upstream = random.choice(ReverseProxyPlugin.REVERSE_PROXY_PASS)
60-
url = urlparse.urlsplit(upstream)
123+
url = urlparse.urlsplit(
124+
random.choice(ReverseProxyPlugin.REVERSE_PROXY_PASS),
125+
)
61126
assert url.hostname
62-
with socket_connection((text_(url.hostname), url.port if url.port else DEFAULT_HTTP_PORT)) as conn:
63-
conn.send(request.build())
64-
self.client.queue(memoryview(conn.recv(DEFAULT_BUFFER_SIZE)))
127+
port = url.port or (
128+
DEFAULT_HTTP_PORT if url.scheme ==
129+
b'http' else DEFAULT_HTTPS_PORT
130+
)
131+
self.upstream = TcpServerConnection(text_(url.hostname), port)
132+
try:
133+
self.upstream.connect()
134+
if url.scheme == b'https':
135+
self.upstream.wrap(
136+
text_(
137+
url.hostname,
138+
), ca_file=str(CACERT_PEM_PATH),
139+
)
140+
self.upstream.queue(memoryview(request.build()))
141+
except ConnectionRefusedError:
142+
logger.info(
143+
'Connection refused by upstream server {0}:{1}'.format(
144+
text_(url.hostname), port,
145+
),
146+
)
147+
raise HttpProtocolException()
65148

66149
def on_websocket_open(self) -> None:
67150
pass
68151

69152
def on_websocket_message(self, frame: WebsocketFrame) -> None:
70153
pass
71154

72-
def on_websocket_close(self) -> None:
73-
pass
155+
def on_client_connection_close(self) -> None:
156+
if self.upstream and not self.upstream.closed:
157+
logger.debug('Closing upstream server connection')
158+
self.upstream.close()
159+
self.upstream = None

proxy/plugin/web_server_route.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@ def on_websocket_open(self) -> None:
5454
def on_websocket_message(self, frame: WebsocketFrame) -> None:
5555
logger.info(frame.data)
5656

57-
def on_websocket_close(self) -> None:
57+
def on_client_connection_close(self) -> None:
5858
logger.info('Websocket close')

0 commit comments

Comments
 (0)