Skip to content

Commit

Permalink
Provide a way to pass through a certain HTTP upgrade request (netty#1…
Browse files Browse the repository at this point in the history
…1267)


Motivation:

A user might want to handle a certain HTTP upgrade request differently
than what `HttpServerUpgradeHandler` does by default. For example, a
user could let `HttpServerUpgradeHandler` handle HTTP/2 upgrades but
not WebSocket upgrades.

Modifications:

- Added `HttpServerUpgradeHandler.isUpgrade(HttpRequest)` so a user can
  tell `HttpServerUpgradeHandler` to pass the request as it is to the
  next handler.

Result:

- A user can handle a certain upgrade request specially.
  • Loading branch information
trustin authored and 夏无影 committed Jul 8, 2022
1 parent 5b5ad55 commit bf27b3f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
Expand Up @@ -224,13 +224,24 @@ public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upg
@Override
protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out)
throws Exception {
// Determine if we're already handling an upgrade request or just starting a new one.
handlingUpgrade |= isUpgradeRequest(msg);

if (!handlingUpgrade) {
// Not handling an upgrade request, just pass it to the next handler.
ReferenceCountUtil.retain(msg);
out.add(msg);
return;
// Not handling an upgrade request yet. Check if we received a new upgrade request.
if (msg instanceof HttpRequest) {
HttpRequest req = (HttpRequest) msg;
if (req.headers().contains(HttpHeaderNames.UPGRADE) &&
shouldHandleUpgradeRequest(req)) {
handlingUpgrade = true;
} else {
ReferenceCountUtil.retain(msg);
ctx.fireChannelRead(msg);
return;
}
} else {
ReferenceCountUtil.retain(msg);
ctx.fireChannelRead(msg);
return;
}
}

FullHttpRequest fullRequest;
Expand Down Expand Up @@ -264,10 +275,20 @@ protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> ou
}

/**
* Determines whether or not the message is an HTTP upgrade request.
* Determines whether the specified upgrade {@link HttpRequest} should be handled by this handler or not.
* This method will be invoked only when the request contains an {@code Upgrade} header.
* It always returns {@code true} by default, which means any request with an {@code Upgrade} header
* will be handled. You can override this method to ignore certain {@code Upgrade} headers, for example:
* <pre>{@code
* @Override
* protected boolean isUpgradeRequest(HttpRequest req) {
* // Do not handle WebSocket upgrades.
* return !req.headers().contains(HttpHeaderNames.UPGRADE, "websocket", false);
* }
* }</pre>
*/
private static boolean isUpgradeRequest(HttpObject msg) {
return msg instanceof HttpRequest && ((HttpRequest) msg).headers().get(HttpHeaderNames.UPGRADE) != null;
protected boolean shouldHandleUpgradeRequest(HttpRequest req) {
return true;
}

/**
Expand Down
Expand Up @@ -141,4 +141,48 @@ public void operationComplete(ChannelFuture future) {
assertTrue(upgradeMessage.release());
assertFalse(channel.finishAndReleaseAll());
}

@Test
public void skippedUpgrade() {
final HttpServerCodec httpServerCodec = new HttpServerCodec();
final UpgradeCodecFactory factory = new UpgradeCodecFactory() {
@Override
public UpgradeCodec newUpgradeCodec(CharSequence protocol) {
fail("Should never be invoked");
return null;
}
};

HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, factory) {
@Override
protected boolean shouldHandleUpgradeRequest(HttpRequest req) {
return !req.headers().contains(HttpHeaderNames.UPGRADE, "do-not-upgrade", false);
}
};

EmbeddedChannel channel = new EmbeddedChannel(httpServerCodec, upgradeHandler);

String upgradeString = "GET / HTTP/1.1\r\n" +
"Host: example.com\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: do-not-upgrade\r\n\r\n";
ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII);

// The upgrade request should not be passed to the next handler without any processing.
assertTrue(channel.writeInbound(upgrade));
assertNotNull(channel.pipeline().get(HttpServerCodec.class));
assertNull(channel.pipeline().get("marker"));

HttpRequest req = channel.readInbound();
assertFalse(req instanceof FullHttpRequest); // Should not be aggregated.
assertTrue(req.headers().contains(HttpHeaderNames.CONNECTION, "Upgrade", false));
assertTrue(req.headers().contains(HttpHeaderNames.UPGRADE, "do-not-upgrade", false));
assertTrue(channel.readInbound() instanceof LastHttpContent);
assertNull(channel.readInbound());

// No response should be written because we're just passing through.
channel.flushOutbound();
assertNull(channel.readOutbound());
assertFalse(channel.finishAndReleaseAll());
}
}

0 comments on commit bf27b3f

Please sign in to comment.