diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java index e70a880be89..efc1101b36e 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -79,6 +79,8 @@ public abstract class WebSocketClientHandshaker { private final boolean absoluteUpgradeUrl; + protected final boolean generateOriginHeader; + /** * Base constructor * @@ -145,6 +147,36 @@ protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String su protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol, HttpHeaders customHeaders, int maxFramePayloadLength, long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { + this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, + absoluteUpgradeUrl, true); + } + + /** + * Base constructor + * + * @param uri + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request + * according to the given webSocketURL + */ + protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) { this.uri = uri; this.version = version; expectedSubprotocol = subprotocol; @@ -152,6 +184,7 @@ protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String su this.maxFramePayloadLength = maxFramePayloadLength; this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; this.absoluteUpgradeUrl = absoluteUpgradeUrl; + this.generateOriginHeader = generateOriginHeader; } /** @@ -247,6 +280,21 @@ public Future handshake(Channel channel) { } } + if (uri.getHost() == null) { + if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) { + return channel.newFailedFuture(new IllegalArgumentException("Cannot generate the 'host' header value," + + " webSocketURI should contain host or passed through customHeaders")); + } + + if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) { + final String originName = HttpHeaderNames.ORIGIN.toString(); + return channel.newFailedFuture( + new IllegalArgumentException("Cannot generate the '" + originName + "' header" + + " value, webSocketURI should contain host or disable generateOriginHeader or pass value" + + " through customHeaders")); + } + } + FullHttpRequest request = newHandshakeRequest(channel.bufferAllocator()); Promise promise = channel.newPromise(); diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java index 04128d3f675..4550097e9a4 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13.java @@ -153,12 +153,53 @@ public WebSocketClientHandshaker13(URI webSocketURL, String subprotocol, * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over * clear HTTP */ + WebSocketClientHandshaker13(URI webSocketURL, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { + this(webSocketURL, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl, true); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate the `Origin` header value for handshake request + * according to the given webSocketURL + */ WebSocketClientHandshaker13(URI webSocketURL, String subprotocol, boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, boolean performMasking, boolean allowMaskMismatch, - long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, + boolean generateOriginHeader) { super(webSocketURL, WebSocketVersion.V13, subprotocol, customHeaders, maxFramePayloadLength, - forceCloseTimeoutMillis, absoluteUpgradeUrl); + forceCloseTimeoutMillis, absoluteUpgradeUrl, generateOriginHeader); this.allowExtensions = allowExtensions; this.performMasking = performMasking; this.allowMaskMismatch = allowMaskMismatch; @@ -204,6 +245,10 @@ protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) { .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce); + if (generateOriginHeader && !headers.contains(HttpHeaderNames.ORIGIN)) { + headers.set(HttpHeaderNames.ORIGIN, websocketHostValue(wsURL)); + } + sentNonce = nonce; String expectedSubprotocol = expectedSubprotocol(); if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) { diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java index 72767fea180..ee177811702 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java @@ -193,4 +193,48 @@ public static WebSocketClientHandshaker newHandshaker( throw new WebSocketClientHandshakeException("Protocol version " + version + " not supported."); } + + /** + * Creates a new handshaker. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. Null if no sub-protocol support is required. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Custom HTTP headers to send during the handshake + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request + * according to the given webSocketURL + */ + public static WebSocketClientHandshaker newHandshaker( + URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + return new WebSocketClientHandshaker13( + webSocketURL, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + } } diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java index 4802e394259..295f1d98128 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java @@ -32,6 +32,7 @@ public final class WebSocketClientProtocolConfig { static final boolean DEFAULT_ALLOW_MASK_MISMATCH = false; static final boolean DEFAULT_HANDLE_CLOSE_FRAMES = true; static final boolean DEFAULT_DROP_PONG_FRAMES = true; + static final boolean DEFAULT_GENERATE_ORIGIN_HEADER = true; private final URI webSocketUri; private final String subprotocol; @@ -47,6 +48,7 @@ public final class WebSocketClientProtocolConfig { private final long handshakeTimeoutMillis; private final long forceCloseTimeoutMillis; private final boolean absoluteUpgradeUrl; + private final boolean generateOriginHeader; private WebSocketClientProtocolConfig( URI webSocketUri, @@ -62,7 +64,8 @@ private WebSocketClientProtocolConfig( boolean dropPongFrames, long handshakeTimeoutMillis, long forceCloseTimeoutMillis, - boolean absoluteUpgradeUrl + boolean absoluteUpgradeUrl, + boolean generateOriginHeader ) { this.webSocketUri = webSocketUri; this.subprotocol = subprotocol; @@ -78,6 +81,7 @@ private WebSocketClientProtocolConfig( this.dropPongFrames = dropPongFrames; this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); this.absoluteUpgradeUrl = absoluteUpgradeUrl; + this.generateOriginHeader = generateOriginHeader; } public URI webSocketUri() { @@ -136,24 +140,29 @@ public boolean absoluteUpgradeUrl() { return absoluteUpgradeUrl; } + public boolean generateOriginHeader() { + return generateOriginHeader; + } + @Override public String toString() { return "WebSocketClientProtocolConfig" + - " {webSocketUri=" + webSocketUri + - ", subprotocol=" + subprotocol + - ", version=" + version + - ", allowExtensions=" + allowExtensions + - ", customHeaders=" + customHeaders + - ", maxFramePayloadLength=" + maxFramePayloadLength + - ", performMasking=" + performMasking + - ", allowMaskMismatch=" + allowMaskMismatch + - ", handleCloseFrames=" + handleCloseFrames + - ", sendCloseFrame=" + sendCloseFrame + - ", dropPongFrames=" + dropPongFrames + - ", handshakeTimeoutMillis=" + handshakeTimeoutMillis + - ", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis + - ", absoluteUpgradeUrl=" + absoluteUpgradeUrl + - "}"; + " {webSocketUri=" + webSocketUri + + ", subprotocol=" + subprotocol + + ", version=" + version + + ", allowExtensions=" + allowExtensions + + ", customHeaders=" + customHeaders + + ", maxFramePayloadLength=" + maxFramePayloadLength + + ", performMasking=" + performMasking + + ", allowMaskMismatch=" + allowMaskMismatch + + ", handleCloseFrames=" + handleCloseFrames + + ", sendCloseFrame=" + sendCloseFrame + + ", dropPongFrames=" + dropPongFrames + + ", handshakeTimeoutMillis=" + handshakeTimeoutMillis + + ", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis + + ", absoluteUpgradeUrl=" + absoluteUpgradeUrl + + ", generateOriginHeader=" + generateOriginHeader + + "}"; } public Builder toBuilder() { @@ -175,7 +184,8 @@ public static Builder newBuilder() { DEFAULT_DROP_PONG_FRAMES, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS, -1, - false); + false, + DEFAULT_GENERATE_ORIGIN_HEADER); } public static final class Builder { @@ -193,6 +203,7 @@ public static final class Builder { private long handshakeTimeoutMillis; private long forceCloseTimeoutMillis; private boolean absoluteUpgradeUrl; + private boolean generateOriginHeader; private Builder(WebSocketClientProtocolConfig clientConfig) { this(Objects.requireNonNull(clientConfig, "clientConfig").webSocketUri(), @@ -208,7 +219,8 @@ private Builder(WebSocketClientProtocolConfig clientConfig) { clientConfig.dropPongFrames(), clientConfig.handshakeTimeoutMillis(), clientConfig.forceCloseTimeoutMillis(), - clientConfig.absoluteUpgradeUrl()); + clientConfig.absoluteUpgradeUrl(), + clientConfig.generateOriginHeader()); } private Builder(URI webSocketUri, @@ -224,7 +236,8 @@ private Builder(URI webSocketUri, boolean dropPongFrames, long handshakeTimeoutMillis, long forceCloseTimeoutMillis, - boolean absoluteUpgradeUrl) { + boolean absoluteUpgradeUrl, + boolean generateOriginHeader) { this.webSocketUri = webSocketUri; this.subprotocol = subprotocol; this.version = version; @@ -239,6 +252,7 @@ private Builder(URI webSocketUri, this.handshakeTimeoutMillis = handshakeTimeoutMillis; this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; this.absoluteUpgradeUrl = absoluteUpgradeUrl; + this.generateOriginHeader = generateOriginHeader; } /** @@ -365,6 +379,16 @@ public Builder absoluteUpgradeUrl(boolean absoluteUpgradeUrl) { return this; } + /** + * Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request + * according the given webSocketURI. Usually it's not necessary and can be disabled, + * but for backward compatibility is set to {@code true} as default. + */ + public Builder generateOriginHeader(boolean generateOriginHeader) { + this.generateOriginHeader = generateOriginHeader; + return this; + } + /** * Build unmodifiable client protocol configuration. */ @@ -383,7 +407,8 @@ public WebSocketClientProtocolConfig build() { dropPongFrames, handshakeTimeoutMillis, forceCloseTimeoutMillis, - absoluteUpgradeUrl + absoluteUpgradeUrl, + generateOriginHeader ); } } diff --git a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java index 1a8f9d9e683..a45ec98eb5f 100644 --- a/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java +++ b/codec-http/src/main/java/io/netty5/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java @@ -74,7 +74,8 @@ public WebSocketClientProtocolHandler(WebSocketClientProtocolConfig clientConfig clientConfig.performMasking(), clientConfig.allowMaskMismatch(), clientConfig.forceCloseTimeoutMillis(), - clientConfig.absoluteUpgradeUrl() + clientConfig.absoluteUpgradeUrl(), + clientConfig.generateOriginHeader() ); this.clientConfig = clientConfig; } diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java index 99fdda6fe49..1cf8ba057a7 100644 --- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java +++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java @@ -35,10 +35,15 @@ public class WebSocketClientHandshaker13Test extends WebSocketClientHandshakerTe @Override protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, - boolean absoluteUpgradeUrl) { + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { return new WebSocketClientHandshaker13(uri, subprotocol, false, headers, 1024, true, true, 10000, - absoluteUpgradeUrl); + absoluteUpgradeUrl, generateOriginHeader); + } + + @Override + protected CharSequence getOriginHeaderName() { + return HttpHeaderNames.ORIGIN; } @Override @@ -60,7 +65,7 @@ protected CharSequence[] getHandshakeRequiredHeaderNames() { @Test void testWebSocketClientInvalidUpgrade() { var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, - null, false); + null, false, true); var response = websocketUpgradeResponse(); response.headers().remove(HttpHeaderNames.UPGRADE); @@ -78,7 +83,7 @@ void testWebSocketClientInvalidUpgrade() { @Test void testWebSocketClientInvalidConnection() { var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, - null, false); + null, false, true); var response = websocketUpgradeResponse(); response.headers().set(HttpHeaderNames.CONNECTION, "Close"); @@ -96,7 +101,7 @@ void testWebSocketClientInvalidConnection() { @Test void testWebSocketClientInvalidNullAccept() { var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, - null, false); + null, false, true); var response = websocketUpgradeResponse(); final WebSocketClientHandshakeException exception; @@ -113,7 +118,7 @@ void testWebSocketClientInvalidNullAccept() { @Test void testWebSocketClientInvalidExpectedAccept() { var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, - null, false); + null, false, true); final CharSequence sentNonce; try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { sentNonce = request.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); diff --git a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java index 27002decbfa..04d5529bcd6 100644 --- a/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java +++ b/codec-http/src/test/java/io/netty5/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java @@ -33,6 +33,7 @@ import io.netty5.handler.codec.http.HttpResponseStatus; import io.netty5.handler.codec.http.HttpVersion; import io.netty5.handler.codec.http.headers.HttpHeaders; +import io.netty5.util.concurrent.Future; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -44,6 +45,8 @@ import static io.netty5.buffer.DefaultBufferAllocators.preferredAllocator; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; @@ -51,12 +54,15 @@ public abstract class WebSocketClientHandshakerTest { protected abstract WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, - boolean absoluteUpgradeUrl); + boolean absoluteUpgradeUrl, + boolean generateOriginHeader); protected WebSocketClientHandshaker newHandshaker(URI uri) { - return newHandshaker(uri, null, null, false); + return newHandshaker(uri, null, null, false, true); } + protected abstract CharSequence getOriginHeaderName(); + protected abstract CharSequence getProtocolHeaderName(); protected abstract CharSequence[] getHandshakeRequiredHeaderNames(); @@ -147,7 +153,7 @@ void testUpgradeUrlWithoutPathWithQuery() { @Test void testAbsoluteUpgradeUrlWithQuery() { URI uri = URI.create("ws://localhost:9999/path%20with%20ws?a=b%20c"); - WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, true); + WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, true, true); try (FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator())) { assertEquals("ws://localhost:9999/path%20with%20ws?a=b%20c", request.uri()); } @@ -271,7 +277,7 @@ void testDuplicateWebsocketHandshakeHeaders() { inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol); String realSubProtocol = "realSubProtocol"; - WebSocketClientHandshaker handshaker = newHandshaker(uri, realSubProtocol, inputHeaders, false); + WebSocketClientHandshaker handshaker = newHandshaker(uri, realSubProtocol, inputHeaders, false, true); FullHttpRequest request = handshaker.newHandshakeRequest(preferredAllocator()); HttpHeaders outputHeaders = request.headers(); @@ -293,7 +299,7 @@ void testSetHostHeaderIfNoPresentInCustomHeaders() { var customHeaders = HttpHeaders.newHeaders(); customHeaders.set(HttpHeaderNames.HOST, "custom-host"); var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, - customHeaders, false); + customHeaders, false, true); try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { assertEquals("custom-host", request.headers().get(HttpHeaderNames.HOST)); } @@ -302,7 +308,7 @@ void testSetHostHeaderIfNoPresentInCustomHeaders() { @Test void testNoOriginHeaderInHandshakeRequest() { var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, - null, false); + null, false, false); try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { assertNull(request.headers().get(HttpHeaderNames.ORIGIN)); } @@ -312,7 +318,7 @@ void testNoOriginHeaderInHandshakeRequest() { void testSetOriginFromCustomHeaders() { var customHeaders = HttpHeaders.newHeaders().set(HttpHeaderNames.ORIGIN, "http://example.com"); var handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null, - customHeaders, false); + customHeaders, false, true); try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { assertEquals("http://example.com", request.headers().get(HttpHeaderNames.ORIGIN)); } @@ -321,7 +327,7 @@ void testSetOriginFromCustomHeaders() { @Test void testWebSocketClientHandshakeExceptionContainsResponse() { var handshaker = newHandshaker(URI.create("ws://localhost:9999/ws"), null, - null, false); + null, false, true); var response = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED, preferredAllocator().allocate(0)); response.headers().set(HttpHeaderNames.WWW_AUTHENTICATE, "realm = access token required"); @@ -339,6 +345,48 @@ void testWebSocketClientHandshakeExceptionContainsResponse() { "realm = access token required")); } + @Test + public void testOriginHeaderIsAbsentWhenGeneratingDisable() { + URI uri = URI.create("http://example.com/ws"); + WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false, false); + + try (var request = handshaker.newHandshakeRequest(preferredAllocator())) { + assertFalse(request.headers().contains(getOriginHeaderName())); + assertEquals("/ws", request.uri()); + } + } + + @Test + public void testInvalidHostWhenIncorrectWebSocketURI() { + URI uri = URI.create("/ws"); + EmbeddedChannel channel = new EmbeddedChannel(new HttpClientCodec()); + final WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false, true); + final Future handshakeFuture = handshaker.handshake(channel); + + assertFalse(handshakeFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, handshakeFuture.cause()); + assertEquals("Cannot generate the 'host' header value, webSocketURI should contain host" + + " or passed through customHeaders", handshakeFuture.cause().getMessage()); + assertFalse(channel.finish()); + } + + @Test + public void testInvalidOriginWhenIncorrectWebSocketURI() { + URI uri = URI.create("/ws"); + EmbeddedChannel channel = new EmbeddedChannel(new HttpClientCodec()); + HttpHeaders headers = HttpHeaders.newHeaders(); + headers.set(HttpHeaderNames.HOST, "localhost:80"); + final WebSocketClientHandshaker handshaker = newHandshaker(uri, null, headers, false, true); + final Future handshakeFuture = handshaker.handshake(channel); + + assertFalse(handshakeFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, handshakeFuture.cause()); + assertEquals("Cannot generate the '" + getOriginHeaderName() + "' header value," + + " webSocketURI should contain host or disable generateOriginHeader" + + " or pass value through customHeaders", handshakeFuture.cause().getMessage()); + assertFalse(channel.finish()); + } + private void testHostHeader(String uri, String expected) { var handshaker = newHandshaker(URI.create(uri)); try (var request = handshaker.newHandshakeRequest(preferredAllocator())) {