diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java
index bcf71950021..62f30105636 100644
--- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java
+++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java
@@ -40,7 +40,6 @@
import java.net.URI;
import java.nio.channels.ClosedChannelException;
-import java.util.Locale;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
@@ -51,8 +50,6 @@
*/
public abstract class WebSocketClientHandshaker {
- private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
- private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
private final URI uri;
@@ -548,31 +545,4 @@ static CharSequence websocketHostValue(URI wsURL) {
// See https://tools.ietf.org/html/rfc6454#section-6.2
return NetUtil.toSocketAddressString(host, port);
}
-
- static CharSequence websocketOriginValue(URI wsURL) {
- String scheme = wsURL.getScheme();
- final String schemePrefix;
- int port = wsURL.getPort();
- final int defaultPort;
- if (WebSocketScheme.WSS.name().contentEquals(scheme)
- || HttpScheme.HTTPS.name().contentEquals(scheme)
- || (scheme == null && port == WebSocketScheme.WSS.port())) {
-
- schemePrefix = HTTPS_SCHEME_PREFIX;
- defaultPort = WebSocketScheme.WSS.port();
- } else {
- schemePrefix = HTTP_SCHEME_PREFIX;
- defaultPort = WebSocketScheme.WS.port();
- }
-
- // Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
- String host = wsURL.getHost().toLowerCase(Locale.US);
-
- if (port != defaultPort && port != -1) {
- // if the port is not standard (80/443) its needed to add the port to the header.
- // See https://tools.ietf.org/html/rfc6454#section-6.2
- return schemePrefix + NetUtil.toSocketAddressString(host, port);
- }
- return schemePrefix + host;
- }
}
diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java
index c4585872456..9fbbf62c679 100644
--- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java
+++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java
@@ -25,39 +25,30 @@
import io.netty5.handler.codec.http.HttpMethod;
import io.netty5.handler.codec.http.HttpResponseStatus;
import io.netty5.handler.codec.http.HttpVersion;
-import io.netty5.util.CharsetUtil;
-import io.netty5.util.internal.logging.InternalLogger;
-import io.netty5.util.internal.logging.InternalLoggerFactory;
+import io.netty5.util.internal.StringUtil;
import java.net.URI;
/**
*
- * Performs client side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol-
- * 17
+ * Performs client side opening and closing handshakes for web socket specification version
+ * websocketprotocol-v13
*
*/
public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
- private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketClientHandshaker13.class);
-
- public static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
-
- private String expectedChallengeResponseString;
-
private final boolean allowExtensions;
private final boolean performMasking;
private final boolean allowMaskMismatch;
+ private volatile String sentNonce;
+
/**
* Creates a new instance.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
- * @param version
- * Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
@@ -67,9 +58,9 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/
- public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
+ public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
- this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
+ this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
true, false);
}
@@ -79,8 +70,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
- * @param version
- * Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
@@ -97,10 +86,10 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
*/
- public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
+ public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
- this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
+ this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
performMasking, allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}
@@ -110,8 +99,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
- * @param version
- * Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
@@ -130,11 +117,11 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified.
*/
- public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
+ public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis) {
- this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
+ this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
allowMaskMismatch, forceCloseTimeoutMillis, false);
}
@@ -144,8 +131,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
- * @param version
- * Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
@@ -167,12 +152,12 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
* clear HTTP
*/
- WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
+ WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
- super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis,
- absoluteUpgradeUrl);
+ super(webSocketURL, WebSocketVersion.V13, subprotocol, customHeaders, maxFramePayloadLength,
+ forceCloseTimeoutMillis, absoluteUpgradeUrl);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
@@ -190,7 +175,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
- * Origin: http://example.com
* Sec-WebSocket-Protocol: chat, superchat
* Sec-WebSocket-Version: 13
*
@@ -199,22 +183,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
@Override
protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
URI wsURL = uri();
-
- // Get 16 bit nonce and base 64 encode it
- byte[] nonce = WebSocketUtil.randomBytes(16);
- String key = WebSocketUtil.base64(nonce);
-
- String acceptSeed = key + MAGIC_GUID;
- byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
- expectedChallengeResponseString = WebSocketUtil.base64(sha1);
-
- if (logger.isDebugEnabled()) {
- logger.debug(
- "WebSocket version 13 client handshake key: {}, expected response: {}",
- key, expectedChallengeResponseString);
- }
-
- // Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL),
allocator.allocate(0));
HttpHeaders headers = request.headers();
@@ -223,24 +191,21 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
headers.add(customHeaders);
if (!headers.contains(HttpHeaderNames.HOST)) {
// Only add HOST header if customHeaders did not contain it.
- //
- // See https://github.com/netty/netty/issues/10101
+ // See https://github.com/netty/netty/issues/10101.
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}
} else {
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}
+ String nonce = createNonce();
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
- .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key);
-
- if (!headers.contains(HttpHeaderNames.ORIGIN)) {
- headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
- }
+ .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce);
+ sentNonce = nonce;
String expectedSubprotocol = expectedSubprotocol();
- if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
+ if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
@@ -269,7 +234,7 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
protected void verify(FullHttpResponse response) {
HttpResponseStatus status = response.status();
if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
- throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response);
+ throw new WebSocketClientHandshakeException("Invalid handshake response status: " + status, response);
}
HttpHeaders headers = response.headers();
@@ -283,10 +248,16 @@ protected void verify(FullHttpResponse response) {
+ headers.get(HttpHeaderNames.CONNECTION), response);
}
- CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
- if (accept == null || !accept.equals(expectedChallengeResponseString)) {
- throw new WebSocketClientHandshakeException(String.format(
- "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response);
+ String accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
+ if (accept == null) {
+ throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: null",
+ response);
+ }
+
+ String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce);
+ if (!expectedAccept.equals(accept.trim())) {
+ throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: " + accept +
+ ", expected: " + expectedAccept, response);
}
}
@@ -306,4 +277,12 @@ public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTim
return this;
}
+ /**
+ * Creates a nonce consisting of a randomly selected 16-byte value
+ * that has been base64-encoded.
+ */
+ private static String createNonce() {
+ var nonce = WebSocketUtil.randomBytes(16);
+ return WebSocketUtil.base64(nonce);
+ }
}
diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java
index 152b57be752..29065020b3d 100644
--- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java
+++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java
@@ -144,7 +144,7 @@ public static WebSocketClientHandshaker newHandshaker(
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) {
if (version == V13) {
return new WebSocketClientHandshaker13(
- webSocketURL, V13, subprotocol, allowExtensions, customHeaders,
+ webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis);
}
@@ -187,7 +187,7 @@ public static WebSocketClientHandshaker newHandshaker(
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
if (version == V13) {
return new WebSocketClientHandshaker13(
- webSocketURL, V13, subprotocol, allowExtensions, customHeaders,
+ webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl);
}
diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java
index 125172dff65..e70844b6d7f 100644
--- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java
+++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketServerHandshaker13.java
@@ -23,9 +23,8 @@
import io.netty5.handler.codec.http.HttpHeaderValues;
import io.netty5.handler.codec.http.HttpHeaders;
import io.netty5.handler.codec.http.HttpResponseStatus;
-import io.netty5.util.CharsetUtil;
-import static io.netty5.handler.codec.http.HttpVersion.*;
+import static io.netty5.handler.codec.http.HttpVersion.HTTP_1_1;
/**
*
@@ -35,8 +34,6 @@
*/
public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
- public static final String WEBSOCKET_13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
-
/**
* Constructor specifying the destination web socket location
*
@@ -136,7 +133,7 @@ public WebSocketServerHandshaker13(
@Override
protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req,
HttpHeaders headers) {
- CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
+ String key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
}
@@ -147,14 +144,7 @@ protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullH
res.headers().add(headers);
}
- String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
- byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
- String accept = WebSocketUtil.base64(sha1);
-
- if (logger.isDebugEnabled()) {
- logger.debug("WebSocket version 13 server handshake key: {}, response: {}", key, accept);
- }
-
+ String accept = WebSocketUtil.calculateV13Accept(key);
res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java
index 988eacd9a10..ebbb7a5104b 100644
--- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java
+++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketUtil.java
@@ -15,6 +15,7 @@
*/
package io.netty5.handler.codec.http.websocketx;
+import io.netty5.util.CharsetUtil;
import io.netty5.util.concurrent.FastThreadLocal;
import java.security.MessageDigest;
@@ -23,67 +24,67 @@
import java.util.concurrent.ThreadLocalRandom;
/**
- * A utility class mainly for use by web sockets
+ * A utility class mainly for use by web sockets.
*/
final class WebSocketUtil {
- private static final FastThreadLocal SHA1 = new FastThreadLocal() {
+ private static final String V13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+ private static final FastThreadLocal SHA1 = new FastThreadLocal<>() {
@Override
protected MessageDigest initialValue() throws Exception {
try {
- //Try to get a MessageDigest that uses SHA1
- //Suppress a warning about weak hash algorithm
- //since it's defined in draft-ietf-hybi-thewebsocketprotocol-00
- return MessageDigest.getInstance("SHA1"); // lgtm [java/weak-cryptographic-algorithm]
+ // Try to get a MessageDigest that uses SHA1.
+ // Suppress a warning about weak hash algorithm
+ // since it's defined in https://datatracker.ietf.org/doc/html/rfc6455#section-10.8.
+ return MessageDigest.getInstance("SHA-1"); // lgtm [java/weak-cryptographic-algorithm]
} catch (NoSuchAlgorithmException e) {
- //This shouldn't happen! How old is the computer?
- throw new InternalError("SHA-1 not supported on this platform - Outdated?");
+ // This shouldn't happen! How old is the computer ?
+ throw new InternalError("SHA-1 not supported on this platform - Outdated ?");
}
}
};
/**
- * Performs a SHA-1 hash on the specified data
+ * Performs a SHA-1 hash on the specified data.
*
* @param data The data to hash
- * @return The hashed data
+ * @return the hashed data
*/
static byte[] sha1(byte[] data) {
- // TODO(normanmaurer): Create sha1 method that not need MessageDigest.
- return digest(SHA1, data);
- }
-
- private static byte[] digest(FastThreadLocal digestFastThreadLocal, byte[] data) {
- MessageDigest digest = digestFastThreadLocal.get();
- digest.reset();
- return digest.digest(data);
+ MessageDigest sha1Digest = SHA1.get();
+ sha1Digest.reset();
+ return sha1Digest.digest(data);
}
/**
- * Performs base64 encoding on the specified data
+ * Performs base64 encoding on the specified data.
*
* @param data The data to encode
- * @return An encoded string containing the data
+ * @return an encoded string containing the data
*/
static String base64(byte[] data) {
return Base64.getEncoder().encodeToString(data);
}
+
/**
- * Creates an arbitrary number of random bytes
+ * Creates an arbitrary number of random bytes.
*
* @param size the number of random bytes to create
- * @return An array of random bytes
+ * @return an array of random bytes
*/
static byte[] randomBytes(int size) {
- byte[] bytes = new byte[size];
+ var bytes = new byte[size];
ThreadLocalRandom.current().nextBytes(bytes);
return bytes;
}
- /**
- * A private constructor to ensure that instances of this class cannot be made
- */
+ static String calculateV13Accept(String nonce) {
+ String concat = nonce + V13_ACCEPT_GUID;
+ byte[] sha1 = WebSocketUtil.sha1(concat.getBytes(CharsetUtil.US_ASCII));
+ return WebSocketUtil.base64(sha1);
+ }
+
private WebSocketUtil() {
- // Unused
}
}
diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java
index 59bd514c6f3..e177324e52f 100644
--- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java
+++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java
@@ -15,24 +15,30 @@
*/
package io.netty5.handler.codec.http.websocketx;
+import io.netty5.handler.codec.http.DefaultFullHttpResponse;
+import io.netty5.handler.codec.http.FullHttpResponse;
import io.netty5.handler.codec.http.HttpHeaderNames;
+import io.netty5.handler.codec.http.HttpHeaderValues;
import io.netty5.handler.codec.http.HttpHeaders;
+import io.netty5.handler.codec.http.HttpResponseStatus;
+import io.netty5.handler.codec.http.HttpVersion;
+import org.junit.jupiter.api.Test;
import java.net.URI;
+import static io.netty5.buffer.api.DefaultBufferAllocators.preferredAllocator;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
public class WebSocketClientHandshaker13Test extends WebSocketClientHandshakerTest {
@Override
protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers,
boolean absoluteUpgradeUrl) {
- return new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, subprotocol, false, headers,
- 1024, true, true, 10000,
- absoluteUpgradeUrl);
- }
-
- @Override
- protected CharSequence getOriginHeaderName() {
- return HttpHeaderNames.ORIGIN;
+ return new WebSocketClientHandshaker13(uri, subprotocol, false, headers,
+ 1024, true, true, 10000,
+ absoluteUpgradeUrl);
}
@Override
@@ -43,11 +49,99 @@ protected CharSequence getProtocolHeaderName() {
@Override
protected CharSequence[] getHandshakeRequiredHeaderNames() {
return new CharSequence[] {
+ HttpHeaderNames.HOST,
HttpHeaderNames.UPGRADE,
HttpHeaderNames.CONNECTION,
HttpHeaderNames.SEC_WEBSOCKET_KEY,
- HttpHeaderNames.HOST,
HttpHeaderNames.SEC_WEBSOCKET_VERSION,
};
}
+
+ @Test
+ void testWebSocketClientInvalidUpgrade() {
+ var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null,
+ null, false);
+ var response = websocketUpgradeResponse();
+ response.headers().remove(HttpHeaderNames.UPGRADE);
+
+ final WebSocketClientHandshakeException exception;
+ try (response) {
+ exception = assertThrows(WebSocketClientHandshakeException.class,
+ () -> handshaker.finishHandshake(null, response));
+ }
+
+ assertEquals("Invalid handshake response upgrade: null", exception.getMessage());
+ assertNotNull(exception.response());
+ assertEquals(response.headers(), exception.response().headers());
+ }
+
+ @Test
+ void testWebSocketClientInvalidConnection() {
+ var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null,
+ null, false);
+ var response = websocketUpgradeResponse();
+ response.headers().set(HttpHeaderNames.CONNECTION, "Close");
+
+ final WebSocketClientHandshakeException exception;
+ try (response) {
+ exception = assertThrows(WebSocketClientHandshakeException.class,
+ () -> handshaker.finishHandshake(null, response));
+ }
+
+ assertEquals("Invalid handshake response connection: Close", exception.getMessage());
+ assertNotNull(exception.response());
+ assertEquals(response.headers(), exception.response().headers());
+ }
+
+ @Test
+ void testWebSocketClientInvalidNullAccept() {
+ var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null,
+ null, false);
+ var response = websocketUpgradeResponse();
+
+ final WebSocketClientHandshakeException exception;
+ try (response) {
+ exception = assertThrows(WebSocketClientHandshakeException.class,
+ () -> handshaker.finishHandshake(null, response));
+ }
+
+ assertEquals("Invalid handshake response sec-websocket-accept: null", exception.getMessage());
+ assertNotNull(exception.response());
+ assertEquals(response.headers(), exception.response().headers());
+ }
+
+ @Test
+ void testWebSocketClientInvalidExpectedAccept() {
+ var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null,
+ null, false);
+ final String sentNonce;
+ try (var request = handshaker.newHandshakeRequest(preferredAllocator())) {
+ sentNonce = request.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
+ }
+
+ String fakeAccept = WebSocketUtil.base64(WebSocketUtil.randomBytes(16));
+ var response = websocketUpgradeResponse();
+ response.headers().set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, fakeAccept);
+
+ final WebSocketClientHandshakeException exception;
+ try (response) {
+ exception = assertThrows(WebSocketClientHandshakeException.class,
+ () -> handshaker.finishHandshake(null, response));
+ }
+
+ String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce);
+ assertEquals("Invalid handshake response sec-websocket-accept: " + fakeAccept + ", expected: "
+ + expectedAccept, exception.getMessage());
+ assertNotNull(exception.response());
+ assertEquals(response.headers(), exception.response().headers());
+ }
+
+ private static FullHttpResponse websocketUpgradeResponse() {
+ var response = new DefaultFullHttpResponse(
+ HttpVersion.HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS, preferredAllocator().allocate(0));
+ response.headers()
+ .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
+ .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET);
+ return response;
+ }
}
diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java
index a93015f6364..7abb66f4db3 100644
--- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java
+++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java
@@ -36,6 +36,7 @@
import io.netty5.handler.codec.http.HttpResponseStatus;
import io.netty5.handler.codec.http.HttpVersion;
import io.netty5.util.CharsetUtil;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
@@ -45,6 +46,8 @@
import static io.netty5.buffer.api.DefaultBufferAllocators.preferredAllocator;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
public abstract class WebSocketClientHandshakerTest {
@@ -55,16 +58,14 @@ protected WebSocketClientHandshaker newHandshaker(URI uri) {
return newHandshaker(uri, null, null, false);
}
- protected abstract CharSequence getOriginHeaderName();
-
protected abstract CharSequence getProtocolHeaderName();
protected abstract CharSequence[] getHandshakeRequiredHeaderNames();
@Test
- public void hostHeaderWs() {
- for (String scheme : new String[]{"ws://", "http://"}) {
- for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) {
+ void hostHeaderWs() {
+ for (String scheme : new String[] { "ws://", "http://" }) {
+ for (String host : new String[] { "localhost", "127.0.0.1", "[::1]", "Netty.io" }) {
String enter = scheme + host;
testHostHeader(enter, host);
@@ -81,9 +82,9 @@ public void hostHeaderWs() {
}
@Test
- public void hostHeaderWss() {
- for (String scheme : new String[]{"wss://", "https://"}) {
- for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) {
+ void hostHeaderWss() {
+ for (String scheme : new String[] { "wss://", "https://" }) {
+ for (String host : new String[] { "localhost", "127.0.0.1", "[::1]", "Netty.io" }) {
String enter = scheme + host;
testHostHeader(enter, host);
@@ -100,7 +101,7 @@ public void hostHeaderWss() {
}
@Test
- public void hostHeaderWithoutScheme() {
+ void hostHeaderWithoutScheme() {
testHostHeader("//localhost/", "localhost");
testHostHeader("//localhost/path", "localhost");
testHostHeader("//localhost:80/", "localhost:80");
@@ -109,93 +110,7 @@ public void hostHeaderWithoutScheme() {
}
@Test
- public void originHeaderWs() {
- for (String scheme : new String[]{"ws://", "http://"}) {
- for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) {
- String enter = scheme + host;
- String expect = "http://" + host.toLowerCase();
-
- testOriginHeader(enter, expect);
- testOriginHeader(enter + '/', expect);
- testOriginHeader(enter + ":80", expect);
- testOriginHeader(enter + ":443", expect + ":443");
- testOriginHeader(enter + ":9999", expect + ":9999");
- testOriginHeader(enter + "/path%20with%20ws", expect);
- testOriginHeader(enter + ":80/path%20with%20ws", expect);
- testOriginHeader(enter + ":443/path%20with%20ws", expect + ":443");
- testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999");
- }
- }
- }
-
- @Test
- public void originHeaderWss() {
- for (String scheme : new String[]{"wss://", "https://"}) {
- for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) {
- String enter = scheme + host;
- String expect = "https://" + host.toLowerCase();
-
- testOriginHeader(enter, expect);
- testOriginHeader(enter + '/', expect);
- testOriginHeader(enter + ":80", expect + ":80");
- testOriginHeader(enter + ":443", expect);
- testOriginHeader(enter + ":9999", expect + ":9999");
- testOriginHeader(enter + "/path%20with%20ws", expect);
- testOriginHeader(enter + ":80/path%20with%20ws", expect + ":80");
- testOriginHeader(enter + ":443/path%20with%20ws", expect);
- testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999");
- }
- }
- }
-
- @Test
- public void originHeaderWithoutScheme() {
- testOriginHeader("//localhost/", "http://localhost");
- testOriginHeader("//localhost/path", "http://localhost");
-
- // http scheme by port
- testOriginHeader("//localhost:80/", "http://localhost");
- testOriginHeader("//localhost:80/path", "http://localhost");
-
- // https scheme by port
- testOriginHeader("//localhost:443/", "https://localhost");
- testOriginHeader("//localhost:443/path", "https://localhost");
-
- // http scheme for non standard port
- testOriginHeader("//localhost:9999/", "http://localhost:9999");
- testOriginHeader("//localhost:9999/path", "http://localhost:9999");
-
- // convert host to lower case
- testOriginHeader("//LOCALHOST/", "http://localhost");
- }
-
- @Test
- public void testSetOriginFromCustomHeaders() {
- HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com");
- WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null,
- customHeaders, false);
- try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) {
- assertEquals("http://example.com", request.headers().get(getOriginHeaderName()));
- }
- }
-
- private void testHostHeader(String uri, String expected) {
- testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected);
- }
-
- private void testOriginHeader(String uri, String expected) {
- testHeaderDefaultHttp(uri, getOriginHeaderName(), expected);
- }
-
- protected void testHeaderDefaultHttp(String uri, CharSequence header, String expectedValue) {
- WebSocketClientHandshaker handshaker = newHandshaker(URI.create(uri));
- try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) {
- assertEquals(expectedValue, request.headers().get(header));
- }
- }
-
- @Test
- public void testUpgradeUrl() {
+ void testUpgradeUrl() {
URI uri = URI.create("ws://localhost:9999/path%20with%20ws");
WebSocketClientHandshaker handshaker = newHandshaker(uri);
try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) {
@@ -204,7 +119,7 @@ public void testUpgradeUrl() {
}
@Test
- public void testUpgradeUrlWithQuery() {
+ void testUpgradeUrlWithQuery() {
URI uri = URI.create("ws://localhost:9999/path%20with%20ws?a=b%20c");
WebSocketClientHandshaker handshaker = newHandshaker(uri);
try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) {
@@ -213,7 +128,7 @@ public void testUpgradeUrlWithQuery() {
}
@Test
- public void testUpgradeUrlWithoutPath() {
+ void testUpgradeUrlWithoutPath() {
URI uri = URI.create("ws://localhost:9999");
WebSocketClientHandshaker handshaker = newHandshaker(uri);
try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) {
@@ -222,7 +137,7 @@ public void testUpgradeUrlWithoutPath() {
}
@Test
- public void testUpgradeUrlWithoutPathWithQuery() {
+ void testUpgradeUrlWithoutPathWithQuery() {
URI uri = URI.create("ws://localhost:9999?a=b%20c");
WebSocketClientHandshaker handshaker = newHandshaker(uri);
try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) {
@@ -231,7 +146,7 @@ public void testUpgradeUrlWithoutPathWithQuery() {
}
@Test
- public void testAbsoluteUpgradeUrlWithQuery() {
+ void testAbsoluteUpgradeUrlWithQuery() {
URI uri = URI.create("ws://localhost:9999/path%20with%20ws?a=b%20c");
WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, true);
try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) {
@@ -241,13 +156,13 @@ public void testAbsoluteUpgradeUrlWithQuery() {
@Test
@Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
- public void testHttpResponseAndFrameInSameBuffer() {
+ void testHttpResponseAndFrameInSameBuffer() {
testHttpResponseAndFrameInSameBuffer(false);
}
@Test
@Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
- public void testHttpResponseAndFrameInSameBufferCodec() {
+ void testHttpResponseAndFrameInSameBufferCodec() {
testHttpResponseAndFrameInSameBuffer(true);
}
@@ -288,11 +203,11 @@ protected WebSocketFrameEncoder newWebSocketEncoder() {
WebSocketServerHandshaker socketServerHandshaker = factory.newHandshaker(request);
request.close();
EmbeddedChannel websocketChannel = new EmbeddedChannel(socketServerHandshaker.newWebSocketEncoder(),
- socketServerHandshaker.newWebsocketDecoder());
+ socketServerHandshaker.newWebsocketDecoder());
assertTrue(websocketChannel.writeOutbound(
new BinaryWebSocketFrame(websocketChannel.bufferAllocator().copyOf(data))));
- byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\nContent-Length: 0\r\n\r\n".getBytes(CharsetUtil.US_ASCII);
+ byte[] bytes = "HTTP/1.1 101 Switching Protocols\r\n\r\n".getBytes(CharsetUtil.US_ASCII);
CompositeBuffer compositeBuffer = CompositeBuffer.compose(websocketChannel.bufferAllocator());
compositeBuffer.extendWith(websocketChannel.bufferAllocator().allocate(bytes.length).writeBytes(bytes).send());
@@ -305,13 +220,14 @@ protected WebSocketFrameEncoder newWebSocketEncoder() {
}
EmbeddedChannel ch = new EmbeddedChannel(new HttpObjectAggregator(Integer.MAX_VALUE),
- new SimpleChannelInboundHandler() {
- @Override
- protected void messageReceived(ChannelHandlerContext ctx, FullHttpResponse msg) {
- handshaker.finishHandshake(ctx.channel(), msg);
- ctx.pipeline().remove(this);
- }
- });
+ new SimpleChannelInboundHandler() {
+ @Override
+ protected void messageReceived(ChannelHandlerContext ctx,
+ FullHttpResponse msg) {
+ handshaker.finishHandshake(ctx.channel(), msg);
+ ctx.pipeline().remove(this);
+ }
+ });
if (codec) {
ch.pipeline().addFirst(new HttpClientCodec());
} else {
@@ -340,7 +256,7 @@ protected void messageReceived(ChannelHandlerContext ctx, FullHttpResponse msg)
}
@Test
- public void testDuplicateWebsocketHandshakeHeaders() {
+ void testDuplicateWebsocketHandshakeHeaders() {
URI uri = URI.create("ws://localhost:9999/foo");
HttpHeaders inputHeaders = new DefaultHttpHeaders();
@@ -374,21 +290,60 @@ public void testDuplicateWebsocketHandshakeHeaders() {
}
@Test
- public void testWebSocketClientHandshakeException() {
- URI uri = URI.create("ws://localhost:9999/exception");
- WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false);
- FullHttpResponse response = new DefaultFullHttpResponse(
+ void testSetHostHeaderIfNoPresentInCustomHeaders() {
+ var customHeaders = new DefaultHttpHeaders();
+ customHeaders.set(HttpHeaderNames.HOST, "custom-host");
+ var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null,
+ customHeaders, false);
+ try (var request = handshaker.newHandshakeRequest(preferredAllocator())) {
+ assertEquals("custom-host", request.headers().get(HttpHeaderNames.HOST));
+ }
+ }
+
+ @Test
+ void testNoOriginHeaderInHandshakeRequest() {
+ var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null,
+ null, false);
+ try (var request = handshaker.newHandshakeRequest(preferredAllocator())) {
+ assertNull(request.headers().get(HttpHeaderNames.ORIGIN));
+ }
+ }
+
+ @Test
+ void testSetOriginFromCustomHeaders() {
+ var customHeaders = new DefaultHttpHeaders().set(HttpHeaderNames.ORIGIN, "http://example.com");
+ var handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null,
+ customHeaders, false);
+ try (var request = handshaker.newHandshakeRequest(preferredAllocator())) {
+ assertEquals("http://example.com", request.headers().get(HttpHeaderNames.ORIGIN));
+ }
+ }
+
+ @Test
+ void testWebSocketClientHandshakeExceptionContainsResponse() {
+ var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null,
+ null, false);
+ var response = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED, preferredAllocator().allocate(0));
+ response.headers().set(HttpHeaderNames.WWW_AUTHENTICATE, "realm = access token required");
+ final WebSocketClientHandshakeException exception;
try (response) {
- response.headers().set(HttpHeaderNames.WWW_AUTHENTICATE, "realm = access token required");
- handshaker.finishHandshake(null, response);
- } catch (WebSocketClientHandshakeException exception) {
- assertEquals("Invalid handshake response getStatus: 401 Unauthorized", exception.getMessage());
- assertEquals(HttpResponseStatus.UNAUTHORIZED, exception.response().status());
- assertTrue(exception.response().headers().contains(HttpHeaderNames.WWW_AUTHENTICATE,
- "realm = access token required", false));
+ exception = Assertions.assertThrows(WebSocketClientHandshakeException.class,
+ () -> handshaker.finishHandshake(null, response));
}
+ assertEquals("Invalid handshake response status: 401 Unauthorized", exception.getMessage());
+ assertNotNull(exception.response());
+ assertEquals(HttpResponseStatus.UNAUTHORIZED, exception.response().status());
+ assertEquals(1, exception.response().headers().size());
+ assertTrue(exception.response().headers().contains(HttpHeaderNames.WWW_AUTHENTICATE,
+ "realm = access token required", false));
}
-}
+ private void testHostHeader(String uri, String expected) {
+ var handshaker = newHandshaker(URI.create(uri));
+ try (var request = handshaker.newHandshakeRequest(preferredAllocator())) {
+ assertEquals(expected, request.headers().get(HttpHeaderNames.HOST));
+ }
+ }
+}
diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java
index 44c5dd53873..7468bd4b6f0 100644
--- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java
+++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java
@@ -119,7 +119,7 @@ protected void messageReceived(ChannelHandlerContext ctx, Object msg) {
assertTrue(serverReceivedHandshake);
assertNotNull(serverHandshakeComplete);
assertEquals("/test", serverHandshakeComplete.requestUri());
- assertEquals(8, serverHandshakeComplete.requestHeaders().size());
+ assertEquals(7, serverHandshakeComplete.requestHeaders().size());
assertEquals("test-proto-2", serverHandshakeComplete.selectedSubprotocol());
// Transfer the handshake response and the websocket message to the client
diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java
index f41206e6e9a..3dae2af4abe 100644
--- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java
+++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketUtilTest.java
@@ -16,30 +16,48 @@
package io.netty5.handler.codec.http.websocketx;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.util.Arrays;
+import java.util.Base64;
import java.util.concurrent.ThreadLocalRandom;
-import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
public class WebSocketUtilTest {
- // how many times do we want to run each random variable checker
- private static final int NUM_ITERATIONS = 1000;
-
- private static void assertRandomWithinBoundaries(int min, int max) {
- int r = ThreadLocalRandom.current().nextInt(min, max + 1);
- assertTrue(min <= r && r <= max);
+ @ParameterizedTest
+ @ValueSource(ints = {2, 4, 8, 16, 24, 48, 64, 128 })
+ void testRandomBytes(int size) {
+ var random1 = WebSocketUtil.randomBytes(size);
+ var random2 = WebSocketUtil.randomBytes(size);
+ assertEquals(size, random1.length);
+ assertEquals(size, random2.length);
+ assertFalse(Arrays.equals(random1, random2));
}
@Test
- public void testRandomNumberGenerator() {
- int iteration = 0;
- while (++iteration < NUM_ITERATIONS) {
- assertRandomWithinBoundaries(0, 1);
- assertRandomWithinBoundaries(0, 1);
- assertRandomWithinBoundaries(-1, 1);
- assertRandomWithinBoundaries(-1, 0);
- }
+ void testBase64() {
+ var random = WebSocketUtil.randomBytes(8);
+ String expected = Base64.getEncoder().encodeToString(random);
+ assertEquals(expected, WebSocketUtil.base64(random));
}
+ @Test
+ void testCalculateV13Accept() throws Exception {
+ var random = new byte[16];
+ ThreadLocalRandom.current().nextBytes(random);
+ String nonce = Base64.getEncoder().encodeToString(random);
+
+ String concat = nonce + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+ MessageDigest sha1Digest = MessageDigest.getInstance("SHA-1");
+ byte[] sha1 = sha1Digest.digest(concat.getBytes(StandardCharsets.US_ASCII));
+ String expectedAccept = Base64.getEncoder().encodeToString(sha1);
+
+ assertEquals(expectedAccept, WebSocketUtil.calculateV13Accept(nonce));
+ }
}