Skip to content

Commit

Permalink
Support ExecutOn annotation with ServerWebSocket
Browse files Browse the repository at this point in the history
NettyServerWebSocketHandler is updated to check for the ExecuteOn
annotation when invoking any of the callback methods on a
ServerWebSocket annotated class, and to use the specified
ExecutorService when invoking the methods.

A test is added to verify the enhanced behavior.

This resolves #10758
  • Loading branch information
jeremyg484 committed Apr 27, 2024
1 parent 336907a commit b9c9e03
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 10 deletions.
Expand Up @@ -17,6 +17,7 @@

import io.micronaut.context.event.ApplicationEventPublisher;
import io.micronaut.core.annotation.Internal;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.async.publisher.Publishers;
import io.micronaut.core.bind.BoundExecutable;
Expand All @@ -26,6 +27,7 @@
import io.micronaut.core.propagation.PropagatedContext;
import io.micronaut.core.type.Argument;
import io.micronaut.core.type.Executable;
import io.micronaut.core.type.ReturnType;
import io.micronaut.core.util.KotlinUtils;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpRequest;
Expand All @@ -38,6 +40,8 @@
import io.micronaut.http.server.netty.NettyEmbeddedServices;
import io.micronaut.inject.ExecutableMethod;
import io.micronaut.inject.MethodExecutionHandle;
import io.micronaut.scheduling.executor.ExecutorSelector;
import io.micronaut.scheduling.executor.ThreadSelection;
import io.micronaut.web.router.UriRouteMatch;
import io.micronaut.websocket.CloseReason;
import io.micronaut.websocket.WebSocketPongMessage;
Expand Down Expand Up @@ -65,6 +69,8 @@
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -91,6 +97,8 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {

private final Argument<?> bodyArgument;
private final Argument<?> pongArgument;
private final ThreadSelection threadSelection;
private final ExecutorSelector executorSelector;

/**
* Default constructor.
Expand All @@ -102,17 +110,20 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
* @param request The request used to create the websocket
* @param routeMatch The route match
* @param ctx The channel handler context
* @param executorSelector
* @param coroutineHelper Helper for kotlin coroutines
*/
NettyServerWebSocketHandler(
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
@Nullable CoroutineHelper coroutineHelper) {
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
ThreadSelection threadSelection,
ExecutorSelector executorSelector,
@Nullable CoroutineHelper coroutineHelper) {
super(
ctx,
nettyEmbeddedServices.getRequestArgumentSatisfier().getBinderRegistry(),
Expand All @@ -125,6 +136,9 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
webSocketSessionRepository,
nettyEmbeddedServices.getApplicationContext().getConversionService());

this.threadSelection = threadSelection;
this.executorSelector = executorSelector;

this.serverSession = createWebSocketSession(ctx);

ExecutableBinder<WebSocketState> binder = new DefaultExecutableBinder<>();
Expand Down Expand Up @@ -345,8 +359,25 @@ protected Object invokeExecutable(BoundExecutable boundExecutable, MethodExecuti
}

private Object invokeExecutable0(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
return this.executorSelector.select(messageHandler.getExecutableMethod(), threadSelection)
.map(
executorService -> {
ReturnType<?> returnType = messageHandler.getExecutableMethod().getReturnType();
if (returnType.isReactive()) {
return Mono.from((Publisher<?>) boundExecutable.invoke(messageHandler.getTarget()))
.subscribeOn(Schedulers.fromExecutor(executorService))
.contextWrite(reactorContext -> reactorContext.put(ServerRequestContext.KEY, originatingRequest));
} else {
return executorService.submit(() -> ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget())));
}
}
).orElseGet(invokeWithContext(boundExecutable, messageHandler));
}

private Supplier<Object> invokeWithContext(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return () -> ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
}

@Override
Expand Down
Expand Up @@ -199,6 +199,8 @@ private void writeResponse(ChannelHandlerContext ctx,
msg,
routeMatch,
ctx,
serverConfiguration.getThreadSelection(),
routeExecutor.getExecutorSelector(),
routeExecutor.getCoroutineHelper().orElse(null));
pipeline.addBefore(ctx.name(), NettyServerWebSocketHandler.ID, webSocketHandler);

Expand Down
@@ -0,0 +1,214 @@
package io.micronaut.websocket

import io.micronaut.context.annotation.Property
import io.micronaut.context.annotation.Requires
import io.micronaut.runtime.server.EmbeddedServer
import io.micronaut.scheduling.LoomSupport
import io.micronaut.scheduling.TaskExecutors
import io.micronaut.scheduling.annotation.ExecuteOn
import io.micronaut.test.extensions.spock.annotation.MicronautTest
import io.micronaut.websocket.annotation.*
import jakarta.inject.Inject
import org.reactivestreams.Publisher
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import spock.lang.Specification
import spock.lang.Unroll
import spock.util.concurrent.PollingConditions

import java.util.concurrent.Future
import java.util.function.Predicate
import java.util.function.Supplier
import java.util.stream.Collectors

@Property(name = "spec.name", value = "WebsocketExecuteOnSpec")
@MicronautTest
class WebsocketExecuteOnSpec extends Specification {

static final Logger LOG = LoggerFactory.getLogger(WebsocketExecuteOnSpec.class)

@Inject
EmbeddedServer embeddedServer

@Unroll
void "#type websocket server methods can run outside of the event loop with ExecuteOn"() {
given:
WebSocketClient wsClient = embeddedServer.applicationContext.createBean(WebSocketClient.class, embeddedServer.getURL())
String threadName = (LoomSupport.isSupported() ? "virtual" : TaskExecutors.IO) + "-executor"
String expectedJoined = "joined on thread " + threadName
String expectedEcho = "Hello from thread " + threadName

expect:
wsClient

when:
EchoClientWebSocket echoClientWebSocket = Flux.from(wsClient.connect(EchoClientWebSocket, "/echo/${type}")).blockFirst()

then:
noExceptionThrown()
new PollingConditions().eventually {
echoClientWebSocket.receivedMessages() == [expectedJoined]
}

when:
echoClientWebSocket.send('Hello')

then:
new PollingConditions().eventually {
echoClientWebSocket.receivedMessages() == [expectedJoined, expectedEcho]
}

cleanup:
echoClientWebSocket.close()

where:
type | _
"sync" | _
"reactive" | _
"async" | _
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ServerWebSocket("/echo/sync")
@ExecuteOn(TaskExecutors.BLOCKING)
static class SynchronousEchoServerWebSocket {
public static final String JOINED = "joined on thread %s"
public static final String DISCONNECTED = "disconnected on thread %s"
public static final String ECHO = "%s from thread %s"

@Inject
WebSocketBroadcaster broadcaster

@OnOpen
void onOpen(WebSocketSession session) {
broadcaster.broadcastSync(JOINED.formatted(Thread.currentThread().getName()), isValid(session))
}

@OnMessage
void onMessage(String message, WebSocketSession session) {
broadcaster.broadcastSync(ECHO.formatted(message, Thread.currentThread().getName()), isValid(session))
}

@OnClose
void onClose(WebSocketSession session) {
broadcaster.broadcastSync(DISCONNECTED.formatted(Thread.currentThread().getName()), isValid(session))
}

private static Predicate<WebSocketSession> isValid(WebSocketSession session) {
return { s -> s == session }
}
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ServerWebSocket("/echo/reactive")
@ExecuteOn(TaskExecutors.BLOCKING)
static class ReactiveEchoServerWebSocket {
public static final String JOINED = "joined on thread %s"
public static final String DISCONNECTED = "disconnected on thread %s"
public static final String ECHO = " from thread %s"

@Inject
WebSocketBroadcaster broadcaster

Supplier<String> formatMessage(String message) {
() -> message.formatted(Thread.currentThread().getName())
}

@OnOpen
Publisher<String> onOpen(WebSocketSession session) {
Mono.fromSupplier(formatMessage(JOINED))
.flatMap(message -> Mono.from(broadcaster.broadcast(message)))
}

@OnMessage
Publisher<String> onMessage(String message, WebSocketSession session) {
Mono.fromSupplier(formatMessage(message + ECHO))
.flatMap(m -> Mono.from(broadcaster.broadcast(m)))
}

@OnClose
Publisher<String> onClose(WebSocketSession session) {
Mono.just(session)
.flatMap(s -> {
LOG.info(DISCONNECTED.formatted(Thread.currentThread().getName()))
return Mono.just("closed")
})
}
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ServerWebSocket("/echo/async")
@ExecuteOn(TaskExecutors.BLOCKING)
static class AsyncEchoServerWebSocket {
public static final String JOINED = "joined on thread %s"
public static final String DISCONNECTED = "disconnected on thread %s"
public static final String ECHO = " from thread %s"

@Inject
WebSocketBroadcaster broadcaster

Supplier<String> formatMessage(String message) {
() -> message.formatted(Thread.currentThread().getName())
}

@OnOpen
Future<String> onOpen(WebSocketSession session) {
Mono.fromSupplier(formatMessage(JOINED))
.flatMap(message -> Mono.from(broadcaster.broadcast(message))).toFuture();
}

@OnMessage
Future<String> onMessage(String message, WebSocketSession session) {
Mono.fromSupplier(formatMessage(message + ECHO))
.flatMap(m -> Mono.from(broadcaster.broadcast(m))).toFuture()
}

@OnClose
Future<String> onClose(WebSocketSession session) {
Mono.just(session)
.flatMap(s -> {
LOG.info(DISCONNECTED.formatted(Thread.currentThread().getName()))
return Mono.just("closed")
}).toFuture()
}
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ClientWebSocket
static abstract class EchoClientWebSocket implements AutoCloseable {

static final String RECEIVED = "RECEIVED:"

private WebSocketSession session
private List<String> replies = new ArrayList<>()

@OnOpen
void onOpen(WebSocketSession session) {
this.session = session
}
List<String> getReplies() {
return replies
}

@OnMessage
void onMessage(String message) {
replies.add(RECEIVED + message)
}

abstract void send(String message)

List<String> receivedMessages() {
return filterMessagesByType(RECEIVED)
}

List<String> filterMessagesByType(String type) {
replies.stream()
.filter(str -> str.contains(type))
.map(str -> str.replaceAll(type, ""))
.map(str -> str.substring(0, str.length()-(1)).replace("-thread-", ""))
.collect(Collectors.toList())
}
}
}

0 comments on commit b9c9e03

Please sign in to comment.