Skip to content

Commit

Permalink
Fix generating the Origin header value for websocket handshake requ…
Browse files Browse the repository at this point in the history
…est (#12941)

Motivation:

We have the old erroneous behavior of generating the `Origin| Sec-WebSocket-Origin` for client websocket handshake request (#9673). In Netty5 this fixed and auto-generation has been deleted at all, only if the client passed the `Origin` header via custom headers. The same we can do for Netty4 but it could potentially break some clients (unlikely), or introduce an additional parameter to disable or enable this behavior.

Modification:

Introduce new `generateOriginHeader` parameter in client config and generate the `Origin|Sec-WebSocket-Origin` header value only if it enabled. Add additional check for webSocketURI if it contains host or passed through `customHeaders` to prevent NPE in `newHandshakeRequest()`.

Result:

Fixes #9673 #12933

Co-authored-by: Norman Maurer <norman_maurer@apple.com>
  • Loading branch information
amizurov and normanmaurer committed Nov 10, 2022
1 parent e437e21 commit f3c27ae
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 37 deletions.
Expand Up @@ -79,6 +79,8 @@ public abstract class WebSocketClientHandshaker {

private final boolean absoluteUpgradeUrl;

protected final boolean generateOriginHeader;

/**
* Base constructor
*
Expand Down Expand Up @@ -145,13 +147,44 @@ protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String su
protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
HttpHeaders customHeaders, int maxFramePayloadLength,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis,
absoluteUpgradeUrl, true);
}

/**
* Base constructor
*
* @param uri
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified
* @param absoluteUpgradeUrl
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
* clear HTTP
* @param generateOriginHeader
* Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request
* according to the given webSocketURL
*/
protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
HttpHeaders customHeaders, int maxFramePayloadLength,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
this.uri = uri;
this.version = version;
expectedSubprotocol = subprotocol;
this.customHeaders = customHeaders;
this.maxFramePayloadLength = maxFramePayloadLength;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.absoluteUpgradeUrl = absoluteUpgradeUrl;
this.generateOriginHeader = generateOriginHeader;
}

/**
Expand Down Expand Up @@ -247,6 +280,21 @@ public Future<Void> handshake(Channel channel) {
}
}

if (uri.getHost() == null) {
if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) {
return channel.newFailedFuture(new IllegalArgumentException("Cannot generate the 'host' header value," +
" webSocketURI should contain host or passed through customHeaders"));
}

if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) {
final String originName = HttpHeaderNames.ORIGIN.toString();
return channel.newFailedFuture(
new IllegalArgumentException("Cannot generate the '" + originName + "' header" +
" value, webSocketURI should contain host or disable generateOriginHeader or pass value" +
" through customHeaders"));
}
}

FullHttpRequest request = newHandshakeRequest(channel.bufferAllocator());

Promise<Void> promise = channel.newPromise();
Expand Down
Expand Up @@ -153,12 +153,53 @@ public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
* clear HTTP
*/
WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking,
allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl, true);
}

/**
* Creates a new instance.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
* sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server.
* @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders
* Map of custom headers to add to the client request
* @param maxFramePayloadLength
* Maximum length of a frame's payload
* @param performMasking
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified.
* @param absoluteUpgradeUrl
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
* clear HTTP
* @param generateOriginHeader
* Allows to generate the `Origin` header value for handshake request
* according to the given webSocketURL
*/
WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl,
boolean generateOriginHeader) {
super(webSocketURL, WebSocketVersion.V13, subprotocol, customHeaders, maxFramePayloadLength,
forceCloseTimeoutMillis, absoluteUpgradeUrl);
forceCloseTimeoutMillis, absoluteUpgradeUrl, generateOriginHeader);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
Expand Down Expand Up @@ -204,6 +245,10 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce);

if (generateOriginHeader && !headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketHostValue(wsURL));
}

sentNonce = nonce;
String expectedSubprotocol = expectedSubprotocol();
if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) {
Expand Down
Expand Up @@ -193,4 +193,48 @@ public static WebSocketClientHandshaker newHandshaker(

throw new WebSocketClientHandshakeException("Protocol version " + version + " not supported.");
}

/**
* Creates a new handshaker.
*
* @param webSocketURL
* URL for web socket communications. e.g "ws://myhost.com/mypath".
* Subsequent web socket frames will be sent to this URL.
* @param version
* Version of web socket specification to use to connect to the server
* @param subprotocol
* Sub protocol request sent to the server. Null if no sub-protocol support is required.
* @param allowExtensions
* Allow extensions to be used in the reserved bits of the web socket frame
* @param customHeaders
* Custom HTTP headers to send during the handshake
* @param maxFramePayloadLength
* Maximum allowable frame payload length. Setting this value to your application's
* requirement may reduce denial of service attacks using long data frames.
* @param performMasking
* Whether to mask all written websocket frames. This must be set to true in order to be fully compatible
* with the websocket specifications. Client applications that communicate with a non-standard server
* which doesn't require masking might set this to false to achieve a higher performance.
* @param allowMaskMismatch
* When set to true, frames which are not masked properly according to the standard will still be
* accepted.
* @param forceCloseTimeoutMillis
* Close the connection if it was not closed by the server after timeout specified
* @param absoluteUpgradeUrl
* Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
* clear HTTP
* @param generateOriginHeader
* Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request
* according to the given webSocketURL
*/
public static WebSocketClientHandshaker newHandshaker(
URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis,
boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
return new WebSocketClientHandshaker13(
webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis,
absoluteUpgradeUrl, generateOriginHeader);
}
}
Expand Up @@ -32,6 +32,7 @@ public final class WebSocketClientProtocolConfig {
static final boolean DEFAULT_ALLOW_MASK_MISMATCH = false;
static final boolean DEFAULT_HANDLE_CLOSE_FRAMES = true;
static final boolean DEFAULT_DROP_PONG_FRAMES = true;
static final boolean DEFAULT_GENERATE_ORIGIN_HEADER = true;

private final URI webSocketUri;
private final String subprotocol;
Expand All @@ -47,6 +48,7 @@ public final class WebSocketClientProtocolConfig {
private final long handshakeTimeoutMillis;
private final long forceCloseTimeoutMillis;
private final boolean absoluteUpgradeUrl;
private final boolean generateOriginHeader;

private WebSocketClientProtocolConfig(
URI webSocketUri,
Expand All @@ -62,7 +64,8 @@ private WebSocketClientProtocolConfig(
boolean dropPongFrames,
long handshakeTimeoutMillis,
long forceCloseTimeoutMillis,
boolean absoluteUpgradeUrl
boolean absoluteUpgradeUrl,
boolean generateOriginHeader
) {
this.webSocketUri = webSocketUri;
this.subprotocol = subprotocol;
Expand All @@ -78,6 +81,7 @@ private WebSocketClientProtocolConfig(
this.dropPongFrames = dropPongFrames;
this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis");
this.absoluteUpgradeUrl = absoluteUpgradeUrl;
this.generateOriginHeader = generateOriginHeader;
}

public URI webSocketUri() {
Expand Down Expand Up @@ -136,24 +140,29 @@ public boolean absoluteUpgradeUrl() {
return absoluteUpgradeUrl;
}

public boolean generateOriginHeader() {
return generateOriginHeader;
}

@Override
public String toString() {
return "WebSocketClientProtocolConfig" +
" {webSocketUri=" + webSocketUri +
", subprotocol=" + subprotocol +
", version=" + version +
", allowExtensions=" + allowExtensions +
", customHeaders=" + customHeaders +
", maxFramePayloadLength=" + maxFramePayloadLength +
", performMasking=" + performMasking +
", allowMaskMismatch=" + allowMaskMismatch +
", handleCloseFrames=" + handleCloseFrames +
", sendCloseFrame=" + sendCloseFrame +
", dropPongFrames=" + dropPongFrames +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis +
", absoluteUpgradeUrl=" + absoluteUpgradeUrl +
"}";
" {webSocketUri=" + webSocketUri +
", subprotocol=" + subprotocol +
", version=" + version +
", allowExtensions=" + allowExtensions +
", customHeaders=" + customHeaders +
", maxFramePayloadLength=" + maxFramePayloadLength +
", performMasking=" + performMasking +
", allowMaskMismatch=" + allowMaskMismatch +
", handleCloseFrames=" + handleCloseFrames +
", sendCloseFrame=" + sendCloseFrame +
", dropPongFrames=" + dropPongFrames +
", handshakeTimeoutMillis=" + handshakeTimeoutMillis +
", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis +
", absoluteUpgradeUrl=" + absoluteUpgradeUrl +
", generateOriginHeader=" + generateOriginHeader +
"}";
}

public Builder toBuilder() {
Expand All @@ -175,7 +184,8 @@ public static Builder newBuilder() {
DEFAULT_DROP_PONG_FRAMES,
DEFAULT_HANDSHAKE_TIMEOUT_MILLIS,
-1,
false);
false,
DEFAULT_GENERATE_ORIGIN_HEADER);
}

public static final class Builder {
Expand All @@ -193,6 +203,7 @@ public static final class Builder {
private long handshakeTimeoutMillis;
private long forceCloseTimeoutMillis;
private boolean absoluteUpgradeUrl;
private boolean generateOriginHeader;

private Builder(WebSocketClientProtocolConfig clientConfig) {
this(Objects.requireNonNull(clientConfig, "clientConfig").webSocketUri(),
Expand All @@ -208,7 +219,8 @@ private Builder(WebSocketClientProtocolConfig clientConfig) {
clientConfig.dropPongFrames(),
clientConfig.handshakeTimeoutMillis(),
clientConfig.forceCloseTimeoutMillis(),
clientConfig.absoluteUpgradeUrl());
clientConfig.absoluteUpgradeUrl(),
clientConfig.generateOriginHeader());
}

private Builder(URI webSocketUri,
Expand All @@ -224,7 +236,8 @@ private Builder(URI webSocketUri,
boolean dropPongFrames,
long handshakeTimeoutMillis,
long forceCloseTimeoutMillis,
boolean absoluteUpgradeUrl) {
boolean absoluteUpgradeUrl,
boolean generateOriginHeader) {
this.webSocketUri = webSocketUri;
this.subprotocol = subprotocol;
this.version = version;
Expand All @@ -239,6 +252,7 @@ private Builder(URI webSocketUri,
this.handshakeTimeoutMillis = handshakeTimeoutMillis;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.absoluteUpgradeUrl = absoluteUpgradeUrl;
this.generateOriginHeader = generateOriginHeader;
}

/**
Expand Down Expand Up @@ -365,6 +379,16 @@ public Builder absoluteUpgradeUrl(boolean absoluteUpgradeUrl) {
return this;
}

/**
* Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request
* according the given webSocketURI. Usually it's not necessary and can be disabled,
* but for backward compatibility is set to {@code true} as default.
*/
public Builder generateOriginHeader(boolean generateOriginHeader) {
this.generateOriginHeader = generateOriginHeader;
return this;
}

/**
* Build unmodifiable client protocol configuration.
*/
Expand All @@ -383,7 +407,8 @@ public WebSocketClientProtocolConfig build() {
dropPongFrames,
handshakeTimeoutMillis,
forceCloseTimeoutMillis,
absoluteUpgradeUrl
absoluteUpgradeUrl,
generateOriginHeader
);
}
}
Expand Down
Expand Up @@ -74,7 +74,8 @@ public WebSocketClientProtocolHandler(WebSocketClientProtocolConfig clientConfig
clientConfig.performMasking(),
clientConfig.allowMaskMismatch(),
clientConfig.forceCloseTimeoutMillis(),
clientConfig.absoluteUpgradeUrl()
clientConfig.absoluteUpgradeUrl(),
clientConfig.generateOriginHeader()
);
this.clientConfig = clientConfig;
}
Expand Down

0 comments on commit f3c27ae

Please sign in to comment.