Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force secure WebSocket connections to use http/1.1 #10754

Merged
merged 5 commits into from Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -63,6 +63,12 @@ public final class HttpVersionSelection {
true
);

private static final HttpVersionSelection WEBSOCKET_1 = new HttpVersionSelection(
HttpVersionSelection.PlaintextMode.HTTP_1,
true,
new String[]{HttpVersionSelection.ALPN_HTTP_1},
false);

private final PlaintextMode plaintextMode;
private final boolean alpn;
private final String[] alpnSupportedProtocols;
Expand Down Expand Up @@ -100,6 +106,17 @@ public static HttpVersionSelection forLegacyVersion(@NonNull HttpVersion httpVer
}
}

/**
* Get the {@link HttpVersionSelection} to be used for a WebSocket connection, which will enable
* ALPN but constrain the mode to HTTP 1.1.
*
* @return The version selection for WebSocket
*/
@NonNull
public static HttpVersionSelection forWebsocket() {
return WEBSOCKET_1;
}

/**
* Construct a version selection from the given client configuration.
*
Expand Down
Expand Up @@ -144,6 +144,7 @@ public class ConnectionManager {
private final HttpClientConfiguration configuration;
private volatile SslContext sslContext;
private volatile /* QuicSslContext */ Object http3SslContext;
private volatile SslContext websocketSslContext;
private final NettyClientCustomizer clientCustomizer;
private final String informationalServiceId;

Expand All @@ -165,6 +166,7 @@ public class ConnectionManager {
this.configuration = from.configuration;
this.sslContext = from.sslContext;
this.http3SslContext = from.http3SslContext;
this.websocketSslContext = from.websocketSslContext;
this.clientCustomizer = from.clientCustomizer;
this.informationalServiceId = from.informationalServiceId;
this.nettyClientSslBuilder = from.nettyClientSslBuilder;
Expand Down Expand Up @@ -209,6 +211,8 @@ public class ConnectionManager {

final void refresh() {
SslContext oldSslContext = sslContext;
SslContext oldWebsocketSslContext = websocketSslContext;
websocketSslContext = null;
if (configuration.getSslConfiguration().isEnabled()) {
sslContext = nettyClientSslBuilder.build(configuration.getSslConfiguration(), httpVersion);
} else {
Expand All @@ -224,6 +228,7 @@ final void refresh() {
pool.forEachConnection(c -> ((Pool.ConnectionHolder) c).windDownConnection());
}
ReferenceCountUtil.release(oldSslContext);
ReferenceCountUtil.release(oldWebsocketSslContext);
}

/**
Expand Down Expand Up @@ -369,7 +374,9 @@ public final void shutdown() {
}
}
ReferenceCountUtil.release(sslContext);
ReferenceCountUtil.release(websocketSslContext);
sslContext = null;
websocketSslContext = null;
}

/**
Expand Down Expand Up @@ -432,6 +439,32 @@ public final Mono<PoolHandle> connect(DefaultHttpClient.RequestKey requestKey, @
return pools.computeIfAbsent(requestKey, Pool::new).acquire(blockHint);
}

/**
* Builds an {@link SslContext} for the given WebSocket URI if necessary.
*
* @return The {@link SslContext} instance
*/
@Nullable
private SslContext buildWebsocketSslContext(DefaultHttpClient.RequestKey requestKey) {
SslContext sslCtx = websocketSslContext;
if (requestKey.isSecure()) {
if (configuration.getSslConfiguration().isEnabled()) {
if (sslCtx == null) {
synchronized (this) {
sslCtx = websocketSslContext;
if (sslCtx == null) {
sslCtx = nettyClientSslBuilder.build(configuration.getSslConfiguration(), HttpVersionSelection.forWebsocket());
websocketSslContext = sslCtx;
}
}
}
} else if (configuration.getProxyAddress().isEmpty()){
throw decorate(new HttpClientException("Cannot send WSS request. SSL is disabled"));
}
}
return sslCtx;
}

/**
* Connect to a remote websocket. The given {@link ChannelHandler} is added to the pipeline
* when the handshakes complete.
Expand All @@ -448,7 +481,7 @@ final Mono<?> connectForWebsocket(DefaultHttpClient.RequestKey requestKey, Chann
protected void initChannel(@NonNull Channel ch) {
addLogHandler(ch);

SslContext sslContext = buildSslContext(requestKey);
SslContext sslContext = buildWebsocketSslContext(requestKey);
if (sslContext != null) {
ch.pipeline().addLast(configureSslHandler(sslContext.newHandler(ch.alloc(), requestKey.getHost(), requestKey.getPort())));
}
Expand Down
Expand Up @@ -11,6 +11,7 @@ import io.micronaut.websocket.exceptions.WebSocketClientException
import jakarta.inject.Inject
import jakarta.inject.Singleton
import reactor.core.publisher.Mono
import spock.lang.Issue
import spock.lang.Specification

import java.util.concurrent.ExecutionException
Expand Down Expand Up @@ -38,6 +39,47 @@ class ClientWebsocketSpec extends Specification {
client.close()
}

@Issue("https://github.com/micronaut-projects/micronaut-core/issues/10744")
void 'websocket bean can connect to echo server over SSL with wss scheme'() {
given:
def ctx = ApplicationContext.run(['spec.name': 'ClientWebsocketSpec'])
def client = ctx.getBean(WebSocketClient)
def registry = ctx.getBean(ClientBeanRegistry)
def mono = Mono.from(client.connect(ClientBean.class, 'wss://websocket-echo.com'))

when:
mono.toFuture().get()

then:
registry.clientBeans.size() == 1
registry.clientBeans[0].opened
!registry.clientBeans[0].autoClosed
!registry.clientBeans[0].onClosed

cleanup:
client.close()
}

void 'websocket bean can connect to echo server over SSL with https scheme'() {
given:
def ctx = ApplicationContext.run(['spec.name': 'ClientWebsocketSpec'])//, "micronaut.http.client.alpn-modes":"http/1.1"])
def client = ctx.getBean(WebSocketClient)
def registry = ctx.getBean(ClientBeanRegistry)
def mono = Mono.from(client.connect(ClientBean.class, 'https://websocket-echo.com'))

when:
mono.toFuture().get()

then:
registry.clientBeans.size() == 1
registry.clientBeans[0].opened
!registry.clientBeans[0].autoClosed
!registry.clientBeans[0].onClosed

cleanup:
client.close()
}

@Singleton
@Requires(property = 'spec.name', value = 'ClientWebsocketSpec')
static class ClientBeanRegistry {
Expand Down