Skip to content

Commit

Permalink
Support handshake timeout in SniHandler. (#13041)
Browse files Browse the repository at this point in the history
Motivation: 
The SslHandler has a configurable timeout that is triggered when the handshake is not completed within that limit. The SniHandler should also have a similar timeout, such timeout would be fired when the the client has not sent enough data to trigger the SNI completion and propagated to the SslHandler it creates.

Modifications:
Added a handshake timeout in the SniHandler that creates a timer task to fire a failed sni completion event, as well as setting the handshake timeout on the SslHandler when it creates it.

Result:
Consequently, SniHandler supports handshake timeouts.

Co-authored-by: Norman Maurer <norman_maurer@apple.com>
  • Loading branch information
vietj and normanmaurer committed Dec 19, 2022
1 parent 0a68907 commit 0bcc6c8
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 1 deletion.
50 changes: 50 additions & 0 deletions handler/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.ScheduledFuture;

import java.util.Locale;
import java.util.concurrent.TimeUnit;

import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;

/**
* <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
Expand Down Expand Up @@ -117,8 +121,51 @@ private static String extractSniHostname(ByteBuf in) {
return null;
}

protected final long handshakeTimeoutMillis;
private ScheduledFuture<?> timeoutFuture;
private String hostname;

/**
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
protected AbstractSniHandler(long handshakeTimeoutMillis) {
this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis");
}

public AbstractSniHandler() {
this(0L);
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
if (ctx.channel().isActive()) {
checkStartTimeout(ctx);
}
}

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
ctx.fireChannelActive();
checkStartTimeout(ctx);
}

private void checkStartTimeout(final ChannelHandlerContext ctx) {
if (handshakeTimeoutMillis <= 0 || timeoutFuture != null) {
return;
}
timeoutFuture = ctx.executor().schedule(new Runnable() {
@Override
public void run() {
if (ctx.channel().isActive()) {
SslHandshakeTimeoutException exception = new SslHandshakeTimeoutException(
"handshake timed out after " + handshakeTimeoutMillis + "ms");
ctx.fireUserEventTriggered(new SniCompletionEvent(exception));
ctx.close();
}
}
}, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
}

@Override
protected Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
hostname = clientHello == null ? null : extractSniHostname(clientHello);
Expand All @@ -128,6 +175,9 @@ protected Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throw

@Override
protected void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception {
if (timeoutFuture != null) {
timeoutFuture.cancel(false);
}
try {
onLookupComplete(ctx, hostname, future);
} finally {
Expand Down
28 changes: 27 additions & 1 deletion handler/src/main/java/io/netty/handler/ssl/SniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ public SniHandler(Mapping<? super String, ? extends SslContext> mapping) {
this(new AsyncMappingAdapter(mapping));
}

/**
* Creates a SNI detection handler with configured {@link SslContext}
* maintained by {@link Mapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
public SniHandler(Mapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
this(new AsyncMappingAdapter(mapping), handshakeTimeoutMillis);
}

/**
* Creates a SNI detection handler with configured {@link SslContext}
* maintained by {@link DomainNameMapping}
Expand All @@ -69,6 +80,19 @@ public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
*/
@SuppressWarnings("unchecked")
public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping) {
this(mapping, 0L);
}

/**
* Creates a SNI detection handler with configured {@link SslContext}
* maintained by {@link AsyncMapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
* @param handshakeTimeoutMillis the handshake timeout in milliseconds
*/
@SuppressWarnings("unchecked")
public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
super(handshakeTimeoutMillis);
this.mapping = (AsyncMapping<String, SslContext>) ObjectUtil.checkNotNull(mapping, "mapping");
}

Expand Down Expand Up @@ -148,7 +172,9 @@ protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslCon
* Users may override this method to implement custom behavior.
*/
protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) {
return context.newHandler(allocator);
SslHandler sslHandler = context.newHandler(allocator);
sslHandler.setHandshakeTimeoutMillis(handshakeTimeoutMillis);
return sslHandler;
}

private static final class AsyncMappingAdapter implements AsyncMapping<String, SslContext> {
Expand Down
95 changes: 95 additions & 0 deletions handler/src/test/java/io/netty/handler/ssl/SniHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;

Expand Down Expand Up @@ -63,6 +65,7 @@
import io.netty.util.internal.ResourcesUtil;
import io.netty.util.internal.StringUtil;
import org.hamcrest.CoreMatchers;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -711,4 +714,96 @@ private static List<ByteBuf> split(ByteBuf clientHello, int maxSize) {
clientHello.release();
return result;
}

@Test
public void testSniHandlerFiresHandshakeTimeout() throws Exception {
SniHandler handler = new SniHandler(new Mapping<String, SslContext>() {
@Override
public SslContext map(String input) {
throw new UnsupportedOperationException("Should not be called");
}
}, 10);

final AtomicReference<SniCompletionEvent> completionEventRef =
new AtomicReference<SniCompletionEvent>();
EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
completionEventRef.set((SniCompletionEvent) evt);
}
}
});
try {
while (completionEventRef.get() == null) {
Thread.sleep(100);
// We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop.
ch.runPendingTasks();
}
SniCompletionEvent completionEvent = completionEventRef.get();
assertNotNull(completionEvent);
assertNotNull(completionEvent.cause());
assertEquals(SslHandshakeTimeoutException.class, completionEvent.cause().getClass());
} finally {
ch.finishAndReleaseAll();
}
}

@ParameterizedTest(name = "{index}: sslProvider={0}")
@MethodSource("data")
public void testSslHandlerFiresHandshakeTimeout(SslProvider provider) throws Exception {
final SslContext context = makeSslContext(provider, false);
SniHandler handler = new SniHandler(new Mapping<String, SslContext>() {
@Override
public SslContext map(String input) {
return context;
}
}, 100);

final AtomicReference<SniCompletionEvent> sniCompletionEventRef =
new AtomicReference<SniCompletionEvent>();
final AtomicReference<SslHandshakeCompletionEvent> handshakeCompletionEventRef =
new AtomicReference<SslHandshakeCompletionEvent>();
EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
sniCompletionEventRef.set((SniCompletionEvent) evt);
} else if (evt instanceof SslHandshakeCompletionEvent) {
handshakeCompletionEventRef.set((SslHandshakeCompletionEvent) evt);
}
}
});
try {
// Send enough data to add the SslHandler and let the handshake incomplete
// Client Hello with "host1" server name
ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(
"16030301800100017c0303478ae7e536aa7a9debad1f873121862d2d3d3173e0ef42975c31007faeb2" +
"52522047f55f81fc84fe58951e2af14026147d6178498fde551fcbafc636462c016ec9005a13011302" +
"c02cc02bc030009dc02ec032009f00a3c02f009cc02dc031009e00a2c024c028003dc026c02a006b00" +
"6ac00ac0140035c005c00f00390038c023c027003cc025c02900670040c009c013002fc004c00e0033" +
"003200ff010000d90000000a0008000005686f737431000500050100000000000a00160014001d0017" +
"00180019001e01000101010201030104000b00020100000d0028002604030503060308040805080608" +
"09080a080b040105010601040203030301030202030201020200320028002604030503060308040805" +
"08060809080a080b040105010601040203030301030202030201020200110009000702000400000000" +
"00170000002b00050403040303002d00020101003300260024001d00200bbc37375e214c1e4e7cb90f" +
"869e131dc983a21f8205ba24456177f340904935")));

while (handshakeCompletionEventRef.get() == null) {
Thread.sleep(10);
// We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop.
ch.runPendingTasks();
}
SniCompletionEvent sniCompletionEvent = sniCompletionEventRef.get();
assertNotNull(sniCompletionEvent);
assertEquals("host1", sniCompletionEvent.hostname());
SslCompletionEvent handshakeCompletionEvent = handshakeCompletionEventRef.get();
assertNotNull(handshakeCompletionEvent);
assertNotNull(handshakeCompletionEvent.cause());
assertEquals(SslHandshakeTimeoutException.class, handshakeCompletionEvent.cause().getClass());
} finally {
ch.finishAndReleaseAll();
releaseAll(context);
}
}
}
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,12 @@
<superClass>io.netty.handler.codec.http.multipart.AbstractMixedHttpData&lt;io.netty.handler.codec.http.multipart.Attribute&gt;</superClass>
<justification>Acceptable incompatibility for required change</justification>
</item>
<item>
<ignore>true</ignore>
<code>java.annotation.removed</code>
<annotation>@io.netty.channel.ChannelHandlerMask.Skip</annotation>
<justification>No change in compatibility</justification>
</item>
</differences>
</revapi.differences>
</analysisConfiguration>
Expand Down

0 comments on commit 0bcc6c8

Please sign in to comment.