From e45b5ea72350fe54c88f6ded3abb461be5263ff3 Mon Sep 17 00:00:00 2001 From: Jonas Konrad Date: Wed, 23 Nov 2022 12:10:52 +0100 Subject: [PATCH] Forward early websocket client errors to client future, not the websocket bean (#8300) When the websocket is closed, but the session has not yet been initialized (and OnOpen not been called), change handleCloseReason to instead emit an error on the publisher that is returned by the `WebSocketClient.connect` method. This means the publisher won't get stuck, and the OnClose method won't be called without a session being available. Additionally, I refactored `AbstractNettyWebSocketHandler` and the subclasses a bit. The bodyArgument/pongArgument code, and the old callOpenMethod, were only used for the server handler. Hopefully fixes #7921 --- .../NettyWebSocketClientHandler.java | 81 +++------- .../websocket/ClientWebsocketSpec.groovy | 77 ++++++++++ .../AbstractNettyWebSocketHandler.java | 142 +++++------------- .../NettyServerWebSocketHandler.java | 86 ++++++++++- 4 files changed, 215 insertions(+), 171 deletions(-) create mode 100644 http-client/src/test/groovy/io/micronaut/http/client/websocket/ClientWebsocketSpec.groovy diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java b/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java index 5e825b38400..9aa4b5772ad 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/websocket/NettyWebSocketClientHandler.java @@ -16,34 +16,28 @@ package io.micronaut.http.client.netty.websocket; import io.micronaut.core.annotation.Internal; -import io.micronaut.core.async.publisher.Publishers; import io.micronaut.core.bind.BoundExecutable; import io.micronaut.core.bind.DefaultExecutableBinder; import io.micronaut.core.bind.ExecutableBinder; -import io.micronaut.core.convert.ConversionService; import io.micronaut.core.convert.value.ConvertibleValues; import io.micronaut.core.type.Argument; import io.micronaut.http.MutableHttpRequest; -import io.micronaut.http.bind.DefaultRequestBinderRegistry; import io.micronaut.http.bind.RequestBinderRegistry; import io.micronaut.http.codec.MediaTypeCodecRegistry; import io.micronaut.http.netty.websocket.AbstractNettyWebSocketHandler; import io.micronaut.http.netty.websocket.NettyWebSocketSession; import io.micronaut.http.uri.UriMatchInfo; import io.micronaut.http.uri.UriMatchTemplate; -import io.micronaut.inject.MethodExecutionHandle; import io.micronaut.websocket.CloseReason; import io.micronaut.websocket.WebSocketPongMessage; import io.micronaut.websocket.annotation.ClientWebSocket; import io.micronaut.websocket.bind.WebSocketState; -import io.micronaut.websocket.bind.WebSocketStateBinderRegistry; import io.micronaut.websocket.context.WebSocketBean; import io.micronaut.websocket.exceptions.WebSocketClientException; import io.micronaut.websocket.exceptions.WebSocketSessionException; import io.micronaut.websocket.interceptor.WebSocketSessionAware; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; @@ -52,14 +46,12 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; -import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import java.util.Collections; import java.util.List; -import java.util.Optional; /** * Handler for WebSocket clients. @@ -78,9 +70,7 @@ public class NettyWebSocketClientHandler extends AbstractNettyWebSocketHandle private final Sinks.One completion = Sinks.one(); private final UriMatchInfo matchInfo; private final MediaTypeCodecRegistry codecRegistry; - private ChannelPromise handshakeFuture; private NettyWebSocketSession clientSession; - private final WebSocketStateBinderRegistry webSocketStateBinderRegistry; private FullHttpResponse handshakeResponse; private Argument clientBodyArgument; private Argument clientPongArgument; @@ -103,12 +93,9 @@ public NettyWebSocketClientHandler( this.codecRegistry = mediaTypeCodecRegistry; this.handshaker = handshaker; this.genericWebSocketBean = webSocketBean; - this.webSocketStateBinderRegistry = new WebSocketStateBinderRegistry(requestBinderRegistry != null ? requestBinderRegistry : new DefaultRequestBinderRegistry(ConversionService.SHARED)); String clientPath = webSocketBean.getBeanDefinition().stringValue(ClientWebSocket.class).orElse(""); UriMatchTemplate matchTemplate = UriMatchTemplate.of(clientPath); this.matchInfo = matchTemplate.match(request.getPath()).orElse(null); - - callOpenMethod(null); } @Override @@ -139,14 +126,16 @@ public NettyWebSocketSession getSession() { return clientSession; } - @Override - public void handlerAdded(final ChannelHandlerContext ctx) { - handshakeFuture = ctx.newPromise(); - } - @Override public void channelActive(final ChannelHandlerContext ctx) { - handshaker.handshake(ctx.channel()); + handshaker.handshake(ctx.channel()).addListener(future -> { + if (future.isSuccess()) { + ctx.channel().config().setAutoRead(true); + ctx.read(); + } else { + completion.tryEmitError(future.cause()); + } + }); } @Override @@ -168,7 +157,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { } return; } - handshakeFuture.setSuccess(); this.clientSession = createWebSocketSession(ctx); @@ -178,7 +166,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { ((WebSocketSessionAware) targetBean).setWebSocketSession(clientSession); } - ExecutableBinder binder = new DefaultExecutableBinder<>(); BoundExecutable bound = binder.tryBind(messageHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(clientSession, originatingRequest)); List> unboundArguments = bound.getUnboundArguments(); @@ -218,37 +205,11 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { } } - Optional> opt = webSocketBean.openMethod(); - if (opt.isPresent()) { - MethodExecutionHandle openMethod = opt.get(); - - WebSocketState webSocketState = new WebSocketState(clientSession, originatingRequest); - try { - BoundExecutable openMethodBound = binder.bind(openMethod.getExecutableMethod(), webSocketStateBinderRegistry, webSocketState); - Object target = openMethod.getTarget(); - Object result = openMethodBound.invoke(target); - - if (Publishers.isConvertibleToPublisher(result)) { - Publisher reactiveSequence = Publishers.convertPublisher(result, Publisher.class); - Flux.from(reactiveSequence).subscribe( - o -> { }, - error -> completion.tryEmitError(new WebSocketSessionException("Error opening WebSocket client session: " + error.getMessage(), error)), - () -> { - completion.tryEmitValue(targetBean); - } - ); - } else { - completion.tryEmitValue(targetBean); - } - } catch (Throwable e) { - completion.tryEmitError(new WebSocketClientException("Error opening WebSocket client session: " + e.getMessage(), e)); - if (getSession().isOpen()) { - getSession().close(CloseReason.INTERNAL_ERROR); - } - } - } else { - completion.tryEmitValue(targetBean); - } + Flux.from(callOpenMethod(ctx)).subscribe( + o -> { }, + error -> completion.tryEmitError(new WebSocketSessionException("Error opening WebSocket client session: " + error.getMessage(), error)), + () -> completion.tryEmitValue(targetBean) + ); return; } @@ -257,8 +218,6 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { } else { ctx.fireChannelRead(msg); } - - } @Override @@ -286,14 +245,20 @@ public ConvertibleValues getUriVariables() { @Override public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) { - if (!handshakeFuture.isDone()) { - handshakeFuture.setFailure(cause); - } - + completion.tryEmitError(cause); super.exceptionCaught(ctx, cause); } public final Mono getHandshakeCompletedMono() { return completion.asMono(); } + + @Override + protected void handleCloseReason(ChannelHandlerContext ctx, CloseReason cr, boolean writeCloseReason) { + if (!handshaker.isHandshakeComplete()) { + completion.tryEmitError(new WebSocketClientException("Error opening WebSocket client session: " + cr.getReason())); + return; + } + super.handleCloseReason(ctx, cr, writeCloseReason); + } } diff --git a/http-client/src/test/groovy/io/micronaut/http/client/websocket/ClientWebsocketSpec.groovy b/http-client/src/test/groovy/io/micronaut/http/client/websocket/ClientWebsocketSpec.groovy new file mode 100644 index 00000000000..71957bfc1cf --- /dev/null +++ b/http-client/src/test/groovy/io/micronaut/http/client/websocket/ClientWebsocketSpec.groovy @@ -0,0 +1,77 @@ +package io.micronaut.http.client.websocket + +import io.micronaut.context.ApplicationContext +import io.micronaut.context.annotation.Requires +import io.micronaut.websocket.WebSocketClient +import io.micronaut.websocket.annotation.ClientWebSocket +import io.micronaut.websocket.annotation.OnClose +import io.micronaut.websocket.annotation.OnMessage +import io.micronaut.websocket.annotation.OnOpen +import io.micronaut.websocket.exceptions.WebSocketClientException +import jakarta.inject.Inject +import jakarta.inject.Singleton +import reactor.core.publisher.Mono +import spock.lang.Specification + +import java.util.concurrent.ExecutionException + +class ClientWebsocketSpec extends Specification { + void 'websocket bean should not open if there is a connection error'() { + given: + def ctx = ApplicationContext.run(['spec.name': 'ClientWebsocketSpec']) + def client = ctx.getBean(WebSocketClient) + def registry = ctx.getBean(ClientBeanRegistry) + def mono = Mono.from(client.connect(ClientBean.class, 'http://does-not-exist')) + + when: + mono.toFuture().get() + then: + def e = thrown ExecutionException + e.cause instanceof WebSocketClientException + + registry.clientBeans.size() == 1 + !registry.clientBeans[0].opened + !registry.clientBeans[0].autoClosed + !registry.clientBeans[0].onClosed + + cleanup: + client.close() + } + + @Singleton + @Requires(property = 'spec.name', value = 'ClientWebsocketSpec') + static class ClientBeanRegistry { + List clientBeans = new ArrayList<>() + } + + @ClientWebSocket + static class ClientBean implements AutoCloseable { + boolean opened = false + boolean onClosed = false + boolean autoClosed = false + + @Inject + ClientBean(ClientBeanRegistry registry) { + registry.clientBeans.add(this) + } + + @OnOpen + void open() { + opened = true + } + + @OnMessage + void onMessage(String text) { + } + + @OnClose + void onClose() { + onClosed = true + } + + @Override + void close() throws Exception { + autoClosed = true + } + } +} diff --git a/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java b/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java index 405515fba9c..c3b6b0e6a9a 100644 --- a/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java +++ b/http-netty/src/main/java/io/micronaut/http/netty/websocket/AbstractNettyWebSocketHandler.java @@ -35,6 +35,7 @@ import io.micronaut.inject.MethodExecutionHandle; import io.micronaut.websocket.CloseReason; import io.micronaut.websocket.WebSocketPongMessage; +import io.micronaut.websocket.WebSocketSession; import io.micronaut.websocket.bind.WebSocketState; import io.micronaut.websocket.bind.WebSocketStateBinderRegistry; import io.micronaut.websocket.context.WebSocketBean; @@ -55,6 +56,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import java.io.IOException; @@ -89,13 +91,10 @@ public abstract class AbstractNettyWebSocketHandler extends SimpleChannelInbound protected final HttpRequest originatingRequest; protected final MethodExecutionHandle messageHandler; protected final MethodExecutionHandle pongHandler; - protected final NettyWebSocketSession session; protected final MediaTypeCodecRegistry mediaTypeCodecRegistry; protected final WebSocketVersion webSocketVersion; protected final String subProtocol; protected final WebSocketSessionRepository webSocketSessionRepository; - private final Argument bodyArgument; - private final Argument pongArgument; private final AtomicBoolean closed = new AtomicBoolean(false); private AtomicReference frameBuffer = new AtomicReference<>(); @@ -132,138 +131,68 @@ protected AbstractNettyWebSocketHandler( this.pongHandler = webSocketBean.pongMethod().orElse(null); this.mediaTypeCodecRegistry = mediaTypeCodecRegistry; this.webSocketVersion = version; - this.session = createWebSocketSession(ctx); - - if (session != null) { - - ExecutableBinder binder = new DefaultExecutableBinder<>(); - - if (messageHandler != null) { - BoundExecutable bound = binder.tryBind(messageHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(session, originatingRequest)); - List> unboundArguments = bound.getUnboundArguments(); - - if (unboundArguments.size() == 1) { - this.bodyArgument = unboundArguments.iterator().next(); - } else { - this.bodyArgument = null; - if (LOG.isErrorEnabled()) { - LOG.error("WebSocket @OnMessage method " + webSocketBean.getTarget() + "." + messageHandler.getExecutableMethod() + " should define exactly 1 message parameter, but found 2 possible candidates: " + unboundArguments); - } - - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } - } - } else { - this.bodyArgument = null; - } - - if (pongHandler != null) { - BoundExecutable bound = binder.tryBind(pongHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(session, originatingRequest)); - List> unboundArguments = bound.getUnboundArguments(); - if (unboundArguments.size() == 1 && unboundArguments.get(0).isAssignableFrom(WebSocketPongMessage.class)) { - this.pongArgument = unboundArguments.get(0); - } else { - this.pongArgument = null; - if (LOG.isErrorEnabled()) { - LOG.error("WebSocket @OnMessage pong handler method " + webSocketBean.getTarget() + "." + pongHandler.getExecutableMethod() + " should define exactly 1 message parameter assignable from a WebSocketPongMessage, but found: " + unboundArguments); - } - - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } - } - } else { - this.pongArgument = null; - } - } else { - this.bodyArgument = null; - this.pongArgument = null; - } } /** * Calls the open method of the websocket bean. * - * @param ctx THe handler context + * @param ctx The handler context + * @return Publisher for any errors, or the result of the open method */ - protected void callOpenMethod(ChannelHandlerContext ctx) { - if (session == null) { - return; - } + protected Publisher callOpenMethod(ChannelHandlerContext ctx) { + WebSocketSession session = getSession(); Optional> executionHandle = webSocketBean.openMethod(); if (executionHandle.isPresent()) { MethodExecutionHandle openMethod = executionHandle.get(); - BoundExecutable boundExecutable = null; + + BoundExecutable boundExecutable; try { boundExecutable = bindMethod(originatingRequest, webSocketBinder, openMethod, Collections.emptyList()); } catch (Throwable e) { - if (LOG.isErrorEnabled()) { - LOG.error("Error Binding method @OnOpen for WebSocket [" + webSocketBean + "]: " + e.getMessage(), e); - } - if (session.isOpen()) { session.close(CloseReason.INTERNAL_ERROR); } + return Mono.error(e); } - if (boundExecutable != null) { - try { - BoundExecutable finalBoundExecutable = boundExecutable; - Object result = invokeExecutable(finalBoundExecutable, openMethod); - if (Publishers.isConvertibleToPublisher(result)) { - Flux flowable = Flux.from(instrumentPublisher(ctx, result)); - flowable.subscribe( - o -> { - }, - error -> { - if (LOG.isErrorEnabled()) { - LOG.error("Error Opening WebSocket [" + webSocketBean + "]: " + error.getMessage(), error); - } - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } - }, - () -> { - } - ); - } - } catch (Throwable e) { - forwardErrorToUser(ctx, t -> { - if (LOG.isErrorEnabled()) { - LOG.error("Error Opening WebSocket [" + webSocketBean + "]: " + t.getMessage(), t); + try { + Object result = invokeExecutable(boundExecutable, openMethod); + if (Publishers.isConvertibleToPublisher(result)) { + return Flux.from(instrumentPublisher(ctx, result)).doOnError(t -> { + if (session.isOpen()) { + session.close(CloseReason.INTERNAL_ERROR); } - }, e); - // since we failed to call onOpen, we should always close here - if (session.isOpen()) { - session.close(CloseReason.INTERNAL_ERROR); - } + }); + } else { + return Mono.empty(); + } + } catch (Throwable e) { + // since we failed to call onOpen, we should always close here + if (session.isOpen()) { + session.close(CloseReason.INTERNAL_ERROR); } + return Mono.error(e); } + } else { + return Mono.empty(); } } /** * @return The body argument for the message handler */ - public Argument getBodyArgument() { - return bodyArgument; - } + public abstract Argument getBodyArgument(); /** * @return The pong argument for the pong handler */ - public Argument getPongArgument() { - return pongArgument; - } + public abstract Argument getPongArgument(); /** * @return The session */ - public NettyWebSocketSession getSession() { - return session; - } + public abstract NettyWebSocketSession getSession(); @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @@ -271,7 +200,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { forwardErrorToUser(ctx, e -> handleUnexpected(ctx, e), cause); } - private void forwardErrorToUser(ChannelHandlerContext ctx, Consumer fallback, Throwable cause) { + protected final void forwardErrorToUser(ChannelHandlerContext ctx, Consumer fallback, Throwable cause) { Optional> opt = webSocketBean.errorMethod(); if (opt.isPresent()) { @@ -443,10 +372,10 @@ protected void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame ms o -> { }, error -> messageProcessingException(ctx, error), - () -> messageHandled(ctx, session, v) + () -> messageHandled(ctx, v) ); } else { - messageHandled(ctx, session, v); + messageHandled(ctx, v); } } catch (Throwable e) { messageProcessingException(ctx, e); @@ -528,10 +457,9 @@ private void messageProcessingException(ChannelHandlerContext ctx, Throwable e) * Method called once a message has been handled by the handler. * * @param ctx The channel handler context - * @param session The session * @param message The message that was handled */ - protected void messageHandled(ChannelHandlerContext ctx, NettyWebSocketSession session, Object message) { + protected void messageHandled(ChannelHandlerContext ctx, Object message) { // no-op } @@ -547,12 +475,12 @@ protected void writeCloseFrameAndTerminate(ChannelHandlerContext ctx, CloseReaso } /** - * Used to close thee session with a given reason. + * Used to close the session with a given reason. * @param ctx The context * @param cr The reason * @param writeCloseReason Whether to allow writing the close reason to the remote */ - private void handleCloseReason(ChannelHandlerContext ctx, CloseReason cr, boolean writeCloseReason) { + protected void handleCloseReason(ChannelHandlerContext ctx, CloseReason cr, boolean writeCloseReason) { cleanupBuffer(); if (closed.compareAndSet(false, true)) { if (LOG.isDebugEnabled()) { diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java index 2f40262ab9f..817124f45f2 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketHandler.java @@ -20,7 +20,10 @@ import io.micronaut.core.annotation.Nullable; import io.micronaut.core.async.publisher.Publishers; import io.micronaut.core.bind.BoundExecutable; +import io.micronaut.core.bind.DefaultExecutableBinder; +import io.micronaut.core.bind.ExecutableBinder; import io.micronaut.core.convert.value.ConvertibleValues; +import io.micronaut.core.type.Argument; import io.micronaut.core.type.Executable; import io.micronaut.core.util.KotlinUtils; import io.micronaut.http.HttpAttributes; @@ -36,7 +39,9 @@ import io.micronaut.inject.MethodExecutionHandle; import io.micronaut.web.router.UriRouteMatch; import io.micronaut.websocket.CloseReason; +import io.micronaut.websocket.WebSocketPongMessage; import io.micronaut.websocket.WebSocketSession; +import io.micronaut.websocket.bind.WebSocketState; import io.micronaut.websocket.context.WebSocketBean; import io.micronaut.websocket.event.WebSocketMessageProcessedEvent; import io.micronaut.websocket.event.WebSocketSessionClosedEvent; @@ -56,6 +61,7 @@ import reactor.core.scheduler.Schedulers; import java.security.Principal; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -77,10 +83,14 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { */ public static final String ID = "websocket-handler"; + private final NettyWebSocketSession serverSession; private final NettyEmbeddedServices nettyEmbeddedServices; @Nullable private final CoroutineHelper coroutineHelper; + private final Argument bodyArgument; + private final Argument pongArgument; + /** * Default constructor. * @@ -113,18 +123,67 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { handshaker.selectedSubprotocol(), webSocketSessionRepository); + this.serverSession = createWebSocketSession(ctx); + + ExecutableBinder binder = new DefaultExecutableBinder<>(); + + if (messageHandler != null) { + BoundExecutable bound = binder.tryBind(messageHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(serverSession, originatingRequest)); + List> unboundArguments = bound.getUnboundArguments(); + + if (unboundArguments.size() == 1) { + this.bodyArgument = unboundArguments.iterator().next(); + } else { + this.bodyArgument = null; + if (LOG.isErrorEnabled()) { + LOG.error("WebSocket @OnMessage method " + webSocketBean.getTarget() + "." + messageHandler.getExecutableMethod() + " should define exactly 1 message parameter, but found 2 possible candidates: " + unboundArguments); + } + + if (serverSession.isOpen()) { + serverSession.close(CloseReason.INTERNAL_ERROR); + } + } + } else { + this.bodyArgument = null; + } + + if (pongHandler != null) { + BoundExecutable bound = binder.tryBind(pongHandler.getExecutableMethod(), webSocketBinder, new WebSocketState(serverSession, originatingRequest)); + List> unboundArguments = bound.getUnboundArguments(); + if (unboundArguments.size() == 1 && unboundArguments.get(0).isAssignableFrom(WebSocketPongMessage.class)) { + this.pongArgument = unboundArguments.get(0); + } else { + this.pongArgument = null; + if (LOG.isErrorEnabled()) { + LOG.error("WebSocket @OnMessage pong handler method " + webSocketBean.getTarget() + "." + pongHandler.getExecutableMethod() + " should define exactly 1 message parameter assignable from a WebSocketPongMessage, but found: " + unboundArguments); + } + + if (serverSession.isOpen()) { + serverSession.close(CloseReason.INTERNAL_ERROR); + } + } + } else { + this.pongArgument = null; + } + this.nettyEmbeddedServices = nettyEmbeddedServices; this.coroutineHelper = coroutineHelper; request.setAttribute(HttpAttributes.ROUTE_MATCH, routeMatch); request.setAttribute(HttpAttributes.ROUTE, routeMatch.getRoute()); - callOpenMethod(ctx); + Flux.from(callOpenMethod(ctx)).subscribe(v -> { }, t -> { + forwardErrorToUser(ctx, e -> { + if (LOG.isErrorEnabled()) { + LOG.error("Error Opening WebSocket [" + webSocketBean + "]: " + e.getMessage(), e); + } + }, t); + }); ApplicationEventPublisher eventPublisher = nettyEmbeddedServices.getEventPublisher(WebSocketSessionOpenEvent.class); try { - eventPublisher.publishEvent(new WebSocketSessionOpenEvent(session)); + eventPublisher.publishEvent(new WebSocketSessionOpenEvent(serverSession)); } catch (Exception e) { if (LOG.isErrorEnabled()) { LOG.error("Error publishing WebSocket opened event: " + e.getMessage(), e); @@ -132,6 +191,21 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler { } } + @Override + public NettyWebSocketSession getSession() { + return serverSession; + } + + @Override + public Argument getBodyArgument() { + return bodyArgument; + } + + @Override + public Argument getPongArgument() { + return pongArgument; + } + @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof IdleStateEvent) { @@ -276,11 +350,11 @@ private Object invokeExecutable0(BoundExecutable boundExecutable, MethodExecutio } @Override - protected void messageHandled(ChannelHandlerContext ctx, NettyWebSocketSession session, Object message) { + protected void messageHandled(ChannelHandlerContext ctx, Object message) { ctx.executor().execute(() -> { try { nettyEmbeddedServices.getEventPublisher(WebSocketMessageProcessedEvent.class) - .publishEvent(new WebSocketMessageProcessedEvent<>(session, message)); + .publishEvent(new WebSocketMessageProcessedEvent<>(getSession(), message)); } catch (Exception e) { if (LOG.isErrorEnabled()) { LOG.error("Error publishing WebSocket message processed event: " + e.getMessage(), e); @@ -294,12 +368,12 @@ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { Channel channel = ctx.channel(); channel.attr(NettyWebSocketSession.WEB_SOCKET_SESSION_KEY).set(null); if (LOG.isDebugEnabled()) { - LOG.debug("Removing WebSocket Server session: " + session); + LOG.debug("Removing WebSocket Server session: " + serverSession); } webSocketSessionRepository.removeChannel(channel); try { nettyEmbeddedServices.getEventPublisher(WebSocketSessionClosedEvent.class) - .publishEvent(new WebSocketSessionClosedEvent(session)); + .publishEvent(new WebSocketSessionClosedEvent(serverSession)); } catch (Exception e) { if (LOG.isErrorEnabled()) { LOG.error("Error publishing WebSocket closed event: " + e.getMessage(), e);