Skip to content

Commit

Permalink
Set the ORIGIN header from a custom headers if present (#9435)
Browse files Browse the repository at this point in the history
Motivation:

Allow to set the ORIGIN header value from custom headers in WebSocketClientHandshaker

Modification:

Only override header if not present already

Result:

More flexible handshaker usage
  • Loading branch information
amizurov authored and normanmaurer committed Aug 11, 2019
1 parent fedcc40 commit b8ac02d
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 14 deletions.
Expand Up @@ -189,10 +189,13 @@ protected FullHttpRequest newHandshakeRequest() {
headers.set(HttpHeaderNames.UPGRADE, WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2);

if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}

String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
Expand Down
Expand Up @@ -223,8 +223,11 @@ protected FullHttpRequest newHandshakeRequest() {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));

if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
}

String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
Expand Down
Expand Up @@ -225,8 +225,11 @@ protected FullHttpRequest newHandshakeRequest() {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));

if (!headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL));
}

String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
Expand Down
Expand Up @@ -226,8 +226,11 @@ protected FullHttpRequest newHandshakeRequest() {
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key)
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL))
.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));

if (!headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL));
}

String expectedSubprotocol = expectedSubprotocol();
if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) {
Expand Down
Expand Up @@ -39,12 +39,11 @@ protected CharSequence getProtocolHeaderName() {
}

@Override
protected CharSequence[] getHandshakeHeaderNames() {
protected CharSequence[] getHandshakeRequiredHeaderNames() {
return new CharSequence[] {
HttpHeaderNames.CONNECTION,
HttpHeaderNames.UPGRADE,
HttpHeaderNames.HOST,
HttpHeaderNames.ORIGIN,
HttpHeaderNames.SEC_WEBSOCKET_KEY1,
HttpHeaderNames.SEC_WEBSOCKET_KEY2,
};
Expand Down
Expand Up @@ -40,13 +40,12 @@ protected CharSequence getProtocolHeaderName() {
}

@Override
protected CharSequence[] getHandshakeHeaderNames() {
protected CharSequence[] getHandshakeRequiredHeaderNames() {
return new CharSequence[] {
HttpHeaderNames.UPGRADE,
HttpHeaderNames.CONNECTION,
HttpHeaderNames.SEC_WEBSOCKET_KEY,
HttpHeaderNames.HOST,
getOriginHeaderName(),
HttpHeaderNames.SEC_WEBSOCKET_VERSION,
};
}
Expand Down
Expand Up @@ -50,7 +50,7 @@ protected WebSocketClientHandshaker newHandshaker(URI uri) {

protected abstract CharSequence getProtocolHeaderName();

protected abstract CharSequence[] getHandshakeHeaderNames();
protected abstract CharSequence[] getHandshakeRequiredHeaderNames();

@Test
public void hostHeaderWs() {
Expand Down Expand Up @@ -160,6 +160,19 @@ public void originHeaderWithoutScheme() {
testOriginHeader("//LOCALHOST/", "http://localhost");
}

@Test
public void testSetOriginFromCustomHeaders() {
HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com");
WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null,
customHeaders, false);
FullHttpRequest request = handshaker.newHandshakeRequest();
try {
assertEquals("http://example.com", request.headers().get(getOriginHeaderName()));
} finally {
request.release();
}
}

private void testHostHeader(String uri, String expected) {
testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected);
}
Expand Down Expand Up @@ -325,7 +338,7 @@ public void testDuplicateWebsocketHandshakeHeaders() {
String bogusHeaderValue = "bogusHeaderValue";

// add values for the headers that are reserved for use in the websockets handshake
for (CharSequence header : getHandshakeHeaderNames()) {
for (CharSequence header : getHandshakeRequiredHeaderNames()) {
inputHeaders.add(header, bogusHeaderValue);
}
inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol);
Expand All @@ -336,7 +349,7 @@ public void testDuplicateWebsocketHandshakeHeaders() {
HttpHeaders outputHeaders = request.headers();

// the header values passed in originally have been replaced with values generated by the Handshaker
for (CharSequence header : getHandshakeHeaderNames()) {
for (CharSequence header : getHandshakeRequiredHeaderNames()) {
assertEquals(1, outputHeaders.getAll(header).size());
assertNotEquals(bogusHeaderValue, outputHeaders.get(header));
}
Expand Down

0 comments on commit b8ac02d

Please sign in to comment.