Skip to content

Commit

Permalink
Forward early websocket client errors to client future, not the webso…
Browse files Browse the repository at this point in the history
…cket bean (#8300)

When the websocket is closed, but the session has not yet been initialized (and OnOpen not been called), change handleCloseReason to instead emit an error on the publisher that is returned by the `WebSocketClient.connect` method. This means the publisher won't get stuck, and the OnClose method won't be called without a session being available.

Additionally, I refactored `AbstractNettyWebSocketHandler` and the subclasses a bit. The bodyArgument/pongArgument code, and the old callOpenMethod, were only used for the server handler.

Hopefully fixes #7921
  • Loading branch information
yawkat committed Nov 23, 2022
1 parent 8d8d7c0 commit e45b5ea
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 171 deletions.
Expand Up @@ -16,34 +16,28 @@
package io.micronaut.http.client.netty.websocket;

import io.micronaut.core.annotation.Internal;
import io.micronaut.core.async.publisher.Publishers;
import io.micronaut.core.bind.BoundExecutable;
import io.micronaut.core.bind.DefaultExecutableBinder;
import io.micronaut.core.bind.ExecutableBinder;
import io.micronaut.core.convert.ConversionService;
import io.micronaut.core.convert.value.ConvertibleValues;
import io.micronaut.core.type.Argument;
import io.micronaut.http.MutableHttpRequest;
import io.micronaut.http.bind.DefaultRequestBinderRegistry;
import io.micronaut.http.bind.RequestBinderRegistry;
import io.micronaut.http.codec.MediaTypeCodecRegistry;
import io.micronaut.http.netty.websocket.AbstractNettyWebSocketHandler;
import io.micronaut.http.netty.websocket.NettyWebSocketSession;
import io.micronaut.http.uri.UriMatchInfo;
import io.micronaut.http.uri.UriMatchTemplate;
import io.micronaut.inject.MethodExecutionHandle;
import io.micronaut.websocket.CloseReason;
import io.micronaut.websocket.WebSocketPongMessage;
import io.micronaut.websocket.annotation.ClientWebSocket;
import io.micronaut.websocket.bind.WebSocketState;
import io.micronaut.websocket.bind.WebSocketStateBinderRegistry;
import io.micronaut.websocket.context.WebSocketBean;
import io.micronaut.websocket.exceptions.WebSocketClientException;
import io.micronaut.websocket.exceptions.WebSocketSessionException;
import io.micronaut.websocket.interceptor.WebSocketSessionAware;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
Expand All @@ -52,14 +46,12 @@
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
* Handler for WebSocket clients.
Expand All @@ -78,9 +70,7 @@ public class NettyWebSocketClientHandler<T> extends AbstractNettyWebSocketHandle
private final Sinks.One<T> completion = Sinks.one();
private final UriMatchInfo matchInfo;
private final MediaTypeCodecRegistry codecRegistry;
private ChannelPromise handshakeFuture;
private NettyWebSocketSession clientSession;
private final WebSocketStateBinderRegistry webSocketStateBinderRegistry;
private FullHttpResponse handshakeResponse;
private Argument<?> clientBodyArgument;
private Argument<?> clientPongArgument;
Expand All @@ -103,12 +93,9 @@ public NettyWebSocketClientHandler(
this.codecRegistry = mediaTypeCodecRegistry;
this.handshaker = handshaker;
this.genericWebSocketBean = webSocketBean;
this.webSocketStateBinderRegistry = new WebSocketStateBinderRegistry(requestBinderRegistry != null ? requestBinderRegistry : new DefaultRequestBinderRegistry(ConversionService.SHARED));
String clientPath = webSocketBean.getBeanDefinition().stringValue(ClientWebSocket.class).orElse("");
UriMatchTemplate matchTemplate = UriMatchTemplate.of(clientPath);
this.matchInfo = matchTemplate.match(request.getPath()).orElse(null);

callOpenMethod(null);
}

@Override
Expand Down Expand Up @@ -139,14 +126,16 @@ public NettyWebSocketSession getSession() {
return clientSession;
}

@Override
public void handlerAdded(final ChannelHandlerContext ctx) {
handshakeFuture = ctx.newPromise();
}

@Override
public void channelActive(final ChannelHandlerContext ctx) {
handshaker.handshake(ctx.channel());
handshaker.handshake(ctx.channel()).addListener(future -> {
if (future.isSuccess()) {
ctx.channel().config().setAutoRead(true);
ctx.read();
} else {
completion.tryEmitError(future.cause());
}
});
}

@Override
Expand All @@ -168,7 +157,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
}
return;
}
handshakeFuture.setSuccess();

this.clientSession = createWebSocketSession(ctx);

Expand All @@ -178,7 +166,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
((WebSocketSessionAware) targetBean).setWebSocketSession(clientSession);
}


ExecutableBinder<WebSocketState> binder = new DefaultExecutableBinder<>();
BoundExecutable<?, ?> bound = binder.tryBind(messageHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(clientSession, originatingRequest));
List<Argument<?>> unboundArguments = bound.getUnboundArguments();
Expand Down Expand Up @@ -218,37 +205,11 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
}
}

Optional<? extends MethodExecutionHandle<?, ?>> opt = webSocketBean.openMethod();
if (opt.isPresent()) {
MethodExecutionHandle<?, ?> openMethod = opt.get();

WebSocketState webSocketState = new WebSocketState(clientSession, originatingRequest);
try {
BoundExecutable openMethodBound = binder.bind(openMethod.getExecutableMethod(), webSocketStateBinderRegistry, webSocketState);
Object target = openMethod.getTarget();
Object result = openMethodBound.invoke(target);

if (Publishers.isConvertibleToPublisher(result)) {
Publisher<?> reactiveSequence = Publishers.convertPublisher(result, Publisher.class);
Flux.from(reactiveSequence).subscribe(
o -> { },
error -> completion.tryEmitError(new WebSocketSessionException("Error opening WebSocket client session: " + error.getMessage(), error)),
() -> {
completion.tryEmitValue(targetBean);
}
);
} else {
completion.tryEmitValue(targetBean);
}
} catch (Throwable e) {
completion.tryEmitError(new WebSocketClientException("Error opening WebSocket client session: " + e.getMessage(), e));
if (getSession().isOpen()) {
getSession().close(CloseReason.INTERNAL_ERROR);
}
}
} else {
completion.tryEmitValue(targetBean);
}
Flux.from(callOpenMethod(ctx)).subscribe(
o -> { },
error -> completion.tryEmitError(new WebSocketSessionException("Error opening WebSocket client session: " + error.getMessage(), error)),
() -> completion.tryEmitValue(targetBean)
);
return;
}

Expand All @@ -257,8 +218,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
} else {
ctx.fireChannelRead(msg);
}


}

@Override
Expand Down Expand Up @@ -286,14 +245,20 @@ public ConvertibleValues<Object> getUriVariables() {

@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
if (!handshakeFuture.isDone()) {
handshakeFuture.setFailure(cause);
}

completion.tryEmitError(cause);
super.exceptionCaught(ctx, cause);
}

public final Mono<T> getHandshakeCompletedMono() {
return completion.asMono();
}

@Override
protected void handleCloseReason(ChannelHandlerContext ctx, CloseReason cr, boolean writeCloseReason) {
if (!handshaker.isHandshakeComplete()) {
completion.tryEmitError(new WebSocketClientException("Error opening WebSocket client session: " + cr.getReason()));
return;
}
super.handleCloseReason(ctx, cr, writeCloseReason);
}
}
@@ -0,0 +1,77 @@
package io.micronaut.http.client.websocket

import io.micronaut.context.ApplicationContext
import io.micronaut.context.annotation.Requires
import io.micronaut.websocket.WebSocketClient
import io.micronaut.websocket.annotation.ClientWebSocket
import io.micronaut.websocket.annotation.OnClose
import io.micronaut.websocket.annotation.OnMessage
import io.micronaut.websocket.annotation.OnOpen
import io.micronaut.websocket.exceptions.WebSocketClientException
import jakarta.inject.Inject
import jakarta.inject.Singleton
import reactor.core.publisher.Mono
import spock.lang.Specification

import java.util.concurrent.ExecutionException

class ClientWebsocketSpec extends Specification {
void 'websocket bean should not open if there is a connection error'() {
given:
def ctx = ApplicationContext.run(['spec.name': 'ClientWebsocketSpec'])
def client = ctx.getBean(WebSocketClient)
def registry = ctx.getBean(ClientBeanRegistry)
def mono = Mono.from(client.connect(ClientBean.class, 'http://does-not-exist'))

when:
mono.toFuture().get()
then:
def e = thrown ExecutionException
e.cause instanceof WebSocketClientException

registry.clientBeans.size() == 1
!registry.clientBeans[0].opened
!registry.clientBeans[0].autoClosed
!registry.clientBeans[0].onClosed

cleanup:
client.close()
}

@Singleton
@Requires(property = 'spec.name', value = 'ClientWebsocketSpec')
static class ClientBeanRegistry {
List<ClientBean> clientBeans = new ArrayList<>()
}

@ClientWebSocket
static class ClientBean implements AutoCloseable {
boolean opened = false
boolean onClosed = false
boolean autoClosed = false

@Inject
ClientBean(ClientBeanRegistry registry) {
registry.clientBeans.add(this)
}

@OnOpen
void open() {
opened = true
}

@OnMessage
void onMessage(String text) {
}

@OnClose
void onClose() {
onClosed = true
}

@Override
void close() throws Exception {
autoClosed = true
}
}
}

0 comments on commit e45b5ea

Please sign in to comment.