diff --git a/src/Middleware/WebSockets/src/HandshakeHelpers.cs b/src/Middleware/WebSockets/src/HandshakeHelpers.cs index 4ecc685090c6..05c5ac5363a3 100644 --- a/src/Middleware/WebSockets/src/HandshakeHelpers.cs +++ b/src/Middleware/WebSockets/src/HandshakeHelpers.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.Security.Cryptography; using System.Text; using Microsoft.AspNetCore.Http; @@ -12,17 +11,6 @@ namespace Microsoft.AspNetCore.WebSockets { internal static class HandshakeHelpers { - /// - /// Gets request headers needed process the handshake on the server. - /// - public static readonly string[] NeededHeaders = new[] - { - HeaderNames.Upgrade, - HeaderNames.Connection, - HeaderNames.SecWebSocketKey, - HeaderNames.SecWebSocketVersion - }; - // "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" // This uses C# compiler's ability to refer to static data directly. For more information see https://vcsjones.dev/2019/02/01/csharp-readonly-span-bytes-static private static ReadOnlySpan EncodedWebSocketKey => new byte[] @@ -34,69 +22,14 @@ internal static class HandshakeHelpers }; // Verify Method, Upgrade, Connection, version, key, etc.. - public static bool CheckSupportedWebSocketRequest(string method, List> interestingHeaders, IHeaderDictionary requestHeaders) - { - bool validUpgrade = false, validConnection = false, validKey = false, validVersion = false; - - if (!string.Equals("GET", method, StringComparison.OrdinalIgnoreCase)) - { - return false; - } - - foreach (var pair in interestingHeaders) - { - if (string.Equals(HeaderNames.Connection, pair.Key, StringComparison.OrdinalIgnoreCase)) - { - if (string.Equals(HeaderNames.Upgrade, pair.Value, StringComparison.OrdinalIgnoreCase)) - { - validConnection = true; - } - } - else if (string.Equals(HeaderNames.Upgrade, pair.Key, StringComparison.OrdinalIgnoreCase)) - { - if (string.Equals(Constants.Headers.UpgradeWebSocket, pair.Value, StringComparison.OrdinalIgnoreCase)) - { - validUpgrade = true; - } - } - else if (string.Equals(HeaderNames.SecWebSocketVersion, pair.Key, StringComparison.OrdinalIgnoreCase)) - { - if (string.Equals(Constants.Headers.SupportedVersion, pair.Value, StringComparison.OrdinalIgnoreCase)) - { - validVersion = true; - } - } - else if (string.Equals(HeaderNames.SecWebSocketKey, pair.Key, StringComparison.OrdinalIgnoreCase)) - { - validKey = IsRequestKeyValid(pair.Value); - } - } - - // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions. - if (validConnection && requestHeaders[HeaderNames.Connection].Count == 1) - { - requestHeaders[HeaderNames.Connection] = HeaderNames.Upgrade; - } - if (validUpgrade && requestHeaders[HeaderNames.Upgrade].Count == 1) - { - requestHeaders[HeaderNames.Upgrade] = Constants.Headers.UpgradeWebSocket; - } - if (validVersion && requestHeaders[HeaderNames.SecWebSocketVersion].Count == 1) - { - requestHeaders[HeaderNames.SecWebSocketVersion] = Constants.Headers.SupportedVersion; - } - - return validConnection && validUpgrade && validVersion && validKey; - } - public static void GenerateResponseHeaders(string key, string? subProtocol, IHeaderDictionary headers) { - headers[HeaderNames.Connection] = HeaderNames.Upgrade; - headers[HeaderNames.Upgrade] = Constants.Headers.UpgradeWebSocket; - headers[HeaderNames.SecWebSocketAccept] = CreateResponseKey(key); + headers.Connection = HeaderNames.Upgrade; + headers.Upgrade = Constants.Headers.UpgradeWebSocket; + headers.SecWebSocketAccept = CreateResponseKey(key); if (!string.IsNullOrWhiteSpace(subProtocol)) { - headers[HeaderNames.SecWebSocketProtocol] = subProtocol; + headers.SecWebSocketProtocol = subProtocol; } } @@ -128,7 +61,7 @@ public static string CreateResponseKey(string requestKey) // so this can be hardcoded to 60 bytes for the requestKey + static websocket string Span mergedBytes = stackalloc byte[60]; Encoding.UTF8.GetBytes(requestKey, mergedBytes); - EncodedWebSocketKey.CopyTo(mergedBytes.Slice(24)); + EncodedWebSocketKey.CopyTo(mergedBytes[24..]); Span hashedBytes = stackalloc byte[20]; var written = SHA1.HashData(mergedBytes, hashedBytes); diff --git a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs index e6d198d73c2c..72fbc077b33f 100644 --- a/src/Middleware/WebSockets/src/WebSocketMiddleware.cs +++ b/src/Middleware/WebSockets/src/WebSocketMiddleware.cs @@ -118,16 +118,7 @@ public bool IsWebSocketRequest } else { - var requestHeaders = _context.Request.Headers; - var interestingHeaders = new List>(); - foreach (var headerName in HandshakeHelpers.NeededHeaders) - { - foreach (var value in requestHeaders.GetCommaSeparatedValues(headerName)) - { - interestingHeaders.Add(new KeyValuePair(headerName, value)); - } - } - _isWebSocketRequest = HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, interestingHeaders, requestHeaders); + _isWebSocketRequest = CheckSupportedWebSocketRequest(_context.Request.Method, _context.Request.Headers); } } return _isWebSocketRequest.Value; @@ -148,8 +139,7 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) } TimeSpan keepAliveInterval = _options.KeepAliveInterval; - var advancedAcceptContext = acceptContext as ExtendedWebSocketAcceptContext; - if (advancedAcceptContext != null) + if (acceptContext is ExtendedWebSocketAcceptContext advancedAcceptContext) { if (advancedAcceptContext.KeepAliveInterval.HasValue) { @@ -165,6 +155,77 @@ public async Task AcceptAsync(WebSocketAcceptContext acceptContext) return WebSocket.CreateFromStream(opaqueTransport, isServer: true, subProtocol: subProtocol, keepAliveInterval: keepAliveInterval); } + + public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders) + { + if (!HttpMethods.IsGet(method)) + { + return false; + } + + var foundHeader = false; + + var values = requestHeaders.GetCommaSeparatedValues(HeaderNames.SecWebSocketVersion); + foreach (var value in values) + { + if (string.Equals(value, Constants.Headers.SupportedVersion, StringComparison.OrdinalIgnoreCase)) + { + // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions. + if (values.Length == 1) + { + requestHeaders.SecWebSocketVersion = Constants.Headers.SupportedVersion; + } + foundHeader = true; + break; + } + } + if (!foundHeader) + { + return false; + } + foundHeader = false; + + values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Connection); + foreach (var value in values) + { + if (string.Equals(value, HeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase)) + { + // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions. + if (values.Length == 1) + { + requestHeaders.Connection = HeaderNames.Upgrade; + } + foundHeader = true; + break; + } + } + if (!foundHeader) + { + return false; + } + foundHeader = false; + + values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Upgrade); + foreach (var value in values) + { + if (string.Equals(value, Constants.Headers.UpgradeWebSocket, StringComparison.OrdinalIgnoreCase)) + { + // WebSockets are long lived; so if the header values are valid we switch them out for the interned versions. + if (values.Length == 1) + { + requestHeaders.Upgrade = Constants.Headers.UpgradeWebSocket; + } + foundHeader = true; + break; + } + } + if (!foundHeader) + { + return false; + } + + return HandshakeHelpers.IsRequestKeyValid(requestHeaders.SecWebSocketKey.ToString()); + } } } } diff --git a/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs b/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs index 73182e7cf00d..0631a4f5846c 100644 --- a/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs +++ b/src/Middleware/WebSockets/test/UnitTests/KestrelWebSocketHelpers.cs @@ -3,6 +3,7 @@ using System; using System.Net; +using System.Runtime.ExceptionServices; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; @@ -18,6 +19,7 @@ public class KestrelWebSocketHelpers { public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int port, Func app, Action configure = null) { + Exception exceptionFromApp = null; configure = configure ?? (o => { }); Action startup = builder => { @@ -31,6 +33,8 @@ public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int por } catch (Exception ex) { + // capture the exception from the app, we'll throw this at the end of the test when the server is disposed + exceptionFromApp = ex; if (ct.Response.HasStarted) { throw; @@ -64,12 +68,37 @@ public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int por options.Listen(IPAddress.Loopback, 0); }) .Configure(startup); + }).ConfigureHostOptions(o => + { + o.ShutdownTimeout = TimeSpan.FromSeconds(30); }).Build(); host.Start(); port = host.GetPort(); - return host; + return new Disposable(() => + { + host.Dispose(); + if (exceptionFromApp is not null) + { + ExceptionDispatchInfo.Throw(exceptionFromApp); + } + }); + } + + private class Disposable : IDisposable + { + private readonly Action _dispose; + + public Disposable(Action dispose) + { + _dispose = dispose; + } + + public void Dispose() + { + _dispose(); + } } } } diff --git a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs index 5ef3dfec0868..576217bf56e6 100644 --- a/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs +++ b/src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs @@ -9,7 +9,6 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Testing; -using Microsoft.Extensions.Logging.Testing; using Microsoft.Net.Http.Headers; using Xunit; @@ -138,6 +137,7 @@ public async Task SendMediumData_Success() [Fact] public async Task SendLongData_Success() { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var orriginalData = Encoding.UTF8.GetBytes(new string('a', 0x1FFFF)); using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { @@ -146,29 +146,22 @@ public async Task SendLongData_Success() var serverBuffer = new byte[orriginalData.Length]; var result = await webSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); - int intermediateCount = result.Count; - Assert.False(result.EndOfMessage); - Assert.Equal(WebSocketMessageType.Text, result.MessageType); - - result = await webSocket.ReceiveAsync(new ArraySegment(serverBuffer, intermediateCount, orriginalData.Length - intermediateCount), CancellationToken.None); - intermediateCount += result.Count; - Assert.False(result.EndOfMessage); - Assert.Equal(WebSocketMessageType.Text, result.MessageType); - - result = await webSocket.ReceiveAsync(new ArraySegment(serverBuffer, intermediateCount, orriginalData.Length - intermediateCount), CancellationToken.None); - intermediateCount += result.Count; Assert.True(result.EndOfMessage); - Assert.Equal(orriginalData.Length, intermediateCount); - Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); Assert.Equal(orriginalData, serverBuffer); + + tcs.SetResult(); })) { using (var client = new ClientWebSocket()) { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); await client.SendAsync(new ArraySegment(orriginalData), WebSocketMessageType.Binary, true, CancellationToken.None); + } + // Wait to close the server otherwise the app could throw if it takes longer than the shutdown timeout + await tcs.Task; } } @@ -176,6 +169,7 @@ public async Task SendLongData_Success() public async Task SendFragmentedData_Success() { var orriginalData = Encoding.UTF8.GetBytes("Hello World"); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => { Assert.True(context.WebSockets.IsWebSocketRequest); @@ -187,6 +181,7 @@ public async Task SendFragmentedData_Success() Assert.Equal(2, result.Count); int totalReceived = result.Count; Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + tcs.SetResult(); result = await webSocket.ReceiveAsync( new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); @@ -194,6 +189,7 @@ public async Task SendFragmentedData_Success() Assert.Equal(2, result.Count); totalReceived += result.Count; Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + tcs.SetResult(); result = await webSocket.ReceiveAsync( new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); @@ -209,7 +205,11 @@ public async Task SendFragmentedData_Success() { await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); await client.SendAsync(new ArraySegment(orriginalData, 0, 2), WebSocketMessageType.Binary, false, CancellationToken.None); + await tcs.Task; + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await client.SendAsync(new ArraySegment(orriginalData, 2, 2), WebSocketMessageType.Binary, false, CancellationToken.None); + await tcs.Task; + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await client.SendAsync(new ArraySegment(orriginalData, 4, 7), WebSocketMessageType.Binary, true, CancellationToken.None); } } @@ -574,5 +574,62 @@ public async Task OriginIsNotValidatedForNonWebSocketRequests() } } } + + [Fact] + public async Task CommonHeadersAreSetToInternedStrings() + { + using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + var webSocket = await context.WebSockets.AcceptWebSocketAsync(); + + // Use ReferenceEquals and test against the constants + Assert.Same(HeaderNames.Upgrade, context.Request.Headers.Connection.ToString()); + Assert.Same(Constants.Headers.UpgradeWebSocket, context.Request.Headers.Upgrade.ToString()); + Assert.Same(Constants.Headers.SupportedVersion, context.Request.Headers.SecWebSocketVersion.ToString()); + })) + { + using (var client = new ClientWebSocket()) + { + await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None); + } + } + } + + [Fact] + public async Task MultipleValueHeadersNotOverridden() + { + using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + var webSocket = await context.WebSockets.AcceptWebSocketAsync(); + + Assert.Equal("Upgrade, keep-alive", context.Request.Headers.Connection.ToString()); + Assert.Equal("websocket, example", context.Request.Headers.Upgrade.ToString()); + })) + { + using (var client = new HttpClient()) + { + var uri = new UriBuilder(new Uri($"ws://127.0.0.1:{port}/")); + uri.Scheme = "http"; + + // Craft a valid WebSocket Upgrade request + using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) + { + request.Headers.Connection.Clear(); + request.Headers.Connection.Add("Upgrade"); + request.Headers.Connection.Add("keep-alive"); + request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket")); + request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("example")); + request.Headers.Add(HeaderNames.SecWebSocketVersion, "13"); + // SecWebSocketKey required to be 16 bytes + request.Headers.Add(HeaderNames.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None)); + + var response = await client.SendAsync(request); + Assert.Equal(HttpStatusCode.SwitchingProtocols, response.StatusCode); + } + } + } + } } }