Skip to content

Commit

Permalink
Remove ORIGIN header from ws opening handshake v13, clean up code (#1…
Browse files Browse the repository at this point in the history
…2293)

Motivation:

In Netty 5 we only left the latest 13 version of websocket, so we can clean up the code a bit.

Modification:

- Remove `origin` header from handshake request, it can be set with `customHeaders`, see #9673
- Calculate expected value of `sec-websocket-accept` only when we received a response
- Remove unused code
- Add more tests for clientHandshaker13

Result:
Removed not used code and fixed wrong behaviour with origin header #9673.
  • Loading branch information
amizurov committed Apr 21, 2022
1 parent f6cc9a2 commit 68f95ec
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 280 deletions.
Expand Up @@ -40,7 +40,6 @@

import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.Locale;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

Expand All @@ -51,8 +50,6 @@
*/
public abstract class WebSocketClientHandshaker {

private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;

private final URI uri;
Expand Down Expand Up @@ -548,31 +545,4 @@ static CharSequence websocketHostValue(URI wsURL) {
// See https://tools.ietf.org/html/rfc6454#section-6.2
return NetUtil.toSocketAddressString(host, port);
}

static CharSequence websocketOriginValue(URI wsURL) {
String scheme = wsURL.getScheme();
final String schemePrefix;
int port = wsURL.getPort();
final int defaultPort;
if (WebSocketScheme.WSS.name().contentEquals(scheme)
|| HttpScheme.HTTPS.name().contentEquals(scheme)
|| (scheme == null && port == WebSocketScheme.WSS.port())) {

schemePrefix = HTTPS_SCHEME_PREFIX;
defaultPort = WebSocketScheme.WSS.port();
} else {
schemePrefix = HTTP_SCHEME_PREFIX;
defaultPort = WebSocketScheme.WS.port();
}

// Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
String host = wsURL.getHost().toLowerCase(Locale.US);

if (port != defaultPort && port != -1) {
// if the port is not standard (80/443) its needed to add the port to the header.
// See https://tools.ietf.org/html/rfc6454#section-6.2
return schemePrefix + NetUtil.toSocketAddressString(host, port);
}
return schemePrefix + host;
}
}
Expand Up @@ -25,39 +25,30 @@
import io.netty5.handler.codec.http.HttpMethod;
import io.netty5.handler.codec.http.HttpResponseStatus;
import io.netty5.handler.codec.http.HttpVersion;
import io.netty5.util.CharsetUtil;
import io.netty5.util.internal.logging.InternalLogger;
import io.netty5.util.internal.logging.InternalLoggerFactory;
import io.netty5.util.internal.StringUtil;

import java.net.URI;

/**
* <p>
* Performs client side opening and closing handshakes for web socket specification version <a
* href="https://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17" >draft-ietf-hybi-thewebsocketprotocol-
* 17</a>
* Performs client side opening and closing handshakes for web socket specification version
* <a href="https://datatracker.ietf.org/doc/html/rfc6455">websocketprotocol-v13</a>
* </p>
*/
public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {

private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketClientHandshaker13.class);

public static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

private String expectedChallengeResponseString;

private final boolean allowExtensions;
private final boolean performMasking;
private final boolean allowMaskMismatch;

private volatile String sentNonce;

/**
* Creates a new instance.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -67,9 +58,9 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
true, false);
}

Expand All @@ -79,8 +70,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -97,10 +86,10 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
performMasking, allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

Expand All @@ -110,8 +99,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -130,11 +117,11 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified.
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
allowMaskMismatch, forceCloseTimeoutMillis, false);
}

Expand All @@ -144,8 +131,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -167,12 +152,12 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
* clear HTTP
*/
WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis,
absoluteUpgradeUrl);
super(webSocketURL, WebSocketVersion.V13, subprotocol, customHeaders, maxFramePayloadLength,
forceCloseTimeoutMillis, absoluteUpgradeUrl);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
Expand All @@ -190,7 +175,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
* Origin: http://example.com
* Sec-WebSocket-Protocol: chat, superchat
* Sec-WebSocket-Version: 13
* </pre>
Expand All @@ -199,22 +183,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
@Override
protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
URI wsURL = uri();

// Get 16 bit nonce and base 64 encode it
byte[] nonce = WebSocketUtil.randomBytes(16);
String key = WebSocketUtil.base64(nonce);

String acceptSeed = key + MAGIC_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
expectedChallengeResponseString = WebSocketUtil.base64(sha1);

if (logger.isDebugEnabled()) {
logger.debug(
"WebSocket version 13 client handshake key: {}, expected response: {}",
key, expectedChallengeResponseString);
}

// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL),
allocator.allocate(0));
HttpHeaders headers = request.headers();
Expand All @@ -223,24 +191,21 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
headers.add(customHeaders);
if (!headers.contains(HttpHeaderNames.HOST)) {
// Only add HOST header if customHeaders did not contain it.
//
// See https://github.com/netty/netty/issues/10101
// See https://github.com/netty/netty/issues/10101.
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}
} else {
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}

String nonce = createNonce();
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key);

if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce);

sentNonce = nonce;
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}

Expand Down Expand Up @@ -269,7 +234,7 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
protected void verify(FullHttpResponse response) {
HttpResponseStatus status = response.status();
if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response);
throw new WebSocketClientHandshakeException("Invalid handshake response status: " + status, response);
}

HttpHeaders headers = response.headers();
Expand All @@ -283,10 +248,16 @@ protected void verify(FullHttpResponse response) {
+ headers.get(HttpHeaderNames.CONNECTION), response);
}

CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketClientHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response);
String accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null) {
throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: null",
response);
}

String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce);
if (!expectedAccept.equals(accept.trim())) {
throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: " + accept +
", expected: " + expectedAccept, response);
}
}

Expand All @@ -306,4 +277,12 @@ public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTim
return this;
}

/**
* Creates a nonce consisting of a randomly selected 16-byte value
* that has been base64-encoded.
*/
private static String createNonce() {
var nonce = WebSocketUtil.randomBytes(16);
return WebSocketUtil.base64(nonce);
}
}
Expand Up @@ -144,7 +144,7 @@ public static WebSocketClientHandshaker newHandshaker(
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) {
if (version == V13) {
return new WebSocketClientHandshaker13(
webSocketURL, V13, subprotocol, allowExtensions, customHeaders,
webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis);
}

Expand Down Expand Up @@ -187,7 +187,7 @@ public static WebSocketClientHandshaker newHandshaker(
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
if (version == V13) {
return new WebSocketClientHandshaker13(
webSocketURL, V13, subprotocol, allowExtensions, customHeaders,
webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl);
}

Expand Down
Expand Up @@ -23,9 +23,8 @@
import io.netty5.handler.codec.http.HttpHeaderValues;
import io.netty5.handler.codec.http.HttpHeaders;
import io.netty5.handler.codec.http.HttpResponseStatus;
import io.netty5.util.CharsetUtil;

import static io.netty5.handler.codec.http.HttpVersion.*;
import static io.netty5.handler.codec.http.HttpVersion.HTTP_1_1;

/**
* <p>
Expand All @@ -35,8 +34,6 @@
*/
public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {

public static final String WEBSOCKET_13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

/**
* Constructor specifying the destination web socket location
*
Expand Down Expand Up @@ -136,7 +133,7 @@ public WebSocketServerHandshaker13(
@Override
protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req,
HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
String key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
}
Expand All @@ -147,14 +144,7 @@ protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullH
res.headers().add(headers);
}

String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1);

if (logger.isDebugEnabled()) {
logger.debug("WebSocket version 13 server handshake key: {}, response: {}", key, accept);
}

String accept = WebSocketUtil.calculateV13Accept(key);
res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
Expand Down

0 comments on commit 68f95ec

Please sign in to comment.