Skip to content

Commit

Permalink
Support ExecuteOn annotation with ServerWebSocket (#10772)
Browse files Browse the repository at this point in the history
NettyServerWebSocketHandler is updated to check for the ExecuteOn
annotation when invoking any of the callback methods on a
ServerWebSocket annotated class, and to use the specified
ExecutorService when invoking the methods.

A test is added to verify the enhanced behavior.

This resolves #10758
  • Loading branch information
jeremyg484 committed May 2, 2024
1 parent e9554c6 commit 86914a7
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 13 deletions.
Expand Up @@ -220,7 +220,7 @@ protected final void forwardErrorToUser(ChannelHandlerContext ctx, Consumer<Thro
Object target = errorMethod.getTarget();
Object result;
try {
result = boundExecutable.invoke(target);
result = invokeExecutable(boundExecutable, errorMethod);
} catch (Exception e) {

if (LOG.isErrorEnabled()) {
Expand All @@ -230,8 +230,8 @@ protected final void forwardErrorToUser(ChannelHandlerContext ctx, Consumer<Thro
return;
}
if (Publishers.isConvertibleToPublisher(result)) {
Flux<?> flowable = Flux.from(instrumentPublisher(ctx, result));
flowable.collectList().subscribe(objects -> fallback.accept(cause), throwable -> {
Mono<?> unhandled = Mono.from(instrumentPublisher(ctx, result));
unhandled.subscribe(unhandledResult -> fallback.accept(cause), throwable -> {
if (throwable != null && LOG.isErrorEnabled()) {
LOG.error("Error subscribing to @OnError handler {}.{}: {}", target.getClass().getSimpleName(), errorMethod.getExecutableMethod(), throwable.getMessage(), throwable);
}
Expand Down
Expand Up @@ -26,6 +26,7 @@
import io.micronaut.core.propagation.PropagatedContext;
import io.micronaut.core.type.Argument;
import io.micronaut.core.type.Executable;
import io.micronaut.core.type.ReturnType;
import io.micronaut.core.util.KotlinUtils;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpRequest;
Expand All @@ -38,6 +39,8 @@
import io.micronaut.http.server.netty.NettyEmbeddedServices;
import io.micronaut.inject.ExecutableMethod;
import io.micronaut.inject.MethodExecutionHandle;
import io.micronaut.scheduling.executor.ExecutorSelector;
import io.micronaut.scheduling.executor.ThreadSelection;
import io.micronaut.web.router.UriRouteMatch;
import io.micronaut.websocket.CloseReason;
import io.micronaut.websocket.WebSocketPongMessage;
Expand Down Expand Up @@ -65,6 +68,7 @@
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -91,6 +95,8 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {

private final Argument<?> bodyArgument;
private final Argument<?> pongArgument;
private final ThreadSelection threadSelection;
private final ExecutorSelector executorSelector;

/**
* Default constructor.
Expand All @@ -102,17 +108,20 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
* @param request The request used to create the websocket
* @param routeMatch The route match
* @param ctx The channel handler context
* @param executorSelector
* @param coroutineHelper Helper for kotlin coroutines
*/
NettyServerWebSocketHandler(
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
@Nullable CoroutineHelper coroutineHelper) {
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
ThreadSelection threadSelection,
ExecutorSelector executorSelector,
@Nullable CoroutineHelper coroutineHelper) {
super(
ctx,
nettyEmbeddedServices.getRequestArgumentSatisfier().getBinderRegistry(),
Expand All @@ -125,6 +134,9 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
webSocketSessionRepository,
nettyEmbeddedServices.getApplicationContext().getConversionService());

this.threadSelection = threadSelection;
this.executorSelector = executorSelector;

this.serverSession = createWebSocketSession(ctx);

ExecutableBinder<WebSocketState> binder = new DefaultExecutableBinder<>();
Expand Down Expand Up @@ -345,8 +357,27 @@ protected Object invokeExecutable(BoundExecutable boundExecutable, MethodExecuti
}

private Object invokeExecutable0(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
return this.executorSelector.select(messageHandler.getExecutableMethod(), threadSelection)
.map(
executorService -> {
ReturnType<?> returnType = messageHandler.getExecutableMethod().getReturnType();
Mono<?> result;
if (returnType.isReactive()) {
result = Mono.from((Publisher<?>) boundExecutable.invoke(messageHandler.getTarget()))
.contextWrite(reactorContext -> reactorContext.put(ServerRequestContext.KEY, originatingRequest));;
} else if (returnType.isAsync()) {
result = Mono.fromFuture((Supplier<CompletableFuture<?>>) invokeWithContext(boundExecutable, messageHandler));
} else {
result = Mono.fromSupplier(invokeWithContext(boundExecutable, messageHandler));
}
return (Object) result.subscribeOn(Schedulers.fromExecutor(executorService));
}
).orElseGet(invokeWithContext(boundExecutable, messageHandler));
}

private Supplier<?> invokeWithContext(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return () -> ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
}

@Override
Expand Down
Expand Up @@ -199,6 +199,8 @@ private void writeResponse(ChannelHandlerContext ctx,
msg,
routeMatch,
ctx,
serverConfiguration.getThreadSelection(),
routeExecutor.getExecutorSelector(),
routeExecutor.getCoroutineHelper().orElse(null));
pipeline.addBefore(ctx.name(), NettyServerWebSocketHandler.ID, webSocketHandler);

Expand Down

0 comments on commit 86914a7

Please sign in to comment.