Skip to content

Commit 98e113e

Browse files
fantix1st1
andauthored
Refactor SSL shutdown process (#385)
Co-authored-by: Yury Selivanov <[email protected]>
1 parent cdd2218 commit 98e113e

File tree

4 files changed

+247
-109
lines changed

4 files changed

+247
-109
lines changed

tests/test_tcp.py

Lines changed: 143 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,14 +2609,18 @@ async def client(addr):
26092609

26102610
def test_remote_shutdown_receives_trailing_data(self):
26112611
if self.implementation == 'asyncio':
2612+
# this is an issue in asyncio
26122613
raise unittest.SkipTest()
26132614

2614-
CHUNK = 1024 * 128
2615-
SIZE = 32
2615+
CHUNK = 1024 * 16
2616+
SIZE = 8
2617+
count = 0
26162618

26172619
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
26182620
client_sslctx = self._create_client_ssl_context()
26192621
future = None
2622+
filled = threading.Lock()
2623+
eof_received = threading.Lock()
26202624

26212625
def server(sock):
26222626
incoming = ssl.MemoryBIO()
@@ -2647,68 +2651,71 @@ def server(sock):
26472651
sslobj.write(b'pong')
26482652
sock.send(outgoing.read())
26492653

2650-
time.sleep(0.2) # wait for the peer to fill its backlog
2651-
2652-
# send close_notify but don't wait for response
2653-
with self.assertRaises(ssl.SSLWantReadError):
2654-
sslobj.unwrap()
2655-
sock.send(outgoing.read())
2656-
2657-
# should receive all data
26582654
data_len = 0
2659-
while True:
2660-
try:
2661-
chunk = len(sslobj.read(16384))
2662-
data_len += chunk
2663-
except ssl.SSLWantReadError:
2664-
incoming.write(sock.recv(16384))
2665-
except ssl.SSLZeroReturnError:
2666-
break
2667-
2668-
self.assertEqual(data_len, CHUNK * SIZE)
2669-
2670-
# verify that close_notify is received
2671-
sslobj.unwrap()
26722655

2673-
sock.close()
2656+
with filled:
2657+
# trigger peer's resume_writing()
2658+
incoming.write(sock.recv(65536 * 4))
2659+
while True:
2660+
try:
2661+
chunk = len(sslobj.read(16384))
2662+
data_len += chunk
2663+
except ssl.SSLWantReadError:
2664+
break
26742665

2675-
def eof_server(sock):
2676-
sock.starttls(sslctx, server_side=True)
2677-
self.assertEqual(sock.recv_all(4), b'ping')
2678-
sock.send(b'pong')
2666+
# send close_notify but don't wait for response
2667+
with self.assertRaises(ssl.SSLWantReadError):
2668+
sslobj.unwrap()
2669+
sock.send(outgoing.read())
26792670

2680-
time.sleep(0.2) # wait for the peer to fill its backlog
2671+
with eof_received:
2672+
# should receive all data
2673+
while True:
2674+
try:
2675+
chunk = len(sslobj.read(16384))
2676+
data_len += chunk
2677+
except ssl.SSLWantReadError:
2678+
incoming.write(sock.recv(16384))
2679+
except ssl.SSLZeroReturnError:
2680+
break
26812681

2682-
# send EOF
2683-
sock.shutdown(socket.SHUT_WR)
2682+
self.assertEqual(data_len, CHUNK * count)
26842683

2685-
# should receive all data
2686-
data = sock.recv_all(CHUNK * SIZE)
2687-
self.assertEqual(len(data), CHUNK * SIZE)
2684+
# verify that close_notify is received
2685+
sslobj.unwrap()
26882686

26892687
sock.close()
26902688

26912689
async def client(addr):
2692-
nonlocal future
2690+
nonlocal future, count
26932691
future = self.loop.create_future()
26942692

2695-
reader, writer = await asyncio.open_connection(
2696-
*addr,
2697-
ssl=client_sslctx,
2698-
server_hostname='')
2699-
writer.write(b'ping')
2700-
data = await reader.readexactly(4)
2701-
self.assertEqual(data, b'pong')
2702-
2703-
# fill write backlog in a hacky way - renegotiation won't help
2704-
for _ in range(SIZE):
2705-
writer.transport._test__append_write_backlog(b'x' * CHUNK)
2693+
with eof_received:
2694+
with filled:
2695+
reader, writer = await asyncio.open_connection(
2696+
*addr,
2697+
ssl=client_sslctx,
2698+
server_hostname='')
2699+
writer.write(b'ping')
2700+
data = await reader.readexactly(4)
2701+
self.assertEqual(data, b'pong')
2702+
2703+
count = 0
2704+
try:
2705+
while True:
2706+
writer.write(b'x' * CHUNK)
2707+
count += 1
2708+
await asyncio.wait_for(
2709+
asyncio.ensure_future(writer.drain()), 0.5)
2710+
except asyncio.TimeoutError:
2711+
# fill write backlog in a hacky way
2712+
for _ in range(SIZE):
2713+
writer.transport._test__append_write_backlog(
2714+
b'x' * CHUNK)
2715+
count += 1
27062716

2707-
try:
27082717
data = await reader.read()
27092718
self.assertEqual(data, b'')
2710-
except (BrokenPipeError, ConnectionResetError):
2711-
pass
27122719

27132720
await future
27142721

@@ -2728,9 +2735,6 @@ def wrapper(sock):
27282735
with self.tcp_server(run(server)) as srv:
27292736
self.loop.run_until_complete(client(srv.addr))
27302737

2731-
with self.tcp_server(run(eof_server)) as srv:
2732-
self.loop.run_until_complete(client(srv.addr))
2733-
27342738
def test_connect_timeout_warning(self):
27352739
s = socket.socket(socket.AF_INET)
27362740
s.bind(('127.0.0.1', 0))
@@ -2842,7 +2846,7 @@ def server(sock):
28422846
sock.shutdown(socket.SHUT_WR)
28432847
loop.call_soon_threadsafe(eof.set)
28442848
# make sure we have enough time to reproduce the issue
2845-
assert sock.recv(1024) == b''
2849+
self.assertEqual(sock.recv(1024), b'')
28462850
sock.close()
28472851

28482852
class Protocol(asyncio.Protocol):
@@ -2875,7 +2879,92 @@ async def client(addr):
28752879
tr.resume_reading()
28762880
await pr.fut
28772881
tr.close()
2878-
assert extra == b'extra bytes'
2882+
if self.implementation != 'asyncio':
2883+
# extra data received after transport.close() should be
2884+
# ignored - this is likely a bug in asyncio
2885+
self.assertIsNone(extra)
2886+
2887+
with self.tcp_server(server) as srv:
2888+
loop.run_until_complete(client(srv.addr))
2889+
2890+
def test_shutdown_while_pause_reading(self):
2891+
if self.implementation == 'asyncio':
2892+
raise unittest.SkipTest()
2893+
2894+
loop = self.loop
2895+
conn_made = loop.create_future()
2896+
eof_recvd = loop.create_future()
2897+
conn_lost = loop.create_future()
2898+
data_recv = False
2899+
2900+
def server(sock):
2901+
sslctx = self._create_server_ssl_context(self.ONLYCERT,
2902+
self.ONLYKEY)
2903+
incoming = ssl.MemoryBIO()
2904+
outgoing = ssl.MemoryBIO()
2905+
sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
2906+
2907+
while True:
2908+
try:
2909+
sslobj.do_handshake()
2910+
sslobj.write(b'trailing data')
2911+
break
2912+
except ssl.SSLWantReadError:
2913+
if outgoing.pending:
2914+
sock.send(outgoing.read())
2915+
incoming.write(sock.recv(16384))
2916+
if outgoing.pending:
2917+
sock.send(outgoing.read())
2918+
2919+
while True:
2920+
try:
2921+
self.assertEqual(sslobj.read(), b'') # close_notify
2922+
break
2923+
except ssl.SSLWantReadError:
2924+
incoming.write(sock.recv(16384))
2925+
2926+
while True:
2927+
try:
2928+
sslobj.unwrap()
2929+
except ssl.SSLWantReadError:
2930+
if outgoing.pending:
2931+
sock.send(outgoing.read())
2932+
# incoming.write(sock.recv(16384))
2933+
else:
2934+
if outgoing.pending:
2935+
sock.send(outgoing.read())
2936+
break
2937+
2938+
self.assertEqual(sock.recv(16384), b'') # socket closed
2939+
2940+
class Protocol(asyncio.Protocol):
2941+
def connection_made(self, transport):
2942+
conn_made.set_result(None)
2943+
2944+
def data_received(self, data):
2945+
nonlocal data_recv
2946+
data_recv = True
2947+
2948+
def eof_received(self):
2949+
eof_recvd.set_result(None)
2950+
2951+
def connection_lost(self, exc):
2952+
if exc is None:
2953+
conn_lost.set_result(None)
2954+
else:
2955+
conn_lost.set_exception(exc)
2956+
2957+
async def client(addr):
2958+
ctx = self._create_client_ssl_context()
2959+
tr, _ = await loop.create_connection(Protocol, *addr, ssl=ctx)
2960+
await conn_made
2961+
self.assertFalse(data_recv)
2962+
2963+
tr.pause_reading()
2964+
tr.close()
2965+
2966+
await eof_recvd
2967+
await conn_lost
28792968

28802969
with self.tcp_server(server) as srv:
28812970
loop.run_until_complete(client(srv.addr))

uvloop/includes/stdlib.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ cdef ssl_MemoryBIO = ssl.MemoryBIO
129129
cdef ssl_create_default_context = ssl.create_default_context
130130
cdef ssl_SSLError = ssl.SSLError
131131
cdef ssl_SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
132+
cdef ssl_SSLZeroReturnError = ssl.SSLZeroReturnError
132133
cdef ssl_CertificateError = ssl.CertificateError
133134
cdef int ssl_SSL_ERROR_WANT_READ = ssl.SSL_ERROR_WANT_READ
134135
cdef int ssl_SSL_ERROR_WANT_WRITE = ssl.SSL_ERROR_WANT_WRITE

uvloop/sslproto.pxd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ cdef enum AppProtocolState:
2424

2525
cdef class _SSLProtocolTransport:
2626
cdef:
27-
object _loop
27+
Loop _loop
2828
SSLProtocol _ssl_protocol
2929
bint _closed
3030

@@ -41,7 +41,7 @@ cdef class SSLProtocol:
4141
size_t _write_buffer_size
4242

4343
object _waiter
44-
object _loop
44+
Loop _loop
4545
_SSLProtocolTransport _app_transport
4646
bint _app_transport_created
4747

@@ -65,7 +65,6 @@ cdef class SSLProtocol:
6565

6666
bint _ssl_writing_paused
6767
bint _app_reading_paused
68-
bint _eof_received
6968

7069
size_t _incoming_high_water
7170
size_t _incoming_low_water
@@ -100,6 +99,7 @@ cdef class SSLProtocol:
10099

101100
cdef _start_shutdown(self)
102101
cdef _check_shutdown_timeout(self)
102+
cdef _do_read_into_void(self)
103103
cdef _do_flush(self)
104104
cdef _do_shutdown(self)
105105
cdef _on_shutdown_complete(self, shutdown_exc)

0 commit comments

Comments
 (0)