From 68f95ecc2f711f73db9ac0d889b55bc1b77cfff7 Mon Sep 17 00:00:00 2001 From: Andrey Mizurov Date: Thu, 21 Apr 2022 18:54:00 +0200 Subject: [PATCH] Remove ORIGIN header from ws opening handshake v13, clean up code (#12293) Motivation: In Netty 5 we only left the latest 13 version of websocket, so we can clean up the code a bit. Modification: - Remove `origin` header from handshake request, it can be set with `customHeaders`, see #9673 - Calculate expected value of `sec-websocket-accept` only when we received a response - Remove unused code - Add more tests for clientHandshaker13 Result: Removed not used code and fixed wrong behaviour with origin header #9673. --- .../websocketx/WebSocketClientHandshaker.java | 30 --- .../WebSocketClientHandshaker13.java | 97 ++++----- .../WebSocketClientHandshakerFactory.java | 4 +- .../WebSocketServerHandshaker13.java | 16 +- .../codec/http/websocketx/WebSocketUtil.java | 55 ++--- .../WebSocketClientHandshaker13Test.java | 112 +++++++++- .../WebSocketClientHandshakerTest.java | 203 +++++++----------- .../WebSocketHandshakeHandOverTest.java | 2 +- .../http/websocketx/WebSocketUtilTest.java | 48 +++-- 9 files changed, 287 insertions(+), 280 deletions(-) diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java index bcf71950021..62f30105636 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -40,7 +40,6 @@ import java.net.URI; import java.nio.channels.ClosedChannelException; -import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; @@ -51,8 +50,6 @@ */ public abstract class WebSocketClientHandshaker { - private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://"; - private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://"; protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000; private final URI uri; @@ -548,31 +545,4 @@ static CharSequence websocketHostValue(URI wsURL) { // See https://tools.ietf.org/html/rfc6454#section-6.2 return NetUtil.toSocketAddressString(host, port); } - - static CharSequence websocketOriginValue(URI wsURL) { - String scheme = wsURL.getScheme(); - final String schemePrefix; - int port = wsURL.getPort(); - final int defaultPort; - if (WebSocketScheme.WSS.name().contentEquals(scheme) - || HttpScheme.HTTPS.name().contentEquals(scheme) - || (scheme == null && port == WebSocketScheme.WSS.port())) { - - schemePrefix = HTTPS_SCHEME_PREFIX; - defaultPort = WebSocketScheme.WSS.port(); - } else { - schemePrefix = HTTP_SCHEME_PREFIX; - defaultPort = WebSocketScheme.WS.port(); - } - - // Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI") - String host = wsURL.getHost().toLowerCase(Locale.US); - - if (port != defaultPort && port != -1) { - // if the port is not standard (80/443) its needed to add the port to the header. - // See https://tools.ietf.org/html/rfc6454#section-6.2 - return schemePrefix + NetUtil.toSocketAddressString(host, port); - } - return schemePrefix + host; - } } diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java index c4585872456..9fbbf62c679 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java @@ -25,39 +25,30 @@ import io.netty5.handler.codec.http.HttpMethod; import io.netty5.handler.codec.http.HttpResponseStatus; import io.netty5.handler.codec.http.HttpVersion; -import io.netty5.util.CharsetUtil; -import io.netty5.util.internal.logging.InternalLogger; -import io.netty5.util.internal.logging.InternalLoggerFactory; +import io.netty5.util.internal.StringUtil; import java.net.URI; /** *

- * Performs client side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- - * 17 + * Performs client side opening and closing handshakes for web socket specification version + * websocketprotocol-v13 *

*/ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker { - private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketClientHandshaker13.class); - - public static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - - private String expectedChallengeResponseString; - private final boolean allowExtensions; private final boolean performMasking; private final boolean allowMaskMismatch; + private volatile String sentNonce; + /** * Creates a new instance. * * @param webSocketURL * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * sent to this URL. - * @param version - * Version of web socket specification to use to connect to the server * @param subprotocol * Sub protocol request sent to the server. * @param allowExtensions @@ -67,9 +58,9 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker { * @param maxFramePayloadLength * Maximum length of a frame's payload */ - public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) { - this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false); } @@ -79,8 +70,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S * @param webSocketURL * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * sent to this URL. - * @param version - * Version of web socket specification to use to connect to the server * @param subprotocol * Sub protocol request sent to the server. * @param allowExtensions @@ -97,10 +86,10 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S * When set to true, frames which are not masked properly according to the standard will still be * accepted. */ - public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean performMasking, boolean allowMaskMismatch) { - this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS); } @@ -110,8 +99,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S * @param webSocketURL * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * sent to this URL. - * @param version - * Version of web socket specification to use to connect to the server * @param subprotocol * Sub protocol request sent to the server. * @param allowExtensions @@ -130,11 +117,11 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S * @param forceCloseTimeoutMillis * Close the connection if it was not closed by the server after timeout specified. */ - public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) { - this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, false); } @@ -144,8 +131,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S * @param webSocketURL * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be * sent to this URL. - * @param version - * Version of web socket specification to use to connect to the server * @param subprotocol * Sub protocol request sent to the server. * @param allowExtensions @@ -167,12 +152,12 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over * clear HTTP */ - WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + WebSocketClientHandshaker13(URI webSocketURL, String subprotocol, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { - super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, - absoluteUpgradeUrl); + super(webSocketURL, WebSocketVersion.V13, subprotocol, customHeaders, maxFramePayloadLength, + forceCloseTimeoutMillis, absoluteUpgradeUrl); this.allowExtensions = allowExtensions; this.performMasking = performMasking; this.allowMaskMismatch = allowMaskMismatch; @@ -190,7 +175,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S * Upgrade: websocket * Connection: Upgrade * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== - * Origin: http://example.com * Sec-WebSocket-Protocol: chat, superchat * Sec-WebSocket-Version: 13 * @@ -199,22 +183,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S @Override protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) { URI wsURL = uri(); - - // Get 16 bit nonce and base 64 encode it - byte[] nonce = WebSocketUtil.randomBytes(16); - String key = WebSocketUtil.base64(nonce); - - String acceptSeed = key + MAGIC_GUID; - byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); - expectedChallengeResponseString = WebSocketUtil.base64(sha1); - - if (logger.isDebugEnabled()) { - logger.debug( - "WebSocket version 13 client handshake key: {}, expected response: {}", - key, expectedChallengeResponseString); - } - - // Format request FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL), allocator.allocate(0)); HttpHeaders headers = request.headers(); @@ -223,24 +191,21 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) { headers.add(customHeaders); if (!headers.contains(HttpHeaderNames.HOST)) { // Only add HOST header if customHeaders did not contain it. - // - // See https://github.com/netty/netty/issues/10101 + // See https://github.com/netty/netty/issues/10101. headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); } } else { headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); } + String nonce = createNonce(); headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) - .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key); - - if (!headers.contains(HttpHeaderNames.ORIGIN)) { - headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)); - } + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce); + sentNonce = nonce; String expectedSubprotocol = expectedSubprotocol(); - if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { + if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) { headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); } @@ -269,7 +234,7 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) { protected void verify(FullHttpResponse response) { HttpResponseStatus status = response.status(); if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) { - throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response); + throw new WebSocketClientHandshakeException("Invalid handshake response status: " + status, response); } HttpHeaders headers = response.headers(); @@ -283,10 +248,16 @@ protected void verify(FullHttpResponse response) { + headers.get(HttpHeaderNames.CONNECTION), response); } - CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); - if (accept == null || !accept.equals(expectedChallengeResponseString)) { - throw new WebSocketClientHandshakeException(String.format( - "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response); + String accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); + if (accept == null) { + throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: null", + response); + } + + String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce); + if (!expectedAccept.equals(accept.trim())) { + throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: " + accept + + ", expected: " + expectedAccept, response); } } @@ -306,4 +277,12 @@ public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTim return this; } + /** + * Creates a nonce consisting of a randomly selected 16-byte value + * that has been base64-encoded. + */ + private static String createNonce() { + var nonce = WebSocketUtil.randomBytes(16); + return WebSocketUtil.base64(nonce); + } } diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java index 152b57be752..29065020b3d 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java @@ -144,7 +144,7 @@ public static WebSocketClientHandshaker newHandshaker( boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) { if (version == V13) { return new WebSocketClientHandshaker13( - webSocketURL, V13, subprotocol, allowExtensions, customHeaders, + webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis); } @@ -187,7 +187,7 @@ public static WebSocketClientHandshaker newHandshaker( boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { if (version == V13) { return new WebSocketClientHandshaker13( - webSocketURL, V13, subprotocol, allowExtensions, customHeaders, + webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl); } diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java index 125172dff65..e70844b6d7f 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java @@ -23,9 +23,8 @@ import io.netty5.handler.codec.http.HttpHeaderValues; import io.netty5.handler.codec.http.HttpHeaders; import io.netty5.handler.codec.http.HttpResponseStatus; -import io.netty5.util.CharsetUtil; -import static io.netty5.handler.codec.http.HttpVersion.*; +import static io.netty5.handler.codec.http.HttpVersion.HTTP_1_1; /** *

@@ -35,8 +34,6 @@ */ public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker { - public static final String WEBSOCKET_13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - /** * Constructor specifying the destination web socket location * @@ -136,7 +133,7 @@ public WebSocketServerHandshaker13( @Override protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req, HttpHeaders headers) { - CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + String key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); if (key == null) { throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req); } @@ -147,14 +144,7 @@ protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullH res.headers().add(headers); } - String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID; - byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); - String accept = WebSocketUtil.base64(sha1); - - if (logger.isDebugEnabled()) { - logger.debug("WebSocket version 13 server handshake key: {}, response: {}", key, accept); - } - + String accept = WebSocketUtil.calculateV13Accept(key); res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java index 988eacd9a10..ebbb7a5104b 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java @@ -15,6 +15,7 @@ */ package io.netty5.handler.codec.http.websocketx; +import io.netty5.util.CharsetUtil; import io.netty5.util.concurrent.FastThreadLocal; import java.security.MessageDigest; @@ -23,67 +24,67 @@ import java.util.concurrent.ThreadLocalRandom; /** - * A utility class mainly for use by web sockets + * A utility class mainly for use by web sockets. */ final class WebSocketUtil { - private static final FastThreadLocal SHA1 = new FastThreadLocal() { + private static final String V13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + private static final FastThreadLocal SHA1 = new FastThreadLocal<>() { @Override protected MessageDigest initialValue() throws Exception { try { - //Try to get a MessageDigest that uses SHA1 - //Suppress a warning about weak hash algorithm - //since it's defined in draft-ietf-hybi-thewebsocketprotocol-00 - return MessageDigest.getInstance("SHA1"); // lgtm [java/weak-cryptographic-algorithm] + // Try to get a MessageDigest that uses SHA1. + // Suppress a warning about weak hash algorithm + // since it's defined in https://datatracker.ietf.org/doc/html/rfc6455#section-10.8. + return MessageDigest.getInstance("SHA-1"); // lgtm [java/weak-cryptographic-algorithm] } catch (NoSuchAlgorithmException e) { - //This shouldn't happen! How old is the computer? - throw new InternalError("SHA-1 not supported on this platform - Outdated?"); + // This shouldn't happen! How old is the computer ? + throw new InternalError("SHA-1 not supported on this platform - Outdated ?"); } } }; /** - * Performs a SHA-1 hash on the specified data + * Performs a SHA-1 hash on the specified data. * * @param data The data to hash - * @return The hashed data + * @return the hashed data */ static byte[] sha1(byte[] data) { - // TODO(normanmaurer): Create sha1 method that not need MessageDigest. - return digest(SHA1, data); - } - - private static byte[] digest(FastThreadLocal digestFastThreadLocal, byte[] data) { - MessageDigest digest = digestFastThreadLocal.get(); - digest.reset(); - return digest.digest(data); + MessageDigest sha1Digest = SHA1.get(); + sha1Digest.reset(); + return sha1Digest.digest(data); } /** - * Performs base64 encoding on the specified data + * Performs base64 encoding on the specified data. * * @param data The data to encode - * @return An encoded string containing the data + * @return an encoded string containing the data */ static String base64(byte[] data) { return Base64.getEncoder().encodeToString(data); } + /** - * Creates an arbitrary number of random bytes + * Creates an arbitrary number of random bytes. * * @param size the number of random bytes to create - * @return An array of random bytes + * @return an array of random bytes */ static byte[] randomBytes(int size) { - byte[] bytes = new byte[size]; + var bytes = new byte[size]; ThreadLocalRandom.current().nextBytes(bytes); return bytes; } - /** - * A private constructor to ensure that instances of this class cannot be made - */ + static String calculateV13Accept(String nonce) { + String concat = nonce + V13_ACCEPT_GUID; + byte[] sha1 = WebSocketUtil.sha1(concat.getBytes(CharsetUtil.US_ASCII)); + return WebSocketUtil.base64(sha1); + } + private WebSocketUtil() { - // Unused } } diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java index 59bd514c6f3..e177324e52f 100644 --- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java +++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java @@ -15,24 +15,30 @@ */ package io.netty5.handler.codec.http.websocketx; +import io.netty5.handler.codec.http.DefaultFullHttpResponse; +import io.netty5.handler.codec.http.FullHttpResponse; import io.netty5.handler.codec.http.HttpHeaderNames; +import io.netty5.handler.codec.http.HttpHeaderValues; import io.netty5.handler.codec.http.HttpHeaders; +import io.netty5.handler.codec.http.HttpResponseStatus; +import io.netty5.handler.codec.http.HttpVersion; +import org.junit.jupiter.api.Test; import java.net.URI; +import static io.netty5.buffer.api.DefaultBufferAllocators.preferredAllocator; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class WebSocketClientHandshaker13Test extends WebSocketClientHandshakerTest { @Override protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, boolean absoluteUpgradeUrl) { - return new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, subprotocol, false, headers, - 1024, true, true, 10000, - absoluteUpgradeUrl); - } - - @Override - protected CharSequence getOriginHeaderName() { - return HttpHeaderNames.ORIGIN; + return new WebSocketClientHandshaker13(uri, subprotocol, false, headers, + 1024, true, true, 10000, + absoluteUpgradeUrl); } @Override @@ -43,11 +49,99 @@ protected CharSequence getProtocolHeaderName() { @Override protected CharSequence[] getHandshakeRequiredHeaderNames() { return new CharSequence[] { + HttpHeaderNames.HOST, HttpHeaderNames.UPGRADE, HttpHeaderNames.CONNECTION, HttpHeaderNames.SEC_WEBSOCKET_KEY, - HttpHeaderNames.HOST, HttpHeaderNames.SEC_WEBSOCKET_VERSION, }; } + + @Test + void testWebSocketClientInvalidUpgrade() { + var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, + null, false); + var response = websocketUpgradeResponse(); + response.headers().remove(HttpHeaderNames.UPGRADE); + + final WebSocketClientHandshakeException exception; + try (response) { + exception = assertThrows(WebSocketClientHandshakeException.class, + () -> handshaker.finishHandshake(null, response)); + } + + assertEquals("Invalid handshake response upgrade: null", exception.getMessage()); + assertNotNull(exception.response()); + assertEquals(response.headers(), exception.response().headers()); + } + + @Test + void testWebSocketClientInvalidConnection() { + var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, + null, false); + var response = websocketUpgradeResponse(); + response.headers().set(HttpHeaderNames.CONNECTION, "Close"); + + final WebSocketClientHandshakeException exception; + try (response) { + exception = assertThrows(WebSocketClientHandshakeException.class, + () -> handshaker.finishHandshake(null, response)); + } + + assertEquals("Invalid handshake response connection: Close", exception.getMessage()); + assertNotNull(exception.response()); + assertEquals(response.headers(), exception.response().headers()); + } + + @Test + void testWebSocketClientInvalidNullAccept() { + var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, + null, false); + var response = websocketUpgradeResponse(); + + final WebSocketClientHandshakeException exception; + try (response) { + exception = assertThrows(WebSocketClientHandshakeException.class, + () -> handshaker.finishHandshake(null, response)); + } + + assertEquals("Invalid handshake response sec-websocket-accept: null", exception.getMessage()); + assertNotNull(exception.response()); + assertEquals(response.headers(), exception.response().headers()); + } + + @Test + void testWebSocketClientInvalidExpectedAccept() { + var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, + null, false); + final String sentNonce; + try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { + sentNonce = request.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + } + + String fakeAccept = WebSocketUtil.base64(WebSocketUtil.randomBytes(16)); + var response = websocketUpgradeResponse(); + response.headers().set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, fakeAccept); + + final WebSocketClientHandshakeException exception; + try (response) { + exception = assertThrows(WebSocketClientHandshakeException.class, + () -> handshaker.finishHandshake(null, response)); + } + + String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce); + assertEquals("Invalid handshake response sec-websocket-accept: " + fakeAccept + ", expected: " + + expectedAccept, exception.getMessage()); + assertNotNull(exception.response()); + assertEquals(response.headers(), exception.response().headers()); + } + + private static FullHttpResponse websocketUpgradeResponse() { + var response = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS, preferredAllocator().allocate(0)); + response.headers() + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); + return response; + } } diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java index a93015f6364..7abb66f4db3 100644 --- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java +++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java @@ -36,6 +36,7 @@ import io.netty5.handler.codec.http.HttpResponseStatus; import io.netty5.handler.codec.http.HttpVersion; import io.netty5.util.CharsetUtil; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -45,6 +46,8 @@ import static io.netty5.buffer.api.DefaultBufferAllocators.preferredAllocator; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; public abstract class WebSocketClientHandshakerTest { @@ -55,16 +58,14 @@ protected WebSocketClientHandshaker newHandshaker(URI uri) { return newHandshaker(uri, null, null, false); } - protected abstract CharSequence getOriginHeaderName(); - protected abstract CharSequence getProtocolHeaderName(); protected abstract CharSequence[] getHandshakeRequiredHeaderNames(); @Test - public void hostHeaderWs() { - for (String scheme : new String[]{"ws://", "http://"}) { - for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) { + void hostHeaderWs() { + for (String scheme : new String[] { "ws://", "http://" }) { + for (String host : new String[] { "localhost", "127.0.0.1", "[::1]", "Netty.io" }) { String enter = scheme + host; testHostHeader(enter, host); @@ -81,9 +82,9 @@ public void hostHeaderWs() { } @Test - public void hostHeaderWss() { - for (String scheme : new String[]{"wss://", "https://"}) { - for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) { + void hostHeaderWss() { + for (String scheme : new String[] { "wss://", "https://" }) { + for (String host : new String[] { "localhost", "127.0.0.1", "[::1]", "Netty.io" }) { String enter = scheme + host; testHostHeader(enter, host); @@ -100,7 +101,7 @@ public void hostHeaderWss() { } @Test - public void hostHeaderWithoutScheme() { + void hostHeaderWithoutScheme() { testHostHeader("//localhost/", "localhost"); testHostHeader("//localhost/path", "localhost"); testHostHeader("//localhost:80/", "localhost:80"); @@ -109,93 +110,7 @@ public void hostHeaderWithoutScheme() { } @Test - public void originHeaderWs() { - for (String scheme : new String[]{"ws://", "http://"}) { - for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) { - String enter = scheme + host; - String expect = "http://" + host.toLowerCase(); - - testOriginHeader(enter, expect); - testOriginHeader(enter + '/', expect); - testOriginHeader(enter + ":80", expect); - testOriginHeader(enter + ":443", expect + ":443"); - testOriginHeader(enter + ":9999", expect + ":9999"); - testOriginHeader(enter + "/path%20with%20ws", expect); - testOriginHeader(enter + ":80/path%20with%20ws", expect); - testOriginHeader(enter + ":443/path%20with%20ws", expect + ":443"); - testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999"); - } - } - } - - @Test - public void originHeaderWss() { - for (String scheme : new String[]{"wss://", "https://"}) { - for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) { - String enter = scheme + host; - String expect = "https://" + host.toLowerCase(); - - testOriginHeader(enter, expect); - testOriginHeader(enter + '/', expect); - testOriginHeader(enter + ":80", expect + ":80"); - testOriginHeader(enter + ":443", expect); - testOriginHeader(enter + ":9999", expect + ":9999"); - testOriginHeader(enter + "/path%20with%20ws", expect); - testOriginHeader(enter + ":80/path%20with%20ws", expect + ":80"); - testOriginHeader(enter + ":443/path%20with%20ws", expect); - testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999"); - } - } - } - - @Test - public void originHeaderWithoutScheme() { - testOriginHeader("//localhost/", "http://localhost"); - testOriginHeader("//localhost/path", "http://localhost"); - - // http scheme by port - testOriginHeader("//localhost:80/", "http://localhost"); - testOriginHeader("//localhost:80/path", "http://localhost"); - - // https scheme by port - testOriginHeader("//localhost:443/", "https://localhost"); - testOriginHeader("//localhost:443/path", "https://localhost"); - - // http scheme for non standard port - testOriginHeader("//localhost:9999/", "http://localhost:9999"); - testOriginHeader("//localhost:9999/path", "http://localhost:9999"); - - // convert host to lower case - testOriginHeader("//LOCALHOST/", "http://localhost"); - } - - @Test - public void testSetOriginFromCustomHeaders() { - HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com"); - WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null, - customHeaders, false); - try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { - assertEquals("http://example.com", request.headers().get(getOriginHeaderName())); - } - } - - private void testHostHeader(String uri, String expected) { - testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected); - } - - private void testOriginHeader(String uri, String expected) { - testHeaderDefaultHttp(uri, getOriginHeaderName(), expected); - } - - protected void testHeaderDefaultHttp(String uri, CharSequence header, String expectedValue) { - WebSocketClientHandshaker handshaker = newHandshaker(URI.create(uri)); - try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { - assertEquals(expectedValue, request.headers().get(header)); - } - } - - @Test - public void testUpgradeUrl() { + void testUpgradeUrl() { URI uri = URI.create("ws://localhost:9999/path%20with%20ws"); WebSocketClientHandshaker handshaker = newHandshaker(uri); try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { @@ -204,7 +119,7 @@ public void testUpgradeUrl() { } @Test - public void testUpgradeUrlWithQuery() { + void testUpgradeUrlWithQuery() { URI uri = URI.create("ws://localhost:9999/path%20with%20ws?a=b%20c"); WebSocketClientHandshaker handshaker = newHandshaker(uri); try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { @@ -213,7 +128,7 @@ public void testUpgradeUrlWithQuery() { } @Test - public void testUpgradeUrlWithoutPath() { + void testUpgradeUrlWithoutPath() { URI uri = URI.create("ws://localhost:9999"); WebSocketClientHandshaker handshaker = newHandshaker(uri); try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { @@ -222,7 +137,7 @@ public void testUpgradeUrlWithoutPath() { } @Test - public void testUpgradeUrlWithoutPathWithQuery() { + void testUpgradeUrlWithoutPathWithQuery() { URI uri = URI.create("ws://localhost:9999?a=b%20c"); WebSocketClientHandshaker handshaker = newHandshaker(uri); try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { @@ -231,7 +146,7 @@ public void testUpgradeUrlWithoutPathWithQuery() { } @Test - public void testAbsoluteUpgradeUrlWithQuery() { + void testAbsoluteUpgradeUrlWithQuery() { URI uri = URI.create("ws://localhost:9999/path%20with%20ws?a=b%20c"); WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, true); try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { @@ -241,13 +156,13 @@ public void testAbsoluteUpgradeUrlWithQuery() { @Test @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) - public void testHttpResponseAndFrameInSameBuffer() { + void testHttpResponseAndFrameInSameBuffer() { testHttpResponseAndFrameInSameBuffer(false); } @Test @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) - public void testHttpResponseAndFrameInSameBufferCodec() { + void testHttpResponseAndFrameInSameBufferCodec() { testHttpResponseAndFrameInSameBuffer(true); } @@ -288,11 +203,11 @@ protected WebSocketFrameEncoder newWebSocketEncoder() { WebSocketServerHandshaker socketServerHandshaker = factory.newHandshaker(request); request.close(); EmbeddedChannel websocketChannel = new EmbeddedChannel(socketServerHandshaker.newWebSocketEncoder(), - socketServerHandshaker.newWebsocketDecoder()); + socketServerHandshaker.newWebsocketDecoder()); assertTrue(websocketChannel.writeOutbound( new BinaryWebSocketFrame(websocketChannel.bufferAllocator().copyOf(data)))); - byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\nContent-Length: 0\r\n\r\n".getBytes(CharsetUtil.US_ASCII); + byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\n\r\n".getBytes(CharsetUtil.US_ASCII); CompositeBuffer compositeBuffer = CompositeBuffer.compose(websocketChannel.bufferAllocator()); compositeBuffer.extendWith(websocketChannel.bufferAllocator().allocate(bytes.length).writeBytes(bytes).send()); @@ -305,13 +220,14 @@ protected WebSocketFrameEncoder newWebSocketEncoder() { } EmbeddedChannel ch = new EmbeddedChannel(new HttpObjectAggregator(Integer.MAX_VALUE), - new SimpleChannelInboundHandler() { - @Override - protected void messageReceived(ChannelHandlerContext ctx, FullHttpResponse msg) { - handshaker.finishHandshake(ctx.channel(), msg); - ctx.pipeline().remove(this); - } - }); + new SimpleChannelInboundHandler() { + @Override + protected void messageReceived(ChannelHandlerContext ctx, + FullHttpResponse msg) { + handshaker.finishHandshake(ctx.channel(), msg); + ctx.pipeline().remove(this); + } + }); if (codec) { ch.pipeline().addFirst(new HttpClientCodec()); } else { @@ -340,7 +256,7 @@ protected void messageReceived(ChannelHandlerContext ctx, FullHttpResponse msg) } @Test - public void testDuplicateWebsocketHandshakeHeaders() { + void testDuplicateWebsocketHandshakeHeaders() { URI uri = URI.create("ws://localhost:9999/foo"); HttpHeaders inputHeaders = new DefaultHttpHeaders(); @@ -374,21 +290,60 @@ public void testDuplicateWebsocketHandshakeHeaders() { } @Test - public void testWebSocketClientHandshakeException() { - URI uri = URI.create("ws://localhost:9999/exception"); - WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false); - FullHttpResponse response = new DefaultFullHttpResponse( + void testSetHostHeaderIfNoPresentInCustomHeaders() { + var customHeaders = new DefaultHttpHeaders(); + customHeaders.set(HttpHeaderNames.HOST, "custom-host"); + var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, + customHeaders, false); + try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { + assertEquals("custom-host", request.headers().get(HttpHeaderNames.HOST)); + } + } + + @Test + void testNoOriginHeaderInHandshakeRequest() { + var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, + null, false); + try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { + assertNull(request.headers().get(HttpHeaderNames.ORIGIN)); + } + } + + @Test + void testSetOriginFromCustomHeaders() { + var customHeaders = new DefaultHttpHeaders().set(HttpHeaderNames.ORIGIN, "http://example.com"); + var handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null, + customHeaders, false); + try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { + assertEquals("http://example.com", request.headers().get(HttpHeaderNames.ORIGIN)); + } + } + + @Test + void testWebSocketClientHandshakeExceptionContainsResponse() { + var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, + null, false); + var response = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED, preferredAllocator().allocate(0)); + response.headers().set(HttpHeaderNames.WWW_AUTHENTICATE, "realm = access token required"); + final WebSocketClientHandshakeException exception; try (response) { - response.headers().set(HttpHeaderNames.WWW_AUTHENTICATE, "realm = access token required"); - handshaker.finishHandshake(null, response); - } catch (WebSocketClientHandshakeException exception) { - assertEquals("Invalid handshake response getStatus: 401 Unauthorized", exception.getMessage()); - assertEquals(HttpResponseStatus.UNAUTHORIZED, exception.response().status()); - assertTrue(exception.response().headers().contains(HttpHeaderNames.WWW_AUTHENTICATE, - "realm = access token required", false)); + exception = Assertions.assertThrows(WebSocketClientHandshakeException.class, + () -> handshaker.finishHandshake(null, response)); } + assertEquals("Invalid handshake response status: 401 Unauthorized", exception.getMessage()); + assertNotNull(exception.response()); + assertEquals(HttpResponseStatus.UNAUTHORIZED, exception.response().status()); + assertEquals(1, exception.response().headers().size()); + assertTrue(exception.response().headers().contains(HttpHeaderNames.WWW_AUTHENTICATE, + "realm = access token required", false)); } -} + private void testHostHeader(String uri, String expected) { + var handshaker = newHandshaker(URI.create(uri)); + try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { + assertEquals(expected, request.headers().get(HttpHeaderNames.HOST)); + } + } +} diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java index 44c5dd53873..7468bd4b6f0 100644 --- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java +++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java @@ -119,7 +119,7 @@ protected void messageReceived(ChannelHandlerContext ctx, Object msg) { assertTrue(serverReceivedHandshake); assertNotNull(serverHandshakeComplete); assertEquals("/test", serverHandshakeComplete.requestUri()); - assertEquals(8, serverHandshakeComplete.requestHeaders().size()); + assertEquals(7, serverHandshakeComplete.requestHeaders().size()); assertEquals("test-proto-2", serverHandshakeComplete.selectedSubprotocol()); // Transfer the handshake response and the websocket message to the client diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java index f41206e6e9a..3dae2af4abe 100644 --- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java +++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java @@ -16,30 +16,48 @@ package io.netty5.handler.codec.http.websocketx; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.Arrays; +import java.util.Base64; import java.util.concurrent.ThreadLocalRandom; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; public class WebSocketUtilTest { - // how many times do we want to run each random variable checker - private static final int NUM_ITERATIONS = 1000; - - private static void assertRandomWithinBoundaries(int min, int max) { - int r = ThreadLocalRandom.current().nextInt(min, max + 1); - assertTrue(min <= r && r <= max); + @ParameterizedTest + @ValueSource(ints = {2, 4, 8, 16, 24, 48, 64, 128 }) + void testRandomBytes(int size) { + var random1 = WebSocketUtil.randomBytes(size); + var random2 = WebSocketUtil.randomBytes(size); + assertEquals(size, random1.length); + assertEquals(size, random2.length); + assertFalse(Arrays.equals(random1, random2)); } @Test - public void testRandomNumberGenerator() { - int iteration = 0; - while (++iteration < NUM_ITERATIONS) { - assertRandomWithinBoundaries(0, 1); - assertRandomWithinBoundaries(0, 1); - assertRandomWithinBoundaries(-1, 1); - assertRandomWithinBoundaries(-1, 0); - } + void testBase64() { + var random = WebSocketUtil.randomBytes(8); + String expected = Base64.getEncoder().encodeToString(random); + assertEquals(expected, WebSocketUtil.base64(random)); } + @Test + void testCalculateV13Accept() throws Exception { + var random = new byte[16]; + ThreadLocalRandom.current().nextBytes(random); + String nonce = Base64.getEncoder().encodeToString(random); + + String concat = nonce + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + MessageDigest sha1Digest = MessageDigest.getInstance("SHA-1"); + byte[] sha1 = sha1Digest.digest(concat.getBytes(StandardCharsets.US_ASCII)); + String expectedAccept = Base64.getEncoder().encodeToString(sha1); + + assertEquals(expectedAccept, WebSocketUtil.calculateV13Accept(nonce)); + } }