Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ORIGIN header from ws opening handshake v13, clean up code #12293

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -40,7 +40,6 @@

import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.Locale;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

Expand All @@ -51,8 +50,6 @@
*/
public abstract class WebSocketClientHandshaker {

private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;

private final URI uri;
Expand Down Expand Up @@ -548,31 +545,4 @@ static CharSequence websocketHostValue(URI wsURL) {
// See https://tools.ietf.org/html/rfc6454#section-6.2
return NetUtil.toSocketAddressString(host, port);
}

static CharSequence websocketOriginValue(URI wsURL) {
String scheme = wsURL.getScheme();
final String schemePrefix;
int port = wsURL.getPort();
final int defaultPort;
if (WebSocketScheme.WSS.name().contentEquals(scheme)
|| HttpScheme.HTTPS.name().contentEquals(scheme)
|| (scheme == null && port == WebSocketScheme.WSS.port())) {

schemePrefix = HTTPS_SCHEME_PREFIX;
defaultPort = WebSocketScheme.WSS.port();
} else {
schemePrefix = HTTP_SCHEME_PREFIX;
defaultPort = WebSocketScheme.WS.port();
}

// Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
String host = wsURL.getHost().toLowerCase(Locale.US);

if (port != defaultPort && port != -1) {
// if the port is not standard (80/443) its needed to add the port to the header.
// See https://tools.ietf.org/html/rfc6454#section-6.2
return schemePrefix + NetUtil.toSocketAddressString(host, port);
}
return schemePrefix + host;
}
}
Expand Up @@ -25,39 +25,30 @@
import io.netty5.handler.codec.http.HttpMethod;
import io.netty5.handler.codec.http.HttpResponseStatus;
import io.netty5.handler.codec.http.HttpVersion;
import io.netty5.util.CharsetUtil;
import io.netty5.util.internal.logging.InternalLogger;
import io.netty5.util.internal.logging.InternalLoggerFactory;
import io.netty5.util.internal.StringUtil;

import java.net.URI;

/**
* <p>
* Performs client side opening and closing handshakes for web socket specification version <a
* href="https://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17" >draft-ietf-hybi-thewebsocketprotocol-
* 17</a>
* Performs client side opening and closing handshakes for web socket specification version
* <a href="https://datatracker.ietf.org/doc/html/rfc6455">websocketprotocol-v13</a>
* </p>
*/
public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {

private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketClientHandshaker13.class);

public static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

private String expectedChallengeResponseString;

private final boolean allowExtensions;
private final boolean performMasking;
private final boolean allowMaskMismatch;

private volatile String sentNonce;

/**
* Creates a new instance.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -67,9 +58,9 @@ public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
* @param maxFramePayloadLength
* Maximum length of a frame's payload
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
true, false);
}

Expand All @@ -79,8 +70,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -97,10 +86,10 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength,
performMasking, allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
}

Expand All @@ -110,8 +99,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -130,11 +117,11 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified.
*/
public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis) {
this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
allowMaskMismatch, forceCloseTimeoutMillis, false);
}

Expand All @@ -144,8 +131,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
Expand All @@ -167,12 +152,12 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
* clear HTTP
*/
WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol,
WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis,
absoluteUpgradeUrl);
super(webSocketURL, WebSocketVersion.V13, subprotocol, customHeaders, maxFramePayloadLength,
forceCloseTimeoutMillis, absoluteUpgradeUrl);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
Expand All @@ -190,7 +175,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
* Origin: http://example.com
* Sec-WebSocket-Protocol: chat, superchat
* Sec-WebSocket-Version: 13
* </pre>
Expand All @@ -199,22 +183,6 @@ public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, S
@Override
protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
URI wsURL = uri();

// Get 16 bit nonce and base 64 encode it
byte[] nonce = WebSocketUtil.randomBytes(16);
String key = WebSocketUtil.base64(nonce);

String acceptSeed = key + MAGIC_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
expectedChallengeResponseString = WebSocketUtil.base64(sha1);

if (logger.isDebugEnabled()) {
logger.debug(
"WebSocket version 13 client handshake key: {}, expected response: {}",
key, expectedChallengeResponseString);
}

// Format request
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL),
allocator.allocate(0));
HttpHeaders headers = request.headers();
Expand All @@ -223,24 +191,21 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
headers.add(customHeaders);
if (!headers.contains(HttpHeaderNames.HOST)) {
// Only add HOST header if customHeaders did not contain it.
//
// See https://github.com/netty/netty/issues/10101
// See https://github.com/netty/netty/issues/10101.
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}
} else {
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}

String nonce = createNonce();
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key);

if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce);

sentNonce = nonce;
String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}

Expand Down Expand Up @@ -269,7 +234,7 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
protected void verify(FullHttpResponse response) {
HttpResponseStatus status = response.status();
if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response);
throw new WebSocketClientHandshakeException("Invalid handshake response status: " + status, response);
}

HttpHeaders headers = response.headers();
Expand All @@ -283,10 +248,16 @@ protected void verify(FullHttpResponse response) {
+ headers.get(HttpHeaderNames.CONNECTION), response);
}

CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null || !accept.equals(expectedChallengeResponseString)) {
throw new WebSocketClientHandshakeException(String.format(
"Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response);
String accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null) {
throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: null",
response);
}

String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce);
if (!expectedAccept.equals(accept.trim())) {
throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: " + accept +
", expected: " + expectedAccept, response);
}
}

Expand All @@ -306,4 +277,12 @@ public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTim
return this;
}

/**
* Creates a nonce consisting of a randomly selected 16-byte value
* that has been base64-encoded.
*/
private static String createNonce() {
var nonce = WebSocketUtil.randomBytes(16);
return WebSocketUtil.base64(nonce);
}
}
Expand Up @@ -144,7 +144,7 @@ public static WebSocketClientHandshaker newHandshaker(
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) {
if (version == V13) {
return new WebSocketClientHandshaker13(
webSocketURL, V13, subprotocol, allowExtensions, customHeaders,
webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis);
}

Expand Down Expand Up @@ -187,7 +187,7 @@ public static WebSocketClientHandshaker newHandshaker(
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
if (version == V13) {
return new WebSocketClientHandshaker13(
webSocketURL, V13, subprotocol, allowExtensions, customHeaders,
webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl);
}

Expand Down
Expand Up @@ -23,9 +23,8 @@
import io.netty5.handler.codec.http.HttpHeaderValues;
import io.netty5.handler.codec.http.HttpHeaders;
import io.netty5.handler.codec.http.HttpResponseStatus;
import io.netty5.util.CharsetUtil;

import static io.netty5.handler.codec.http.HttpVersion.*;
import static io.netty5.handler.codec.http.HttpVersion.HTTP_1_1;

/**
* <p>
Expand All @@ -35,8 +34,6 @@
*/
public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {

public static final String WEBSOCKET_13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";

/**
* Constructor specifying the destination web socket location
*
Expand Down Expand Up @@ -136,7 +133,7 @@ public WebSocketServerHandshaker13(
@Override
protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req,
HttpHeaders headers) {
CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
String key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
}
Expand All @@ -147,14 +144,7 @@ protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullH
res.headers().add(headers);
}

String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID;
byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII));
String accept = WebSocketUtil.base64(sha1);

if (logger.isDebugEnabled()) {
logger.debug("WebSocket version 13 server handshake key: {}, response: {}", key, accept);
}

String accept = WebSocketUtil.calculateV13Accept(key);
res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
Expand Down