Skip to content

Commit

Permalink
Websocket handshake refactoring. (#31506)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShreyasJejurkar committed Apr 29, 2021
1 parent 3e46e68 commit db8a649
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 99 deletions.
77 changes: 5 additions & 72 deletions src/Middleware/WebSockets/src/HandshakeHelpers.cs
Expand Up @@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Security.Cryptography;
using System.Text;
using Microsoft.AspNetCore.Http;
Expand All @@ -12,17 +11,6 @@ namespace Microsoft.AspNetCore.WebSockets
{
internal static class HandshakeHelpers
{
/// <summary>
/// Gets request headers needed process the handshake on the server.
/// </summary>
public static readonly string[] NeededHeaders = new[]
{
HeaderNames.Upgrade,
HeaderNames.Connection,
HeaderNames.SecWebSocketKey,
HeaderNames.SecWebSocketVersion
};

// "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
// This uses C# compiler's ability to refer to static data directly. For more information see https://vcsjones.dev/2019/02/01/csharp-readonly-span-bytes-static
private static ReadOnlySpan<byte> EncodedWebSocketKey => new byte[]
Expand All @@ -34,69 +22,14 @@ internal static class HandshakeHelpers
};

// Verify Method, Upgrade, Connection, version, key, etc..
public static bool CheckSupportedWebSocketRequest(string method, List<KeyValuePair<string, string>> interestingHeaders, IHeaderDictionary requestHeaders)
{
bool validUpgrade = false, validConnection = false, validKey = false, validVersion = false;

if (!string.Equals("GET", method, StringComparison.OrdinalIgnoreCase))
{
return false;
}

foreach (var pair in interestingHeaders)
{
if (string.Equals(HeaderNames.Connection, pair.Key, StringComparison.OrdinalIgnoreCase))
{
if (string.Equals(HeaderNames.Upgrade, pair.Value, StringComparison.OrdinalIgnoreCase))
{
validConnection = true;
}
}
else if (string.Equals(HeaderNames.Upgrade, pair.Key, StringComparison.OrdinalIgnoreCase))
{
if (string.Equals(Constants.Headers.UpgradeWebSocket, pair.Value, StringComparison.OrdinalIgnoreCase))
{
validUpgrade = true;
}
}
else if (string.Equals(HeaderNames.SecWebSocketVersion, pair.Key, StringComparison.OrdinalIgnoreCase))
{
if (string.Equals(Constants.Headers.SupportedVersion, pair.Value, StringComparison.OrdinalIgnoreCase))
{
validVersion = true;
}
}
else if (string.Equals(HeaderNames.SecWebSocketKey, pair.Key, StringComparison.OrdinalIgnoreCase))
{
validKey = IsRequestKeyValid(pair.Value);
}
}

// WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
if (validConnection && requestHeaders[HeaderNames.Connection].Count == 1)
{
requestHeaders[HeaderNames.Connection] = HeaderNames.Upgrade;
}
if (validUpgrade && requestHeaders[HeaderNames.Upgrade].Count == 1)
{
requestHeaders[HeaderNames.Upgrade] = Constants.Headers.UpgradeWebSocket;
}
if (validVersion && requestHeaders[HeaderNames.SecWebSocketVersion].Count == 1)
{
requestHeaders[HeaderNames.SecWebSocketVersion] = Constants.Headers.SupportedVersion;
}

return validConnection && validUpgrade && validVersion && validKey;
}

public static void GenerateResponseHeaders(string key, string? subProtocol, IHeaderDictionary headers)
{
headers[HeaderNames.Connection] = HeaderNames.Upgrade;
headers[HeaderNames.Upgrade] = Constants.Headers.UpgradeWebSocket;
headers[HeaderNames.SecWebSocketAccept] = CreateResponseKey(key);
headers.Connection = HeaderNames.Upgrade;
headers.Upgrade = Constants.Headers.UpgradeWebSocket;
headers.SecWebSocketAccept = CreateResponseKey(key);
if (!string.IsNullOrWhiteSpace(subProtocol))
{
headers[HeaderNames.SecWebSocketProtocol] = subProtocol;
headers.SecWebSocketProtocol = subProtocol;
}
}

Expand Down Expand Up @@ -128,7 +61,7 @@ public static string CreateResponseKey(string requestKey)
// so this can be hardcoded to 60 bytes for the requestKey + static websocket string
Span<byte> mergedBytes = stackalloc byte[60];
Encoding.UTF8.GetBytes(requestKey, mergedBytes);
EncodedWebSocketKey.CopyTo(mergedBytes.Slice(24));
EncodedWebSocketKey.CopyTo(mergedBytes[24..]);

Span<byte> hashedBytes = stackalloc byte[20];
var written = SHA1.HashData(mergedBytes, hashedBytes);
Expand Down
85 changes: 73 additions & 12 deletions src/Middleware/WebSockets/src/WebSocketMiddleware.cs
Expand Up @@ -118,16 +118,7 @@ public bool IsWebSocketRequest
}
else
{
var requestHeaders = _context.Request.Headers;
var interestingHeaders = new List<KeyValuePair<string, string>>();
foreach (var headerName in HandshakeHelpers.NeededHeaders)
{
foreach (var value in requestHeaders.GetCommaSeparatedValues(headerName))
{
interestingHeaders.Add(new KeyValuePair<string, string>(headerName, value));
}
}
_isWebSocketRequest = HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, interestingHeaders, requestHeaders);
_isWebSocketRequest = CheckSupportedWebSocketRequest(_context.Request.Method, _context.Request.Headers);
}
}
return _isWebSocketRequest.Value;
Expand All @@ -148,8 +139,7 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
}

TimeSpan keepAliveInterval = _options.KeepAliveInterval;
var advancedAcceptContext = acceptContext as ExtendedWebSocketAcceptContext;
if (advancedAcceptContext != null)
if (acceptContext is ExtendedWebSocketAcceptContext advancedAcceptContext)
{
if (advancedAcceptContext.KeepAliveInterval.HasValue)
{
Expand All @@ -165,6 +155,77 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)

return WebSocket.CreateFromStream(opaqueTransport, isServer: true, subProtocol: subProtocol, keepAliveInterval: keepAliveInterval);
}

public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
{
if (!HttpMethods.IsGet(method))
{
return false;
}

var foundHeader = false;

var values = requestHeaders.GetCommaSeparatedValues(HeaderNames.SecWebSocketVersion);
foreach (var value in values)
{
if (string.Equals(value, Constants.Headers.SupportedVersion, StringComparison.OrdinalIgnoreCase))
{
// WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
if (values.Length == 1)
{
requestHeaders.SecWebSocketVersion = Constants.Headers.SupportedVersion;
}
foundHeader = true;
break;
}
}
if (!foundHeader)
{
return false;
}
foundHeader = false;

values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Connection);
foreach (var value in values)
{
if (string.Equals(value, HeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase))
{
// WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
if (values.Length == 1)
{
requestHeaders.Connection = HeaderNames.Upgrade;
}
foundHeader = true;
break;
}
}
if (!foundHeader)
{
return false;
}
foundHeader = false;

values = requestHeaders.GetCommaSeparatedValues(HeaderNames.Upgrade);
foreach (var value in values)
{
if (string.Equals(value, Constants.Headers.UpgradeWebSocket, StringComparison.OrdinalIgnoreCase))
{
// WebSockets are long lived; so if the header values are valid we switch them out for the interned versions.
if (values.Length == 1)
{
requestHeaders.Upgrade = Constants.Headers.UpgradeWebSocket;
}
foundHeader = true;
break;
}
}
if (!foundHeader)
{
return false;
}

return HandshakeHelpers.IsRequestKeyValid(requestHeaders.SecWebSocketKey.ToString());
}
}
}
}
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Net;
using System.Runtime.ExceptionServices;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
Expand All @@ -18,6 +19,7 @@ public class KestrelWebSocketHelpers
{
public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int port, Func<HttpContext, Task> app, Action<WebSocketOptions> configure = null)
{
Exception exceptionFromApp = null;
configure = configure ?? (o => { });
Action<IApplicationBuilder> startup = builder =>
{
Expand All @@ -31,6 +33,8 @@ public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int por
}
catch (Exception ex)
{
// capture the exception from the app, we'll throw this at the end of the test when the server is disposed
exceptionFromApp = ex;
if (ct.Response.HasStarted)
{
throw;
Expand Down Expand Up @@ -64,12 +68,37 @@ public static IDisposable CreateServer(ILoggerFactory loggerFactory, out int por
options.Listen(IPAddress.Loopback, 0);
})
.Configure(startup);
}).ConfigureHostOptions(o =>
{
o.ShutdownTimeout = TimeSpan.FromSeconds(30);
}).Build();

host.Start();
port = host.GetPort();

return host;
return new Disposable(() =>
{
host.Dispose();
if (exceptionFromApp is not null)
{
ExceptionDispatchInfo.Throw(exceptionFromApp);
}
});
}

private class Disposable : IDisposable
{
private readonly Action _dispose;

public Disposable(Action dispose)
{
_dispose = dispose;
}

public void Dispose()
{
_dispose();
}
}
}
}
Expand Down

0 comments on commit db8a649

Please sign in to comment.