Skip to content

Top-level notion of work not client #695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 7, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/https_connect_tunnel.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:

# Drop the request if not a CONNECT request
if self.request.method != httpMethods.CONNECT:
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME,
)
return True
@@ -66,7 +66,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
self.connect_upstream()

# Queue tunnel established response to client
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)

8 changes: 4 additions & 4 deletions examples/ssl_echo_server.py
Original file line number Diff line number Diff line change
@@ -27,19 +27,19 @@ def initialize(self) -> None:
# here using wrap_socket() utility.
assert self.flags.keyfile is not None and self.flags.certfile is not None
conn = wrap_socket(
self.client.connection,
self.work.connection,
self.flags.keyfile,
self.flags.certfile,
)
conn.setblocking(False)
# Upgrade plain TcpClientConnection to SSL connection object
self.client = TcpClientConnection(
conn=conn, addr=self.client.addr,
self.work = TcpClientConnection(
conn=conn, addr=self.work.addr,
)

def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None


4 changes: 2 additions & 2 deletions examples/tcp_echo_server.py
Original file line number Diff line number Diff line change
@@ -20,11 +20,11 @@ class EchoServerHandler(BaseTcpServerHandler):
"""Sets client socket to non-blocking during initialization."""

def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)

def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None


9 changes: 6 additions & 3 deletions proxy/core/acceptor/work.py
Original file line number Diff line number Diff line change
@@ -25,15 +25,18 @@ class Work(ABC):

def __init__(
self,
client: TcpClientConnection,
work: TcpClientConnection,
flags: argparse.Namespace,
event_queue: Optional[EventQueue] = None,
uid: Optional[UUID] = None,
) -> None:
self.client = client
# Work uuid
self.uid: UUID = uid if uid is not None else uuid4()
self.flags = flags
# Eventing core queue
self.event_queue = event_queue
self.uid: UUID = uid if uid is not None else uuid4()
# Accept work
self.work = work

@abstractmethod
def get_events(self) -> Dict[socket.socket, int]:
34 changes: 17 additions & 17 deletions proxy/core/base/tcp_server.py
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@ class BaseTcpServerHandler(Work):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.must_flush_before_shutdown = False
logger.debug('Connection accepted from {0}'.format(self.client.addr))
logger.debug('Connection accepted from {0}'.format(self.work.addr))

@abstractmethod
def handle_data(self, data: memoryview) -> Optional[bool]:
@@ -57,14 +57,14 @@ def get_events(self) -> Dict[socket.socket, int]:
# We always want to read from client
# Register for EVENT_READ events
if self.must_flush_before_shutdown is False:
events[self.client.connection] = selectors.EVENT_READ
events[self.work.connection] = selectors.EVENT_READ
# If there is pending buffer for client
# also register for EVENT_WRITE events
if self.client.has_buffer():
if self.client.connection in events:
events[self.client.connection] |= selectors.EVENT_WRITE
if self.work.has_buffer():
if self.work.connection in events:
events[self.work.connection] |= selectors.EVENT_WRITE
else:
events[self.client.connection] = selectors.EVENT_WRITE
events[self.work.connection] = selectors.EVENT_WRITE
return events

def handle_events(
@@ -79,32 +79,32 @@ def handle_events(
if teardown:
logger.debug(
'Shutting down client {0} connection'.format(
self.client.addr,
self.work.addr,
),
)
return teardown

def handle_writables(self, writables: Writables) -> bool:
teardown = False
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug(
'Flushing buffer to client {0}'.format(self.client.addr),
'Flushing buffer to client {0}'.format(self.work.addr),
)
self.client.flush()
self.work.flush()
if self.must_flush_before_shutdown is True:
if not self.client.has_buffer():
if not self.work.has_buffer():
teardown = True
self.must_flush_before_shutdown = False
return teardown

def handle_readables(self, readables: Readables) -> bool:
teardown = False
if self.client.connection in readables:
data = self.client.recv(self.flags.client_recvbuf_size)
if self.work.connection in readables:
data = self.work.recv(self.flags.client_recvbuf_size)
if data is None:
logger.debug(
'Connection closed by client {0}'.format(
self.client.addr,
self.work.addr,
),
)
teardown = True
@@ -113,13 +113,13 @@ def handle_readables(self, readables: Readables) -> bool:
if isinstance(r, bool) and r is True:
logger.debug(
'Implementation signaled shutdown for client {0}'.format(
self.client.addr,
self.work.addr,
),
)
if self.client.has_buffer():
if self.work.has_buffer():
logger.debug(
'Client {0} has pending buffer, will be flushed before shutting down'.format(
self.client.addr,
self.work.addr,
),
)
self.must_flush_before_shutdown = True
4 changes: 2 additions & 2 deletions proxy/core/base/tcp_tunnel.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
pass # pragma: no cover

def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)

def shutdown(self) -> None:
if self.upstream:
@@ -87,7 +87,7 @@ def handle_events(
print('Connection closed by server')
return True
# tunnel data to client
self.client.queue(data)
self.work.queue(data)
if self.upstream and self.upstream.connection in writables:
self.upstream.flush()
return False
46 changes: 23 additions & 23 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
@@ -89,25 +89,25 @@ def __init__(self, *args: Any, **kwargs: Any):

def initialize(self) -> None:
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
conn = self._optionally_wrap_socket(self.client.connection)
conn = self._optionally_wrap_socket(self.work.connection)
conn.setblocking(False)
# Update client connection reference if connection was wrapped
if self._encryption_enabled():
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
instance: HttpProtocolHandlerPlugin = klass(
self.uid,
self.flags,
self.client,
self.work,
self.request,
self.event_queue,
)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.client.connection)
logger.debug('Handling connection %r' % self.work.connection)

def is_inactive(self) -> bool:
if not self.client.has_buffer() and \
if not self.work.has_buffer() and \
self._connection_inactive_for() > self.flags.timeout:
return True
return False
@@ -127,20 +127,20 @@ def shutdown(self) -> None:
logger.debug(
'Closing client connection %r '
'at address %r has buffer %s' %
(self.client.connection, self.client.addr, self.client.has_buffer()),
(self.work.connection, self.work.addr, self.work.has_buffer()),
)

conn = self.client.connection
conn = self.work.connection
# Unwrap if wrapped before shutdown.
if self._encryption_enabled() and \
isinstance(self.client.connection, ssl.SSLSocket):
conn = self.client.connection.unwrap()
isinstance(self.work.connection, ssl.SSLSocket):
conn = self.work.connection.unwrap()
conn.shutdown(socket.SHUT_WR)
logger.debug('Client connection shutdown successful')
except OSError:
pass
finally:
self.client.connection.close()
self.work.connection.close()
logger.debug('Client connection closed')
super().shutdown()

@@ -196,7 +196,7 @@ def handle_events(
def handle_data(self, data: memoryview) -> Optional[bool]:
if data is None:
logger.debug('Client closed connection, tearing down...')
self.client.closed = True
self.work.closed = True
return True

try:
@@ -227,7 +227,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
logger.debug(
'Updated client conn to %s', upgraded_sock,
)
self.client._conn = upgraded_sock
self.work._conn = upgraded_sock
for plugin_ in self.plugins.values():
if plugin_ != plugin:
plugin_.client._conn = upgraded_sock
@@ -237,20 +237,20 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
logger.debug('HttpProtocolException raised')
response: Optional[memoryview] = e.response(self.request)
if response:
self.client.queue(response)
self.work.queue(response)
return True
return False

def handle_writables(self, writables: Writables) -> bool:
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug('Client is ready for writes, flushing buffer')
self.last_activity = time.time()

# TODO(abhinavsingh): This hook could just reside within server recv block
# instead of invoking when flushed to client.
#
# Invoke plugin.on_response_chunk
chunk = self.client.buffer
chunk = self.work.buffer
for plugin in self.plugins.values():
chunk = plugin.on_response_chunk(chunk)
if chunk is None:
@@ -272,7 +272,7 @@ def handle_writables(self, writables: Writables) -> bool:
return False

def handle_readables(self, readables: Readables) -> bool:
if self.client.connection in readables:
if self.work.connection in readables:
logger.debug('Client is ready for reads, reading')
self.last_activity = time.time()
try:
@@ -290,7 +290,7 @@ def handle_readables(self, readables: Readables) -> bool:
else:
logger.exception(
'Exception while receiving from %s connection %r with reason %r' %
(self.client.tag, self.client.connection, e),
(self.work.tag, self.work.connection, e),
)
return True
return False
@@ -324,7 +324,7 @@ def run(self) -> None:
except Exception as e:
logger.exception(
'Exception while handling connection %r' %
self.client.connection, exc_info=e,
self.work.connection, exc_info=e,
)
finally:
self.shutdown()
@@ -377,24 +377,24 @@ def _run_once(self) -> bool:

def _flush(self) -> None:
assert self.selector
if not self.client.has_buffer():
if not self.work.has_buffer():
return
try:
self.selector.register(
self.client.connection,
self.work.connection,
selectors.EVENT_WRITE,
)
while self.client.has_buffer():
while self.work.has_buffer():
ev: List[
Tuple[selectors.SelectorKey, int]
] = self.selector.select(timeout=1)
if len(ev) == 0:
continue
self.client.flush()
self.work.flush()
except BrokenPipeError:
pass
finally:
self.selector.unregister(self.client.connection)
self.selector.unregister(self.work.connection)

def _connection_inactive_for(self) -> float:
return time.time() - self.last_activity
12 changes: 6 additions & 6 deletions tests/http/exceptions/test_http_proxy_auth_failed.py
Original file line number Diff line number Diff line change
@@ -63,9 +63,9 @@ def test_proxy_auth_fails_without_cred(self, mock_server_conn: mock.Mock) -> Non

self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()

@@ -92,9 +92,9 @@ def test_proxy_auth_fails_with_invalid_cred(self, mock_server_conn: mock.Mock) -

self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()

@@ -121,7 +121,7 @@ def test_proxy_auth_works_with_valid_cred(self, mock_server_conn: mock.Mock) ->

self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)

@mock.patch('proxy.http.proxy.server.TcpServerConnection')
def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None:
@@ -146,4 +146,4 @@ def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: m

self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)
2 changes: 1 addition & 1 deletion tests/http/test_http_proxy_tls_interception.py
Original file line number Diff line number Diff line change
@@ -201,7 +201,7 @@ def mock_connection() -> Any:
)
self.assertEqual(self._conn.setblocking.call_count, 2)
self.assertEqual(
self.protocol_handler.client.connection,
self.protocol_handler.work.connection,
self.mock_ssl_wrap.return_value,
)

10 changes: 5 additions & 5 deletions tests/http/test_protocol_handler.py
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ def assert_tunnel_response(
).upstream is not None,
)
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)
mock_server_connection.assert_called_once()
@@ -111,7 +111,7 @@ def assert_tunnel_response(
server.closed = False

parser = HttpParser(httpParserTypes.RESPONSE_PARSER)
parser.parse(self.protocol_handler.client.buffer[0].tobytes())
parser.parse(self.protocol_handler.work.buffer[0].tobytes())
self.assertEqual(parser.state, httpParserStates.COMPLETE)
assert parser.code is not None
self.assertEqual(int(parser.code), 200)
@@ -199,7 +199,7 @@ def test_proxy_connection_failed(self) -> None:
])
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
ProxyConnectionFailed.RESPONSE_PKT,
)

@@ -231,7 +231,7 @@ def test_proxy_authentication_failed(
])
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
ProxyAuthenticationFailed.RESPONSE_PKT,
)

@@ -328,7 +328,7 @@ def test_authenticated_proxy_http_tunnel(
CRLF,
])
self.assert_tunnel_response(mock_server_connection, server)
self.protocol_handler.client.flush()
self.protocol_handler.work.flush()
self.assert_data_queued_to_server(server)

self.protocol_handler._run_once()
2 changes: 1 addition & 1 deletion tests/http/test_web_server.py
Original file line number Diff line number Diff line change
@@ -132,7 +132,7 @@ def test_default_web_server_returns_404(
httpParserStates.COMPLETE,
)
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
HttpWebServerPlugin.DEFAULT_404_RESPONSE,
)

8 changes: 4 additions & 4 deletions tests/plugin/test_http_proxy_plugins.py
Original file line number Diff line number Diff line change
@@ -139,7 +139,7 @@ def test_proposed_rest_api_plugin(

mock_server_conn.assert_not_called()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
httpStatusCodes.OK, reason=b'OK',
headers={b'Content-Type': b'application/json'},
@@ -215,7 +215,7 @@ def test_filter_by_upstream_host_plugin(

mock_server_conn.assert_not_called()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
status_code=httpStatusCodes.I_AM_A_TEAPOT,
reason=b'I\'m a tea pot',
@@ -305,7 +305,7 @@ def closed() -> bool:
)
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
httpStatusCodes.OK,
reason=b'OK', body=b'Hello from man in the middle',
@@ -337,7 +337,7 @@ def test_filter_by_url_regex_plugin(
self.protocol_handler._run_once()

self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
status_code=httpStatusCodes.NOT_FOUND,
reason=b'Blocked',
6 changes: 3 additions & 3 deletions tests/plugin/test_http_proxy_plugins_with_tls_interception.py
Original file line number Diff line number Diff line change
@@ -170,14 +170,14 @@ def send(raw: bytes) -> int:
self.mock_server_conn.assert_called_once_with('uni.corn', 443)
self.server.connect.assert_called()
self.assertEqual(
self.protocol_handler.client.connection,
self.protocol_handler.work.connection,
self.client_ssl_connection,
)
self.assertEqual(self.server.connection, self.server_ssl_connection)
self._conn.send.assert_called_with(
HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)
self.assertFalse(self.protocol_handler.client.has_buffer())
self.assertFalse(self.protocol_handler.work.has_buffer())

def test_modify_post_data_plugin(self) -> None:
original = b'{"key": "value"}'
@@ -229,7 +229,7 @@ def test_man_in_the_middle_plugin(self) -> None:
)
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
httpStatusCodes.OK,
reason=b'OK', body=b'Hello from man in the middle',