Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ExecuteOn annotation with ServerWebSocket #10772

Merged
merged 4 commits into from May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -26,6 +26,7 @@
import io.micronaut.core.propagation.PropagatedContext;
import io.micronaut.core.type.Argument;
import io.micronaut.core.type.Executable;
import io.micronaut.core.type.ReturnType;
import io.micronaut.core.util.KotlinUtils;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpRequest;
Expand All @@ -38,6 +39,8 @@
import io.micronaut.http.server.netty.NettyEmbeddedServices;
import io.micronaut.inject.ExecutableMethod;
import io.micronaut.inject.MethodExecutionHandle;
import io.micronaut.scheduling.executor.ExecutorSelector;
import io.micronaut.scheduling.executor.ThreadSelection;
import io.micronaut.web.router.UriRouteMatch;
import io.micronaut.websocket.CloseReason;
import io.micronaut.websocket.WebSocketPongMessage;
Expand Down Expand Up @@ -91,6 +94,8 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {

private final Argument<?> bodyArgument;
private final Argument<?> pongArgument;
private final ThreadSelection threadSelection;
private final ExecutorSelector executorSelector;

/**
* Default constructor.
Expand All @@ -102,17 +107,20 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
* @param request The request used to create the websocket
* @param routeMatch The route match
* @param ctx The channel handler context
* @param executorSelector
* @param coroutineHelper Helper for kotlin coroutines
*/
NettyServerWebSocketHandler(
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
@Nullable CoroutineHelper coroutineHelper) {
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
ThreadSelection threadSelection,
ExecutorSelector executorSelector,
@Nullable CoroutineHelper coroutineHelper) {
super(
ctx,
nettyEmbeddedServices.getRequestArgumentSatisfier().getBinderRegistry(),
Expand All @@ -125,6 +133,9 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
webSocketSessionRepository,
nettyEmbeddedServices.getApplicationContext().getConversionService());

this.threadSelection = threadSelection;
this.executorSelector = executorSelector;

this.serverSession = createWebSocketSession(ctx);

ExecutableBinder<WebSocketState> binder = new DefaultExecutableBinder<>();
Expand Down Expand Up @@ -345,8 +356,25 @@ protected Object invokeExecutable(BoundExecutable boundExecutable, MethodExecuti
}

private Object invokeExecutable0(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
return this.executorSelector.select(messageHandler.getExecutableMethod(), threadSelection)
.map(
executorService -> {
ReturnType<?> returnType = messageHandler.getExecutableMethod().getReturnType();
if (returnType.isReactive()) {
return Mono.from((Publisher<?>) boundExecutable.invoke(messageHandler.getTarget()))
.subscribeOn(Schedulers.fromExecutor(executorService))
.contextWrite(reactorContext -> reactorContext.put(ServerRequestContext.KEY, originatingRequest));
} else {
return executorService.submit(() -> ServerRequestContext.with(originatingRequest,
jeremyg484 marked this conversation as resolved.
Show resolved Hide resolved
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget())));
}
}
).orElseGet(invokeWithContext(boundExecutable, messageHandler));
}

private Supplier<Object> invokeWithContext(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return () -> ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
}

@Override
Expand Down
Expand Up @@ -199,6 +199,8 @@ private void writeResponse(ChannelHandlerContext ctx,
msg,
routeMatch,
ctx,
serverConfiguration.getThreadSelection(),
routeExecutor.getExecutorSelector(),
routeExecutor.getCoroutineHelper().orElse(null));
pipeline.addBefore(ctx.name(), NettyServerWebSocketHandler.ID, webSocketHandler);

Expand Down
@@ -0,0 +1,214 @@
package io.micronaut.websocket

import io.micronaut.context.annotation.Property
import io.micronaut.context.annotation.Requires
import io.micronaut.runtime.server.EmbeddedServer
import io.micronaut.scheduling.LoomSupport
import io.micronaut.scheduling.TaskExecutors
import io.micronaut.scheduling.annotation.ExecuteOn
import io.micronaut.test.extensions.spock.annotation.MicronautTest
import io.micronaut.websocket.annotation.*
import jakarta.inject.Inject
import org.reactivestreams.Publisher
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import spock.lang.Specification
import spock.lang.Unroll
import spock.util.concurrent.PollingConditions

import java.util.concurrent.Future
import java.util.function.Predicate
import java.util.function.Supplier
import java.util.stream.Collectors

@Property(name = "spec.name", value = "WebsocketExecuteOnSpec")
@MicronautTest
class WebsocketExecuteOnSpec extends Specification {

static final Logger LOG = LoggerFactory.getLogger(WebsocketExecuteOnSpec.class)

@Inject
EmbeddedServer embeddedServer

@Unroll
void "#type websocket server methods can run outside of the event loop with ExecuteOn"() {
given:
WebSocketClient wsClient = embeddedServer.applicationContext.createBean(WebSocketClient.class, embeddedServer.getURL())
String threadName = (LoomSupport.isSupported() ? "virtual" : TaskExecutors.IO) + "-executor"
String expectedJoined = "joined on thread " + threadName
String expectedEcho = "Hello from thread " + threadName

expect:
wsClient

when:
EchoClientWebSocket echoClientWebSocket = Flux.from(wsClient.connect(EchoClientWebSocket, "/echo/${type}")).blockFirst()

then:
noExceptionThrown()
new PollingConditions().eventually {
echoClientWebSocket.receivedMessages() == [expectedJoined]
}

when:
echoClientWebSocket.send('Hello')

then:
new PollingConditions().eventually {
echoClientWebSocket.receivedMessages() == [expectedJoined, expectedEcho]
}

cleanup:
echoClientWebSocket.close()

where:
type | _
"sync" | _
"reactive" | _
"async" | _
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ServerWebSocket("/echo/sync")
@ExecuteOn(TaskExecutors.BLOCKING)
static class SynchronousEchoServerWebSocket {
public static final String JOINED = "joined on thread %s"
public static final String DISCONNECTED = "disconnected on thread %s"
public static final String ECHO = "%s from thread %s"

@Inject
WebSocketBroadcaster broadcaster

@OnOpen
void onOpen(WebSocketSession session) {
broadcaster.broadcastSync(JOINED.formatted(Thread.currentThread().getName()), isValid(session))
}

@OnMessage
void onMessage(String message, WebSocketSession session) {
broadcaster.broadcastSync(ECHO.formatted(message, Thread.currentThread().getName()), isValid(session))
}

@OnClose
void onClose(WebSocketSession session) {
broadcaster.broadcastSync(DISCONNECTED.formatted(Thread.currentThread().getName()), isValid(session))
}

private static Predicate<WebSocketSession> isValid(WebSocketSession session) {
return { s -> s == session }
}
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ServerWebSocket("/echo/reactive")
@ExecuteOn(TaskExecutors.BLOCKING)
static class ReactiveEchoServerWebSocket {
public static final String JOINED = "joined on thread %s"
public static final String DISCONNECTED = "disconnected on thread %s"
public static final String ECHO = " from thread %s"

@Inject
WebSocketBroadcaster broadcaster

Supplier<String> formatMessage(String message) {
() -> message.formatted(Thread.currentThread().getName())
}

@OnOpen
Publisher<String> onOpen(WebSocketSession session) {
Mono.fromSupplier(formatMessage(JOINED))
.flatMap(message -> Mono.from(broadcaster.broadcast(message)))
}

@OnMessage
Publisher<String> onMessage(String message, WebSocketSession session) {
Mono.fromSupplier(formatMessage(message + ECHO))
.flatMap(m -> Mono.from(broadcaster.broadcast(m)))
}

@OnClose
Publisher<String> onClose(WebSocketSession session) {
Mono.just(session)
.flatMap(s -> {
LOG.info(DISCONNECTED.formatted(Thread.currentThread().getName()))
return Mono.just("closed")
})
}
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ServerWebSocket("/echo/async")
@ExecuteOn(TaskExecutors.BLOCKING)
static class AsyncEchoServerWebSocket {
public static final String JOINED = "joined on thread %s"
public static final String DISCONNECTED = "disconnected on thread %s"
public static final String ECHO = " from thread %s"

@Inject
WebSocketBroadcaster broadcaster

Supplier<String> formatMessage(String message) {
() -> message.formatted(Thread.currentThread().getName())
}

@OnOpen
Future<String> onOpen(WebSocketSession session) {
Mono.fromSupplier(formatMessage(JOINED))
.flatMap(message -> Mono.from(broadcaster.broadcast(message))).toFuture();
}

@OnMessage
Future<String> onMessage(String message, WebSocketSession session) {
Mono.fromSupplier(formatMessage(message + ECHO))
.flatMap(m -> Mono.from(broadcaster.broadcast(m))).toFuture()
}

@OnClose
Future<String> onClose(WebSocketSession session) {
Mono.just(session)
.flatMap(s -> {
LOG.info(DISCONNECTED.formatted(Thread.currentThread().getName()))
return Mono.just("closed")
}).toFuture()
}
}

@Requires(property = "spec.name", value = "WebsocketExecuteOnSpec")
@ClientWebSocket
static abstract class EchoClientWebSocket implements AutoCloseable {

static final String RECEIVED = "RECEIVED:"

private WebSocketSession session
private List<String> replies = new ArrayList<>()

@OnOpen
void onOpen(WebSocketSession session) {
this.session = session
}
List<String> getReplies() {
return replies
}

@OnMessage
void onMessage(String message) {
replies.add(RECEIVED + message)
}

abstract void send(String message)

List<String> receivedMessages() {
return filterMessagesByType(RECEIVED)
}

List<String> filterMessagesByType(String type) {
replies.stream()
.filter(str -> str.contains(type))
.map(str -> str.replaceAll(type, ""))
.map(str -> str.substring(0, str.length()-(1)).replace("-thread-", ""))
.collect(Collectors.toList())
}
}
}