Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a way to pass through a certain HTTP upgrade request #11267

Merged
merged 3 commits into from May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -224,13 +224,24 @@ public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upg
@Override
protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out)
throws Exception {
// Determine if we're already handling an upgrade request or just starting a new one.
handlingUpgrade |= isUpgradeRequest(msg);

if (!handlingUpgrade) {
// Not handling an upgrade request, just pass it to the next handler.
ReferenceCountUtil.retain(msg);
out.add(msg);
return;
// Not handling an upgrade request yet. Check if we received a new upgrade request.
if (msg instanceof HttpRequest) {
HttpRequest req = (HttpRequest) msg;
if (req.headers().contains(HttpHeaderNames.UPGRADE) &&
shouldHandleUpgradeRequest(req)) {
handlingUpgrade = true;
} else {
ReferenceCountUtil.retain(msg);
ctx.fireChannelRead(msg);
return;
}
} else {
ReferenceCountUtil.retain(msg);
ctx.fireChannelRead(msg);
return;
}
}

FullHttpRequest fullRequest;
Expand Down Expand Up @@ -264,10 +275,20 @@ protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> ou
}

/**
* Determines whether or not the message is an HTTP upgrade request.
* Determines whether the specified upgrade {@link HttpRequest} should be handled by this handler or not.
* This method will be invoked only when the request contains an {@code Upgrade} header.
* It always returns {@code true} by default, which means any request with an {@code Upgrade} header
* will be handled. You can override this method to ignore certain {@code Upgrade} headers, for example:
* <pre>{@code
* @Override
* protected boolean isUpgradeRequest(HttpRequest req) {
* // Do not handle WebSocket upgrades.
* return !req.headers().contains(HttpHeaderNames.UPGRADE, "websocket", false);
* }
* }</pre>
*/
private static boolean isUpgradeRequest(HttpObject msg) {
return msg instanceof HttpRequest && ((HttpRequest) msg).headers().get(HttpHeaderNames.UPGRADE) != null;
protected boolean shouldHandleUpgradeRequest(HttpRequest req) {
return true;
}

/**
Expand Down
Expand Up @@ -141,4 +141,48 @@ public void operationComplete(ChannelFuture future) {
assertTrue(upgradeMessage.release());
assertFalse(channel.finishAndReleaseAll());
}

@Test
public void skippedUpgrade() {
final HttpServerCodec httpServerCodec = new HttpServerCodec();
final UpgradeCodecFactory factory = new UpgradeCodecFactory() {
@Override
public UpgradeCodec newUpgradeCodec(CharSequence protocol) {
fail("Should never be invoked");
return null;
}
};

HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, factory) {
@Override
protected boolean shouldHandleUpgradeRequest(HttpRequest req) {
return !req.headers().contains(HttpHeaderNames.UPGRADE, "do-not-upgrade", false);
}
};

EmbeddedChannel channel = new EmbeddedChannel(httpServerCodec, upgradeHandler);

String upgradeString = "GET / HTTP/1.1\r\n" +
"Host: example.com\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: do-not-upgrade\r\n\r\n";
ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII);

// The upgrade request should not be passed to the next handler without any processing.
assertTrue(channel.writeInbound(upgrade));
assertNotNull(channel.pipeline().get(HttpServerCodec.class));
assertNull(channel.pipeline().get("marker"));

HttpRequest req = channel.readInbound();
assertFalse(req instanceof FullHttpRequest); // Should not be aggregated.
assertTrue(req.headers().contains(HttpHeaderNames.CONNECTION, "Upgrade", false));
assertTrue(req.headers().contains(HttpHeaderNames.UPGRADE, "do-not-upgrade", false));
assertTrue(channel.readInbound() instanceof LastHttpContent);
assertNull(channel.readInbound());

// No response should be written because we're just passing through.
channel.flushOutbound();
assertNull(channel.readOutbound());
assertFalse(channel.finishAndReleaseAll());
}
}