Skip to content
Merged

Pool #694

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ Looking for `proxy.py` plugin examples? Check [proxy/plugin](https://github.com

Table of Contents
=================
* [Generic Work Acceptor and Executor](#generic-work-acceptor-and-executor)
* [WebSocket Client](#websocket-client)
* [TCP Echo Server](#tcp-echo-server)
* [TCP Echo Client](#tcp-echo-client)
@@ -14,6 +15,17 @@ Table of Contents
* [PubSub Eventing](#pubsub-eventing)
* [Https Connect Tunnel](#https-connect-tunnel)

## Generic Work Acceptor and Executor

1. Makes use of `proxy.core.AcceptorPool` and `proxy.core.Work`
2. Demonstrates how to perform generic work using `proxy.py` core.

Start `web_scraper.py` as:

```console
PYTHONPATH=. python examples/web_scraper.py
```

## WebSocket Client

1. Makes use of `proxy.http.websocket.WebsocketClient` which is built on-top of `asyncio`
@@ -22,7 +34,7 @@ Table of Contents

Start `websocket_client.py` as:

```bash
```console
PYTHONPATH=. python examples/websocket_client.py
Received b'hello' after 306 millisec
Received b'hello' after 308 millisec
@@ -44,7 +56,7 @@ Received b'hello' after 309 millisec

Start `tcp_echo_server.py` as:

```bash
```console
PYTHONPATH=. python examples/tcp_echo_server.py
Connection accepted from ('::1', 53285, 0, 0)
Connection closed by client ('::1', 53285, 0, 0)
@@ -57,7 +69,7 @@ Connection closed by client ('::1', 53285, 0, 0)

Start `tcp_echo_client.py` as:

```bash
```console
PYTHONPATH=. python examples/tcp_echo_client.py
b'hello'
b'hello'
@@ -81,7 +93,7 @@ KeyboardInterrupt

Start `ssl_echo_server.py` as:

```bash
```console
PYTHONPATH=. python examples/ssl_echo_server.py
```

@@ -92,7 +104,7 @@ Start `ssl_echo_server.py` as:

Start `ssl_echo_client.py` as:

```bash
```console
PYTHONPATH=. python examples/ssl_echo_client.py
```

@@ -107,7 +119,7 @@ Start `ssl_echo_client.py` as:

Start `pubsub_eventing.py` as:

```bash
```console
PYTHONPATH=. python examples/pubsub_eventing.py
DEBUG:proxy.core.event.subscriber:Subscribed relay sub id 5eb22010764f4d44900f41e2fb408ca6 from core events
publisher starting
70 changes: 70 additions & 0 deletions examples/web_scraper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import time
import socket

from typing import Dict

from proxy.proxy import Proxy
from proxy.core.acceptor import Work, AcceptorPool
from proxy.common.types import Readables, Writables


class WebScraper(Work):
"""Demonstrates how to orchestrate a generic work acceptors and executors
workflow using proxy.py core.
By default, `WebScraper` expects to receive work from a file on disk.
Each line in the file must be a URL to scrape. Received URL is scrapped
by the implementation in this class.
After scrapping, results are published to the eventing core. One or several
result subscriber can then handle the result as necessary. Currently, result
subscribers consume the scrapped response and write discovered URL in the
file on the disk. This creates a feedback loop. Allowing WebScraper to
continue endlessly.
NOTE: No loop detection is performed currently.
NOTE: File descriptor need not point to a file on disk.
Example, file descriptor can be a database connection.
For simplicity, imagine a Redis server connection handling
only PUBSUB protocol.
"""

def get_events(self) -> Dict[socket.socket, int]:
"""Return sockets and events (read or write) that we are interested in."""
return {}

def handle_events(
self,
readables: Readables,
writables: Writables,
) -> bool:
"""Handle readable and writable sockets.
Return True to shutdown work."""
return False


if __name__ == '__main__':
with AcceptorPool(
flags=Proxy.initialize(
port=12345,
num_workers=1,
threadless=True,
keyfile='https-key.pem',
certfile='https-signed-cert.pem',
),
work_klass=WebScraper,
) as pool:
while True:
time.sleep(1)
8 changes: 4 additions & 4 deletions proxy/core/acceptor/acceptor.py
Original file line number Diff line number Diff line change
@@ -74,17 +74,17 @@ def __init__(
) -> None:
super().__init__()
self.flags = flags
# Eventing core queue
self.event_queue = event_queue
# Index assigned by `AcceptorPool`
self.idd = idd
# Lock shared by all acceptor processes
# to avoid concurrent accept over server socket
self.lock = lock
# Index assigned by `AcceptorPool`
self.idd = idd
# Queue over which server socket fd is received on start-up
self.work_queue: connection.Connection = work_queue
# Worker class
self.work_klass = work_klass
# Eventing core queue
self.event_queue = event_queue
# Selector & threadless states
self.running = multiprocessing.Event()
self.selector: Optional[selectors.DefaultSelector] = None
90 changes: 48 additions & 42 deletions proxy/core/acceptor/pool.py
Original file line number Diff line number Diff line change
@@ -25,11 +25,13 @@
from ..event import EventQueue

from ...common.flag import flags
from ...common.constants import DEFAULT_BACKLOG, DEFAULT_IPV6_HOSTNAME, DEFAULT_NUM_WORKERS, DEFAULT_PORT
from ...common.constants import DEFAULT_BACKLOG, DEFAULT_IPV6_HOSTNAME
from ...common.constants import DEFAULT_NUM_WORKERS, DEFAULT_PORT

logger = logging.getLogger(__name__)

# Lock shared by worker processes
# Lock shared by acceptors for
# sequential acceptance of work.
LOCK = multiprocessing.Lock()


@@ -61,20 +63,18 @@


class AcceptorPool:
"""AcceptorPool pre-spawns worker processes to utilize all cores available on the system.
A server socket is initialized and dispatched over a pipe to these workers.
Each worker process then concurrently accepts new client connection over
the initialized server socket.
"""AcceptorPool is a helper class which pre-spawns `Acceptor` processes
to utilize all available CPU cores for accepting new work.
A file descriptor to consume work from is shared with `Acceptor` processes
over a pipe. Each `Acceptor` process then concurrently accepts new work over
the shared file descriptor.
Example usage:
pool = AcceptorPool(flags=..., work_klass=...)
try:
pool.setup()
with AcceptorPool(flags=..., work_klass=...) as pool:
while True:
time.sleep(1)
finally:
pool.shutdown()
`work_klass` must implement `work.Work` class.
"""
@@ -84,11 +84,16 @@ def __init__(
work_klass: Type[Work], event_queue: Optional[EventQueue] = None,
) -> None:
self.flags = flags
# Eventing core queue
self.event_queue: Optional[EventQueue] = event_queue
# File descriptor to use for accepting new work
self.socket: Optional[socket.socket] = None
# Acceptor process instances
self.acceptors: List[Acceptor] = []
# Work queue used to share file descriptor with acceptor processes
self.work_queues: List[connection.Connection] = []
# Work class implementation
self.work_klass = work_klass
self.event_queue: Optional[EventQueue] = event_queue

def __enter__(self) -> 'AcceptorPool':
self.setup()
@@ -102,19 +107,43 @@ def __exit__(
) -> None:
self.shutdown()

def listen(self) -> None:
def setup(self) -> None:
"""Listen on port and setup acceptors."""
self._listen()
# Override flags.port to match the actual port
# we are listening upon. This is necessary to preserve
# the server port when `--port=0` is used.
assert self.socket
self.flags.port = self.socket.getsockname()[1]
self._start_acceptors()
# Send file descriptor to all acceptor processes.
assert self.socket is not None
for index in range(self.flags.num_workers):
send_handle(
self.work_queues[index],
self.socket.fileno(),
self.acceptors[index].pid,
)
self.work_queues[index].close()
self.socket.close()

def shutdown(self) -> None:
logger.info('Shutting down %d workers' % self.flags.num_workers)
for acceptor in self.acceptors:
acceptor.running.set()
for acceptor in self.acceptors:
acceptor.join()
logger.debug('Acceptors shutdown')

def _listen(self) -> None:
self.socket = socket.socket(self.flags.family, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((str(self.flags.hostname), self.flags.port))
self.socket.listen(self.flags.backlog)
self.socket.setblocking(False)
# Override flags.port to match the actual port
# we are listening upon. This is necessary to preserve
# the server port when `--port=0` is used.
self.flags.port = self.socket.getsockname()[1]

def start_workers(self) -> None:
"""Start worker processes."""
def _start_acceptors(self) -> None:
"""Start acceptor processes."""
for acceptor_id in range(self.flags.num_workers):
work_queue = multiprocessing.Pipe()
acceptor = Acceptor(
@@ -134,26 +163,3 @@ def start_workers(self) -> None:
self.acceptors.append(acceptor)
self.work_queues.append(work_queue[0])
logger.info('Started %d workers' % self.flags.num_workers)

def shutdown(self) -> None:
logger.info('Shutting down %d workers' % self.flags.num_workers)
for acceptor in self.acceptors:
acceptor.running.set()
for acceptor in self.acceptors:
acceptor.join()
logger.debug('Acceptors shutdown')

def setup(self) -> None:
"""Listen on port, setup workers and pass server socket to workers."""
self.listen()
self.start_workers()
# Send server socket to all acceptor processes.
assert self.socket is not None
for index in range(self.flags.num_workers):
send_handle(
self.work_queues[index],
self.socket.fileno(),
self.acceptors[index].pid,
)
self.work_queues[index].close()
self.socket.close()
4 changes: 2 additions & 2 deletions proxy/core/base/tcp_server.py
Original file line number Diff line number Diff line change
@@ -15,8 +15,8 @@
from abc import abstractmethod
from typing import Dict, Any, Optional

from proxy.core.acceptor import Work
from proxy.common.types import Readables, Writables
from ...core.acceptor import Work
from ...common.types import Readables, Writables

logger = logging.getLogger(__name__)

3 changes: 2 additions & 1 deletion proxy/core/base/tcp_tunnel.py
Original file line number Diff line number Diff line change
@@ -8,9 +8,10 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from abc import abstractmethod
import socket
import selectors

from abc import abstractmethod
from typing import Any, Optional, Dict

from ...http.parser import HttpParser, httpParserTypes
2 changes: 2 additions & 0 deletions proxy/core/connection/__init__.py
Original file line number Diff line number Diff line change
@@ -11,11 +11,13 @@
from .connection import TcpConnection, TcpConnectionUninitializedException, tcpConnectionTypes
from .client import TcpClientConnection
from .server import TcpServerConnection
from .pool import ConnectionPool

__all__ = [
'TcpConnection',
'TcpConnectionUninitializedException',
'TcpServerConnection',
'TcpClientConnection',
'tcpConnectionTypes',
'ConnectionPool',
]
2 changes: 1 addition & 1 deletion proxy/core/connection/client.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ def __init__(
self,
conn: Union[ssl.SSLSocket, socket.socket],
addr: Tuple[str, int],
):
) -> None:
super().__init__(tcpConnectionTypes.CLIENT)
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn
self.addr: Tuple[str, int] = addr
19 changes: 16 additions & 3 deletions proxy/core/connection/connection.py
Original file line number Diff line number Diff line change
@@ -39,12 +39,14 @@ class TcpConnection(ABC):
when reading and writing into the socket.
Implement the connection property abstract method to return
a socket connection object."""
a socket connection object.
"""

def __init__(self, tag: int):
def __init__(self, tag: int) -> None:
self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client'
self.buffer: List[memoryview] = []
self.closed: bool = False
self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client'
self._reusable: bool = False

@property
@abstractmethod
@@ -95,3 +97,14 @@ def flush(self) -> int:
del mv
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
return sent

def is_reusable(self) -> bool:
return self._reusable

def mark_inuse(self) -> None:
self._reusable = False

def reset(self) -> None:
assert not self.closed
self._reusable = True
self.buffer = []
110 changes: 110 additions & 0 deletions proxy/core/connection/pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import logging

from typing import Set, Dict, Tuple

from ...common.flag import flags
from .server import TcpServerConnection

logger = logging.getLogger(__name__)


flags.add_argument(
'--enable-conn-pool',
action='store_true',
default=False,
help='Default: False. (WIP) Enable upstream connection pooling.',
)


class ConnectionPool:
"""Manages connection pool to upstream servers.
`ConnectionPool` avoids need to reconnect with the upstream
servers repeatedly when a reusable connection is available
in the pool.
A separate pool is maintained for each upstream server.
So internally, it's a pool of pools.
TODO: Listen for read events from the connections
to remove them from the pool when peer closes the
connection. This can also be achieved lazily by
the pool users. Example, if acquired connection
is stale, reacquire.
TODO: Ideally, ConnectionPool must be shared across
all cores to make SSL session cache to also work
without additional out-of-bound synchronizations.
TODO: ConnectionPool currently WON'T work for
HTTPS connection. This is because of missing support for
session cache, session ticket, abbr TLS handshake
and other necessary features to make it work.
NOTE: However, for all HTTP only connections, ConnectionPool
can be used to save upon connection setup time and
speed-up performance of requests.
"""

def __init__(self) -> None:
# Pools of connection per upstream server
self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {}

def acquire(self, host: str, port: int) -> Tuple[bool, TcpServerConnection]:
"""Returns a connection for use with the server."""
addr = (host, port)
# Return a reusable connection if available
if addr in self.pools:
for old_conn in self.pools[addr]:
if old_conn.is_reusable():
old_conn.mark_inuse()
logger.debug(
'Reusing connection#{2} for upstream {0}:{1}'.format(
host, port, id(old_conn),
),
)
return False, old_conn
# Create new connection
new_conn = TcpServerConnection(*addr)
if addr not in self.pools:
self.pools[addr] = set()
self.pools[addr].add(new_conn)
logger.debug(
'Created new connection#{2} for upstream {0}:{1}'.format(
host, port, id(new_conn),
),
)
return True, new_conn

def release(self, conn: TcpServerConnection) -> None:
"""Release the connection.
If the connection has not been closed,
then it will be retained in the pool for reusability.
"""
if conn.closed:
logger.debug(
'Removing connection#{2} from pool from upstream {0}:{1}'.format(
conn.addr[0], conn.addr[1], id(conn),
),
)
self.pools[conn.addr].remove(conn)
else:
logger.debug(
'Retaining connection#{2} to upstream {0}:{1}'.format(
conn.addr[0], conn.addr[1], id(conn),
),
)
assert not conn.is_reusable()
# Reset for reusability
conn.reset()
4 changes: 2 additions & 2 deletions proxy/core/connection/server.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
class TcpServerConnection(TcpConnection):
"""Establishes connection to upstream server."""

def __init__(self, host: str, port: int):
def __init__(self, host: str, port: int) -> None:
super().__init__(tcpConnectionTypes.SERVER)
self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None
self.addr: Tuple[str, int] = (host, int(port))
@@ -38,7 +38,7 @@ def connect(self, addr: Optional[Tuple[str, int]] = None, source_address: Option
self._conn = new_socket_connection(
addr or self.addr, source_address=source_address,
)
self.closed = False
self.closed = False

def wrap(self, hostname: str, ca_file: Optional[str]) -> None:
ctx = ssl.create_default_context(
81 changes: 66 additions & 15 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,8 @@
from ...common.pki import gen_public_key, gen_csr, sign_csr

from ...core.event import eventNames
from ...core.connection import TcpServerConnection, TcpConnectionUninitializedException
from ...core.connection import TcpServerConnection, ConnectionPool
from ...core.connection import TcpConnectionUninitializedException
from ...common.flag import flags

logger = logging.getLogger(__name__)
@@ -112,9 +113,13 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
),
)

# Used to synchronization during certificate generation.
# Used to synchronization during certificate generation and
# connection pool operations.
lock = threading.Lock()

# Shared connection pool
pool = ConnectionPool()

def __init__(
self,
*args: Any, **kwargs: Any,
@@ -175,6 +180,15 @@ def get_descriptors(

return r, w

def _close_and_release(self) -> bool:
if self.flags.enable_conn_pool:
assert self.upstream and not self.upstream.closed
self.upstream.closed = True
with self.lock:
self.pool.release(self.upstream)
self.upstream = None
return True

def write_to_descriptors(self, w: Writables) -> bool:
if (self.upstream and self.upstream.connection not in w) or not self.upstream:
# Currently, we just call write/read block of each plugins. It is
@@ -200,12 +214,12 @@ def write_to_descriptors(self, w: Writables) -> bool:
logger.error(
'BrokenPipeError when flushing buffer for server',
)
return True
return self._close_and_release()
except OSError as e:
logger.exception(
'OSError when flushing buffer to server', exc_info=e,
)
return True
return self._close_and_release()
return False

def read_from_descriptors(self, r: Readables) -> bool:
@@ -229,6 +243,7 @@ def read_from_descriptors(self, r: Readables) -> bool:
try:
raw = self.upstream.recv(self.flags.server_recvbuf_size)
except TimeoutError as e:
self._close_and_release()
if e.errno == errno.ETIMEDOUT:
logger.warning(
'%s:%d timed out on recv' %
@@ -245,19 +260,22 @@ def read_from_descriptors(self, r: Readables) -> bool:
'%s:%d unreachable on recv' %
self.upstream.addr,
)
return True
if e.errno == errno.ECONNRESET:
logger.warning('Connection reset by upstream: %r' % e)
logger.warning(
'Connection reset by upstream: {0}:{1}'.format(
*self.upstream.addr,
),
)
else:
logger.exception(
'Exception while receiving from %s connection %r with reason %r' %
(self.upstream.tag, self.upstream.connection, e),
)
return True
return self._close_and_release()

if raw is None:
logger.debug('Server closed connection, tearing down...')
return True
return self._close_and_release()

for plugin in self.plugins.values():
raw = plugin.handle_upstream_chunk(raw)
@@ -324,10 +342,16 @@ def on_client_connection_close(self) -> None:
for plugin in self.plugins.values():
plugin.on_upstream_connection_close()

# If server was never initialized, return
# If server was never initialized or was _close_and_release
if self.upstream is None:
return

if self.flags.enable_conn_pool:
# Release the connection for reusability
with self.lock:
self.pool.release(self.upstream)
return

try:
try:
self.upstream.connection.shutdown(socket.SHUT_WR)
@@ -516,10 +540,31 @@ def handle_pipeline_response(self, raw: memoryview) -> None:
def connect_upstream(self) -> None:
host, port = self.request.host, self.request.port
if host and port:
self.upstream = TcpServerConnection(text_(host), port)
if self.flags.enable_conn_pool:
with self.lock:
created, self.upstream = self.pool.acquire(
text_(host), port,
)
else:
created, self.upstream = True, TcpServerConnection(
text_(host), port,
)
if not created:
# NOTE: Acquired connection might be in an unusable state.
#
# This can only be confirmed by reading from connection.
# For stale connections, we will receive None, indicating
# to drop the connection.
#
# If that happen, we must acquire a fresh connection.
logger.info(
'Reusing connection to upstream %s:%d' %
(text_(host), port),
)
return
try:
logger.debug(
'Connecting to upstream %s:%s' %
'Connecting to upstream %s:%d' %
(text_(host), port),
)
# Invoke plugin.resolve_dns
@@ -543,11 +588,17 @@ def connect_upstream(self) -> None:
(text_(host), port),
)
except Exception as e: # TimeoutError, socket.gaierror
logger.exception(
'Unable to connect with upstream server', exc_info=e,
logger.warning(
'Unable to connect with upstream %s:%d due to %s' % (
text_(host), port, str(e),
),
)
self.upstream.closed = True
raise ProxyConnectionFailed(text_(host), port, repr(e)) from e
if self.flags.enable_conn_pool:
with self.lock:
self.pool.release(self.upstream)
raise ProxyConnectionFailed(
text_(host), port, repr(e),
) from e
else:
logger.exception('Both host and port must exist')
raise HttpProtocolException()
73 changes: 73 additions & 0 deletions tests/core/test_conn_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import unittest

from unittest import mock

from proxy.core.connection import ConnectionPool


class TestConnectionPool(unittest.TestCase):

@mock.patch('proxy.core.connection.pool.TcpServerConnection')
def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
pool = ConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value
mock_conn.is_reusable.side_effect = [
False, True, True,
]
mock_conn.closed = False
# Acquire
created, conn = pool.acquire(*addr)
self.assertTrue(created)
mock_tcp_server_connection.assert_called_once_with(*addr)
self.assertEqual(conn, mock_conn)
self.assertEqual(len(pool.pools[addr]), 1)
self.assertTrue(conn in pool.pools[addr])
# Release (connection must be retained because not closed)
pool.release(conn)
self.assertEqual(len(pool.pools[addr]), 1)
self.assertTrue(conn in pool.pools[addr])
# Reacquire
created, conn = pool.acquire(*addr)
self.assertFalse(created)
mock_conn.reset.assert_called_once()
self.assertEqual(conn, mock_conn)
self.assertEqual(len(pool.pools[addr]), 1)
self.assertTrue(conn in pool.pools[addr])

@mock.patch('proxy.core.connection.pool.TcpServerConnection')
def test_closed_connections_are_removed_on_release(
self, mock_tcp_server_connection: mock.Mock,
) -> None:
pool = ConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value
mock_conn.closed = True
mock_conn.addr = addr
# Acquire
created, conn = pool.acquire(*addr)
self.assertTrue(created)
mock_tcp_server_connection.assert_called_once_with(*addr)
self.assertEqual(conn, mock_conn)
self.assertEqual(len(pool.pools[addr]), 1)
self.assertTrue(conn in pool.pools[addr])
# Release
pool.release(conn)
self.assertEqual(len(pool.pools[addr]), 0)
# Acquire
created, conn = pool.acquire(*addr)
self.assertTrue(created)
self.assertEqual(mock_tcp_server_connection.call_count, 2)
mock_conn.is_reusable.assert_not_called()