Skip to content

Commit

Permalink
Force secure WebSocket connections to use http/1.1 (#10754)
Browse files Browse the repository at this point in the history
* Force secure WebSocket connections to use http/1.1

Connection manager is updated to use a separate SSLContext for
WebSocket connections that will only advertise http/1.1 in the list
of supported protocols in the ALPN section of the TLS handshake.

WebSocket is not currently supported over HTTP 2 connections, thus
if an HTTP 2 connection is established through ALPN, the subsequent
upgrade to the WebSocket protocol would fail.

This resolves #10744

* Enable ALPN with Websocket and ensure SSLContext released.

* Lazily create websocket SSL context per connection

* Use lazily intialized field for WebSocket SSL context.

* Actually initialize in the initializer.
  • Loading branch information
jeremyg484 committed Apr 25, 2024
1 parent 6a78f0c commit dc3e41f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
Expand Up @@ -63,6 +63,12 @@ public final class HttpVersionSelection {
true
);

private static final HttpVersionSelection WEBSOCKET_1 = new HttpVersionSelection(
HttpVersionSelection.PlaintextMode.HTTP_1,
true,
new String[]{HttpVersionSelection.ALPN_HTTP_1},
false);

private final PlaintextMode plaintextMode;
private final boolean alpn;
private final String[] alpnSupportedProtocols;
Expand Down Expand Up @@ -100,6 +106,17 @@ public static HttpVersionSelection forLegacyVersion(@NonNull HttpVersion httpVer
}
}

/**
* Get the {@link HttpVersionSelection} to be used for a WebSocket connection, which will enable
* ALPN but constrain the mode to HTTP 1.1.
*
* @return The version selection for WebSocket
*/
@NonNull
public static HttpVersionSelection forWebsocket() {
return WEBSOCKET_1;
}

/**
* Construct a version selection from the given client configuration.
*
Expand Down
Expand Up @@ -144,6 +144,7 @@ public class ConnectionManager {
private final HttpClientConfiguration configuration;
private volatile SslContext sslContext;
private volatile /* QuicSslContext */ Object http3SslContext;
private volatile SslContext websocketSslContext;
private final NettyClientCustomizer clientCustomizer;
private final String informationalServiceId;

Expand All @@ -165,6 +166,7 @@ public class ConnectionManager {
this.configuration = from.configuration;
this.sslContext = from.sslContext;
this.http3SslContext = from.http3SslContext;
this.websocketSslContext = from.websocketSslContext;
this.clientCustomizer = from.clientCustomizer;
this.informationalServiceId = from.informationalServiceId;
this.nettyClientSslBuilder = from.nettyClientSslBuilder;
Expand Down Expand Up @@ -209,6 +211,8 @@ public class ConnectionManager {

final void refresh() {
SslContext oldSslContext = sslContext;
SslContext oldWebsocketSslContext = websocketSslContext;
websocketSslContext = null;
if (configuration.getSslConfiguration().isEnabled()) {
sslContext = nettyClientSslBuilder.build(configuration.getSslConfiguration(), httpVersion);
} else {
Expand All @@ -224,6 +228,7 @@ final void refresh() {
pool.forEachConnection(c -> ((Pool.ConnectionHolder) c).windDownConnection());
}
ReferenceCountUtil.release(oldSslContext);
ReferenceCountUtil.release(oldWebsocketSslContext);
}

/**
Expand Down Expand Up @@ -369,7 +374,9 @@ public final void shutdown() {
}
}
ReferenceCountUtil.release(sslContext);
ReferenceCountUtil.release(websocketSslContext);
sslContext = null;
websocketSslContext = null;
}

/**
Expand Down Expand Up @@ -432,6 +439,32 @@ public final Mono<PoolHandle> connect(DefaultHttpClient.RequestKey requestKey, @
return pools.computeIfAbsent(requestKey, Pool::new).acquire(blockHint);
}

/**
* Builds an {@link SslContext} for the given WebSocket URI if necessary.
*
* @return The {@link SslContext} instance
*/
@Nullable
private SslContext buildWebsocketSslContext(DefaultHttpClient.RequestKey requestKey) {
SslContext sslCtx = websocketSslContext;
if (requestKey.isSecure()) {
if (configuration.getSslConfiguration().isEnabled()) {
if (sslCtx == null) {
synchronized (this) {
sslCtx = websocketSslContext;
if (sslCtx == null) {
sslCtx = nettyClientSslBuilder.build(configuration.getSslConfiguration(), HttpVersionSelection.forWebsocket());
websocketSslContext = sslCtx;
}
}
}
} else if (configuration.getProxyAddress().isEmpty()){
throw decorate(new HttpClientException("Cannot send WSS request. SSL is disabled"));
}
}
return sslCtx;
}

/**
* Connect to a remote websocket. The given {@link ChannelHandler} is added to the pipeline
* when the handshakes complete.
Expand All @@ -448,7 +481,7 @@ final Mono<?> connectForWebsocket(DefaultHttpClient.RequestKey requestKey, Chann
protected void initChannel(@NonNull Channel ch) {
addLogHandler(ch);

SslContext sslContext = buildSslContext(requestKey);
SslContext sslContext = buildWebsocketSslContext(requestKey);
if (sslContext != null) {
ch.pipeline().addLast(configureSslHandler(sslContext.newHandler(ch.alloc(), requestKey.getHost(), requestKey.getPort())));
}
Expand Down
Expand Up @@ -11,6 +11,7 @@ import io.micronaut.websocket.exceptions.WebSocketClientException
import jakarta.inject.Inject
import jakarta.inject.Singleton
import reactor.core.publisher.Mono
import spock.lang.Issue
import spock.lang.Specification

import java.util.concurrent.ExecutionException
Expand Down Expand Up @@ -38,6 +39,47 @@ class ClientWebsocketSpec extends Specification {
client.close()
}

@Issue("https://github.com/micronaut-projects/micronaut-core/issues/10744")
void 'websocket bean can connect to echo server over SSL with wss scheme'() {
given:
def ctx = ApplicationContext.run(['spec.name': 'ClientWebsocketSpec'])
def client = ctx.getBean(WebSocketClient)
def registry = ctx.getBean(ClientBeanRegistry)
def mono = Mono.from(client.connect(ClientBean.class, 'wss://websocket-echo.com'))

when:
mono.toFuture().get()

then:
registry.clientBeans.size() == 1
registry.clientBeans[0].opened
!registry.clientBeans[0].autoClosed
!registry.clientBeans[0].onClosed

cleanup:
client.close()
}

void 'websocket bean can connect to echo server over SSL with https scheme'() {
given:
def ctx = ApplicationContext.run(['spec.name': 'ClientWebsocketSpec'])//, "micronaut.http.client.alpn-modes":"http/1.1"])
def client = ctx.getBean(WebSocketClient)
def registry = ctx.getBean(ClientBeanRegistry)
def mono = Mono.from(client.connect(ClientBean.class, 'https://websocket-echo.com'))

when:
mono.toFuture().get()

then:
registry.clientBeans.size() == 1
registry.clientBeans[0].opened
!registry.clientBeans[0].autoClosed
!registry.clientBeans[0].onClosed

cleanup:
client.close()
}

@Singleton
@Requires(property = 'spec.name', value = 'ClientWebsocketSpec')
static class ClientBeanRegistry {
Expand Down

0 comments on commit dc3e41f

Please sign in to comment.