diff --git a/src/Middleware/WebSockets/src/HandshakeHelpers.cs b/src/Middleware/WebSockets/src/HandshakeHelpers.cs
index 4ecc685090c6..05c5ac5363a3 100644
--- a/src/Middleware/WebSockets/src/HandshakeHelpers.cs
+++ b/src/Middleware/WebSockets/src/HandshakeHelpers.cs
@@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
-using System.Collections.Generic;
using System.Security.Cryptography;
using System.Text;
using Microsoft.AspNetCore.Http;
@@ -12,17 +11,6 @@ namespace Microsoft.AspNetCore.WebSockets
{
internal static class HandshakeHelpers
{
- ///
- /// Gets request headers needed process the handshake on the server.
- ///
- public static readonly string[] NeededHeaders = new[]
- {
- HeaderNames.Upgrade,
- HeaderNames.Connection,
- HeaderNames.SecWebSocketKey,
- HeaderNames.SecWebSocketVersion
- };
-
// "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
// This uses C# compiler's ability to refer to static data directly. For more information see https://vcsjones.dev/2019/02/01/csharp-readonly-span-bytes-static
private static ReadOnlySpan EncodedWebSocketKey => new byte[]
@@ -34,69 +22,14 @@ internal static class HandshakeHelpers
};
// Verify Method, Upgrade, Connection, version, key, etc..
- public static bool CheckSupportedWebSocketRequest(string method, List> interestingHeaders, IHeaderDictionary requestHeaders)
- {
- bool validUpgrade = false, validConnection = false, validKey = false, validVersion = false;
-
- if (!string.Equals("GET", method, StringComparison.OrdinalIgnoreCase))
- {
- return false;
- }
-
- foreach (var pair in interestingHeaders)
- {
- if (string.Equals(HeaderNames.Connection, pair.Key, StringComparison.OrdinalIgnoreCase))
- {
- if (string.Equals(HeaderNames.Upgrade, pair.Value, StringComparison.OrdinalIgnoreCase))
- {
- validConnection = true;
- }
- }
- else if (string.Equals(HeaderNames.Upgrade, pair.Key, StringComparison.OrdinalIgnoreCase))
- {
- if (string.Equals(Constants.Headers.UpgradeWebSocket, pair.Value, StringComparison.OrdinalIgnoreCase))
- {
- validUpgrade = true;
- }
- }
- else if (string.Equals(HeaderNames.SecWebSocketVersion, pair.Key, StringComparison.OrdinalIgnoreCase))
- {
- if (string.Equals(Constants.Headers.SupportedVersion, pair.Value, StringComparison.OrdinalIgnoreCase))
- {
- validVersion = true;
- }
- }
- else if (string.Equals(HeaderNames.SecWebSocketKey, pair.Key, StringComparison.OrdinalIgnoreCase))
- {
- validKey = IsRequestKeyValid(pair.Value);
- }
- }
-
- // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
- if (validConnection && requestHeaders[HeaderNames.Connection].Count == 1)
- {
- requestHeaders[HeaderNames.Connection] = HeaderNames.Upgrade;
- }
- if (validUpgrade && requestHeaders[HeaderNames.Upgrade].Count == 1)
- {
- requestHeaders[HeaderNames.Upgrade] = Constants.Headers.UpgradeWebSocket;
- }
- if (validVersion && requestHeaders[HeaderNames.SecWebSocketVersion].Count == 1)
- {
- requestHeaders[HeaderNames.SecWebSocketVersion] = Constants.Headers.SupportedVersion;
- }
-
- return validConnection && validUpgrade && validVersion && validKey;
- }
-
public static void GenerateResponseHeaders(string key, string? subProtocol, IHeaderDictionary headers)
{
- headers[HeaderNames.Connection] = HeaderNames.Upgrade;
- headers[HeaderNames.Upgrade] = Constants.Headers.UpgradeWebSocket;
- headers[HeaderNames.SecWebSocketAccept] = CreateResponseKey(key);
+ headers.Connection = HeaderNames.Upgrade;
+ headers.Upgrade = Constants.Headers.UpgradeWebSocket;
+ headers.SecWebSocketAccept = CreateResponseKey(key);
if (!string.IsNullOrWhiteSpace(subProtocol))
{
- headers[HeaderNames.SecWebSocketProtocol] = subProtocol;
+ headers.SecWebSocketProtocol = subProtocol;
}
}
@@ -128,7 +61,7 @@ public static string CreateResponseKey(string requestKey)
// so this can be hardcoded to 60 bytes for the requestKey + static websocket string
Span mergedBytes = stackalloc byte[60];
Encoding.UTF8.GetBytes(requestKey, mergedBytes);
- EncodedWebSocketKey.CopyTo(mergedBytes.Slice(24));
+ EncodedWebSocketKey.CopyTo(mergedBytes[24..]);
Span hashedBytes = stackalloc byte[20];
var written = SHA1.HashData(mergedBytes, hashedBytes);
diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
index e6d198d73c2c..72fbc077b33f 100644
--- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
+++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs
@@ -118,16 +118,7 @@ public bool IsWebSocketRequest
}
else
{
- var requestHeaders = _context.Request.Headers;
- var interestingHeaders = new List>();
- foreach (var headerName in HandshakeHelpers.NeededHeaders)
- {
- foreach (var value in requestHeaders.GetCommaSeparatedValues(headerName))
- {
- interestingHeaders.Add(new KeyValuePair(headerName, value));
- }
- }
- _isWebSocketRequest = HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, interestingHeaders, requestHeaders);
+ _isWebSocketRequest = CheckSupportedWebSocketRequest(_context.Request.Method, _context.Request.Headers);
}
}
return _isWebSocketRequest.Value;
@@ -148,8 +139,7 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext)
}
TimeSpan keepAliveInterval = _options.KeepAliveInterval;
- var advancedAcceptContext = acceptContext as ExtendedWebSocketAcceptContext;
- if (advancedAcceptContext != null)
+ if (acceptContext is ExtendedWebSocketAcceptContext advancedAcceptContext)
{
if (advancedAcceptContext.KeepAliveInterval.HasValue)
{
@@ -165,6 +155,77 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext)
return WebSocket.CreateFromStream(opaqueTransport, isServer: true, subProtocol: subProtocol, keepAliveInterval: keepAliveInterval);
}
+
+ public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
+ {
+ if (!HttpMethods.IsGet(method))
+ {
+ return false;
+ }
+
+ var foundHeader = false;
+
+ var values = requestHeaders.GetCommaSeparatedValues(HeaderNames.SecWebSocketVersion);
+ foreach (var value in values)
+ {
+ if (string.Equals(value, Constants.Headers.SupportedVersion, StringComparison.OrdinalIgnoreCase))
+ {
+ // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
+ if (values.Length == 1)
+ {
+ requestHeaders.SecWebSocketVersion = Constants.Headers.SupportedVersion;
+ }
+ foundHeader = true;
+ break;
+ }
+ }
+ if (!foundHeader)
+ {
+ return false;
+ }
+ foundHeader = false;
+
+ values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Connection);
+ foreach (var value in values)
+ {
+ if (string.Equals(value, HeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase))
+ {
+ // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
+ if (values.Length == 1)
+ {
+ requestHeaders.Connection = HeaderNames.Upgrade;
+ }
+ foundHeader = true;
+ break;
+ }
+ }
+ if (!foundHeader)
+ {
+ return false;
+ }
+ foundHeader = false;
+
+ values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Upgrade);
+ foreach (var value in values)
+ {
+ if (string.Equals(value, Constants.Headers.UpgradeWebSocket, StringComparison.OrdinalIgnoreCase))
+ {
+ // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
+ if (values.Length == 1)
+ {
+ requestHeaders.Upgrade = Constants.Headers.UpgradeWebSocket;
+ }
+ foundHeader = true;
+ break;
+ }
+ }
+ if (!foundHeader)
+ {
+ return false;
+ }
+
+ return HandshakeHelpers.IsRequestKeyValid(requestHeaders.SecWebSocketKey.ToString());
+ }
}
}
}
diff --git a/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs b/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs
index 73182e7cf00d..0631a4f5846c 100644
--- a/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs
+++ b/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs
@@ -3,6 +3,7 @@
using System;
using System.Net;
+using System.Runtime.ExceptionServices;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
@@ -18,6 +19,7 @@ public class KestrelWebSocketHelpers
{
public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int port, Func app, Action configure = null)
{
+ Exception exceptionFromApp = null;
configure = configure ?? (o => { });
Action startup = builder =>
{
@@ -31,6 +33,8 @@ public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int por
}
catch (Exception ex)
{
+ // capture the exception from the app, we'll throw this at the end of the test when the server is disposed
+ exceptionFromApp = ex;
if (ct.Response.HasStarted)
{
throw;
@@ -64,12 +68,37 @@ public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int por
options.Listen(IPAddress.Loopback, 0);
})
.Configure(startup);
+ }).ConfigureHostOptions(o =>
+ {
+ o.ShutdownTimeout = TimeSpan.FromSeconds(30);
}).Build();
host.Start();
port = host.GetPort();
- return host;
+ return new Disposable(() =>
+ {
+ host.Dispose();
+ if (exceptionFromApp is not null)
+ {
+ ExceptionDispatchInfo.Throw(exceptionFromApp);
+ }
+ });
+ }
+
+ private class Disposable : IDisposable
+ {
+ private readonly Action _dispose;
+
+ public Disposable(Action dispose)
+ {
+ _dispose = dispose;
+ }
+
+ public void Dispose()
+ {
+ _dispose();
+ }
}
}
}
diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
index 5ef3dfec0868..576217bf56e6 100644
--- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
+++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs
@@ -9,7 +9,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Testing;
-using Microsoft.Extensions.Logging.Testing;
using Microsoft.Net.Http.Headers;
using Xunit;
@@ -138,6 +137,7 @@ public async Task SendMediumData_Success()
[Fact]
public async Task SendLongData_Success()
{
+ var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF));
using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
@@ -146,29 +146,22 @@ public async Task SendLongData_Success()
var serverBuffer = new byte[orriginalData.Length];
var result = await webSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None);
- int intermediateCount = result.Count;
- Assert.False(result.EndOfMessage);
- Assert.Equal(WebSocketMessageType.Text, result.MessageType);
-
- result = await webSocket.ReceiveAsync(new ArraySegment(serverBuffer, intermediateCount, orriginalData.Length - intermediateCount), CancellationToken.None);
- intermediateCount += result.Count;
- Assert.False(result.EndOfMessage);
- Assert.Equal(WebSocketMessageType.Text, result.MessageType);
-
- result = await webSocket.ReceiveAsync(new ArraySegment(serverBuffer, intermediateCount, orriginalData.Length - intermediateCount), CancellationToken.None);
- intermediateCount += result.Count;
Assert.True(result.EndOfMessage);
- Assert.Equal(orriginalData.Length, intermediateCount);
- Assert.Equal(WebSocketMessageType.Text, result.MessageType);
+ Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
Assert.Equal(orriginalData, serverBuffer);
+
+ tcs.SetResult();
}))
{
using (var client = new ClientWebSocket())
{
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
await client.SendAsync(new ArraySegment(orriginalData), WebSocketMessageType.Binary, true, CancellationToken.None);
+
}
+ // Wait to close the server otherwise the app could throw if it takes longer than the shutdown timeout
+ await tcs.Task;
}
}
@@ -176,6 +169,7 @@ public async Task SendLongData_Success()
public async Task SendFragmentedData_Success()
{
var orriginalData = Encoding.UTF8.GetBytes("Hello World");
+ var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
{
Assert.True(context.WebSockets.IsWebSocketRequest);
@@ -187,6 +181,7 @@ public async Task SendFragmentedData_Success()
Assert.Equal(2, result.Count);
int totalReceived = result.Count;
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
+ tcs.SetResult();
result = await webSocket.ReceiveAsync(
new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None);
@@ -194,6 +189,7 @@ public async Task SendFragmentedData_Success()
Assert.Equal(2, result.Count);
totalReceived += result.Count;
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
+ tcs.SetResult();
result = await webSocket.ReceiveAsync(
new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None);
@@ -209,7 +205,11 @@ public async Task SendFragmentedData_Success()
{
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
await client.SendAsync(new ArraySegment(orriginalData, 0, 2), WebSocketMessageType.Binary, false, CancellationToken.None);
+ await tcs.Task;
+ tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
await client.SendAsync(new ArraySegment(orriginalData, 2, 2), WebSocketMessageType.Binary, false, CancellationToken.None);
+ await tcs.Task;
+ tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
await client.SendAsync(new ArraySegment(orriginalData, 4, 7), WebSocketMessageType.Binary, true, CancellationToken.None);
}
}
@@ -574,5 +574,62 @@ public async Task OriginIsNotValidatedForNonWebSocketRequests()
}
}
}
+
+ [Fact]
+ public async Task CommonHeadersAreSetToInternedStrings()
+ {
+ using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ var webSocket = await context.WebSockets.AcceptWebSocketAsync();
+
+ // Use ReferenceEquals and test against the constants
+ Assert.Same(HeaderNames.Upgrade, context.Request.Headers.Connection.ToString());
+ Assert.Same(Constants.Headers.UpgradeWebSocket, context.Request.Headers.Upgrade.ToString());
+ Assert.Same(Constants.Headers.SupportedVersion, context.Request.Headers.SecWebSocketVersion.ToString());
+ }))
+ {
+ using (var client = new ClientWebSocket())
+ {
+ await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
+ }
+ }
+ }
+
+ [Fact]
+ public async Task MultipleValueHeadersNotOverridden()
+ {
+ using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ var webSocket = await context.WebSockets.AcceptWebSocketAsync();
+
+ Assert.Equal("Upgrade, keep-alive", context.Request.Headers.Connection.ToString());
+ Assert.Equal("websocket, example", context.Request.Headers.Upgrade.ToString());
+ }))
+ {
+ using (var client = new HttpClient())
+ {
+ var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/"));
+ uri.Scheme = "http";
+
+ // Craft a valid WebSocket Upgrade request
+ using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
+ {
+ request.Headers.Connection.Clear();
+ request.Headers.Connection.Add("Upgrade");
+ request.Headers.Connection.Add("keep-alive");
+ request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket"));
+ request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("example"));
+ request.Headers.Add(HeaderNames.SecWebSocketVersion, "13");
+ // SecWebSocketKey required to be 16 bytes
+ request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None));
+
+ var response = await client.SendAsync(request);
+ Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode);
+ }
+ }
+ }
+ }
}
}