Skip to content

Commit 12feaab

Browse files
authored
follow-up: kestrel tls listener callback (#62266)
1 parent adb9b44 commit 12feaab

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed

src/Servers/Kestrel/Core/test/TlsListenerTests.cs

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ public async Task RunTlsClientHelloCallbackTest_WithExtraShortLastingToken()
7070
var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(3));
7171

7272
await writer.WriteAsync(new byte[1] { 0x16 });
73-
await VerifyThrowsAnyAsync(
74-
async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token),
75-
typeof(OperationCanceledException), typeof(TaskCanceledException));
73+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token));
7674
Assert.False(tlsClientHelloCallbackInvoked);
7775
}
7876

@@ -95,9 +93,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPreCanceledToken()
9593
cts.Cancel();
9694

9795
await writer.WriteAsync(new byte[1] { 0x16 });
98-
await VerifyThrowsAnyAsync(
99-
async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token),
100-
typeof(OperationCanceledException), typeof(TaskCanceledException));
96+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token));
10197
Assert.False(tlsClientHelloCallbackInvoked);
10298
}
10399

@@ -122,7 +118,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPendingCancellation()
122118
await writer.WriteAsync(new byte[2] { 0x03, 0x01 });
123119
cts.Cancel();
124120

125-
await Assert.ThrowsAsync<OperationCanceledException>(async () => await listenerTask);
121+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listenerTask);
126122
Assert.False(tlsClientHelloCallbackInvoked);
127123
}
128124

@@ -158,8 +154,8 @@ public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads()
158154
Assert.Equal(5, readResult.Buffer.Length);
159155

160156
// ensuring that we have read limited number of times
161-
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4,
162-
$"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times.");
157+
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 5,
158+
$"Expected ReadAsync() to happen about 2-5 times. Actually happened {reader.ReadAsyncCounter} times.");
163159
}
164160

165161
private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
@@ -623,28 +619,4 @@ public static IEnumerable<object[]> InvalidClientHelloData_Segmented()
623619
_invalidTlsClientHelloHeader, _invalid3BytesMessage, _invalid9BytesMessage,
624620
_invalidUnknownProtocolVersion1, _invalidUnknownProtocolVersion2, _invalidIncorrectHandshakeMessageType
625621
};
626-
627-
static async Task VerifyThrowsAnyAsync(Func<Task> code, params Type[] exceptionTypes)
628-
{
629-
if (exceptionTypes == null || exceptionTypes.Length == 0)
630-
{
631-
throw new ArgumentException("At least one exception type must be provided.", nameof(exceptionTypes));
632-
}
633-
634-
try
635-
{
636-
await code();
637-
}
638-
catch (Exception ex)
639-
{
640-
if (exceptionTypes.Any(type => type.IsInstanceOfType(ex)))
641-
{
642-
return;
643-
}
644-
645-
throw ThrowsException.ForIncorrectExceptionType(exceptionTypes.First(), ex);
646-
}
647-
648-
throw ThrowsException.ForNoException(exceptionTypes.First());
649-
}
650622
}

src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Collections.Generic;
6+
using System.IO.Pipelines;
67
using System.Net;
78
using System.Net.Security;
89
using System.Security.Authentication;
@@ -18,6 +19,8 @@
1819
using Microsoft.Extensions.DependencyInjection;
1920
using Microsoft.Extensions.Hosting;
2021
using Microsoft.Extensions.Logging;
22+
using Newtonsoft.Json.Linq;
23+
using Xunit.Sdk;
2124

2225
namespace InMemory.FunctionalTests;
2326

@@ -66,4 +69,40 @@ await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
6669

6770
Assert.True(tlsClientHelloCallbackInvoked);
6871
}
72+
73+
[Fact]
74+
public async Task TlsClientHelloBytesCallback_UsesOptionsTimeout()
75+
{
76+
var tlsClientHelloCallbackInvoked = false;
77+
var testContext = new TestServiceContext(LoggerFactory);
78+
await using (var server = new TestServer(context => Task.CompletedTask,
79+
testContext,
80+
listenOptions =>
81+
{
82+
listenOptions.UseHttps(_x509Certificate2, httpsOptions =>
83+
{
84+
httpsOptions.HandshakeTimeout = TimeSpan.FromMilliseconds(1);
85+
86+
httpsOptions.TlsClientHelloBytesCallback = (connection, clientHelloBytes) =>
87+
{
88+
Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length);
89+
tlsClientHelloCallbackInvoked = true;
90+
Assert.True(clientHelloBytes.Length > 32);
91+
Assert.NotNull(connection);
92+
};
93+
});
94+
}))
95+
{
96+
using (var connection = server.CreateConnection())
97+
{
98+
await connection.TransportConnection.Input.WriteAsync(new byte[] { 0x16 });
99+
var readResult = await connection.TransportConnection.Output.ReadAsync();
100+
101+
// HttpsConnectionMiddleware catches the exception, so we can only check the effects of the timeout here
102+
Assert.True(readResult.IsCompleted);
103+
}
104+
}
105+
106+
Assert.False(tlsClientHelloCallbackInvoked);
107+
}
69108
}

0 commit comments

Comments
 (0)