Skip to content

Commit

Permalink
Merge branch '4.5.x' into 5.0.x
Browse files Browse the repository at this point in the history
  • Loading branch information
sdelamo committed May 3, 2024
2 parents e8e3211 + 4521d51 commit 6ef098e
Show file tree
Hide file tree
Showing 23 changed files with 706 additions and 32 deletions.
Expand Up @@ -239,7 +239,7 @@ public void collect(Collection<S> values) {

@Override
public void collect(Collection<S> values, boolean allowFork) {
if (allowFork) {
if (allowFork && ForkJoinPool.getCommonPoolParallelism() > 1) {
ForkJoinPool.commonPool().invoke(this);
for (RecursiveActionValuesCollector<S> task : tasks) {
task.join();
Expand Down
Expand Up @@ -220,7 +220,7 @@ protected final void forwardErrorToUser(ChannelHandlerContext ctx, Consumer<Thro
Object target = errorMethod.getTarget();
Object result;
try {
result = boundExecutable.invoke(target);
result = invokeExecutable(boundExecutable, errorMethod);
} catch (Exception e) {

if (LOG.isErrorEnabled()) {
Expand All @@ -230,8 +230,8 @@ protected final void forwardErrorToUser(ChannelHandlerContext ctx, Consumer<Thro
return;
}
if (Publishers.isConvertibleToPublisher(result)) {
Flux<?> flowable = Flux.from(instrumentPublisher(ctx, result));
flowable.collectList().subscribe(objects -> fallback.accept(cause), throwable -> {
Mono<?> unhandled = Mono.from(instrumentPublisher(ctx, result));
unhandled.subscribe(unhandledResult -> fallback.accept(cause), throwable -> {
if (throwable != null && LOG.isErrorEnabled()) {
LOG.error("Error subscribing to @OnError handler {}.{}: {}", target.getClass().getSimpleName(), errorMethod.getExecutableMethod(), throwable.getMessage(), throwable);
}
Expand Down
Expand Up @@ -506,12 +506,22 @@ void configureForH2cSupport() {

final Http2FrameCodec frameCodec;
final Http2ConnectionHandler connectionHandler;
Http2MultiplexHandler multiplexHandler;
if (server.getServerConfiguration().isLegacyMultiplexHandlers()) {
frameCodec = createHttp2FrameCodec();
connectionHandler = frameCodec;
multiplexHandler = new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
@Override
protected void initChannel(@NonNull Http2StreamChannel ch) {
StreamPipeline streamPipeline = new StreamPipeline(ch, sslHandler, connectionCustomizer.specializeForChannel(ch, NettyServerCustomizer.ChannelRole.REQUEST_STREAM));
streamPipeline.insertHttp2FrameHandlers();
streamPipeline.streamCustomizer.onStreamPipelineBuilt();
}
});
} else {
connectionHandler = createHttp2ServerHandler(false);
frameCodec = null;
multiplexHandler = null;
}
final String fallbackHandlerName = "http1-fallback-handler";
HttpServerUpgradeHandler.UpgradeCodecFactory upgradeCodecFactory = protocol -> {
Expand All @@ -537,14 +547,7 @@ public void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest)
if (frameCodec == null) {
return new Http2ServerUpgradeCodecImpl(connectionHandler);
} else {
return new Http2ServerUpgradeCodecImpl(frameCodec, new Http2MultiplexHandler(new ChannelInitializer<Http2StreamChannel>() {
@Override
protected void initChannel(@NonNull Http2StreamChannel ch) {
StreamPipeline streamPipeline = new StreamPipeline(ch, sslHandler, connectionCustomizer.specializeForChannel(ch, NettyServerCustomizer.ChannelRole.REQUEST_STREAM));
streamPipeline.insertHttp2FrameHandlers();
streamPipeline.streamCustomizer.onStreamPipelineBuilt();
}
}));
return new Http2ServerUpgradeCodecImpl(frameCodec, multiplexHandler);
}
} else {
return null;
Expand All @@ -557,8 +560,14 @@ protected void initChannel(@NonNull Http2StreamChannel ch) {
upgradeCodecFactory,
server.getServerConfiguration().getMaxH2cUpgradeRequestSize()
);
ChannelHandler priorKnowledgeHandler = frameCodec == null ? connectionHandler : new ChannelInitializer<>() {
@Override
protected void initChannel(@NonNull Channel ch) {
ch.pipeline().addLast(connectionHandler, multiplexHandler);
}
};
final CleartextHttp2ServerUpgradeHandler cleartextHttp2ServerUpgradeHandler =
new CleartextHttp2ServerUpgradeHandler(sourceCodec, upgradeHandler, connectionHandler);
new CleartextHttp2ServerUpgradeHandler(sourceCodec, upgradeHandler, priorKnowledgeHandler);

pipeline.addLast(cleartextHttp2ServerUpgradeHandler);
pipeline.addLast(fallbackHandlerName, new SimpleChannelInboundHandler<HttpMessage>() {
Expand Down
Expand Up @@ -295,7 +295,8 @@ public HttpVersion getHttpVersion() {
if (pipeline != null) {
return pipeline.httpVersion;
}
return HttpVersion.HTTP_1_1;
// Http2ServerHandler case
return findConnectionHandler() == null ? HttpVersion.HTTP_1_1 : HttpVersion.HTTP_2_0;
}

@Override
Expand Down
Expand Up @@ -26,6 +26,7 @@
import io.micronaut.core.propagation.PropagatedContext;
import io.micronaut.core.type.Argument;
import io.micronaut.core.type.Executable;
import io.micronaut.core.type.ReturnType;
import io.micronaut.core.util.KotlinUtils;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpRequest;
Expand All @@ -38,6 +39,8 @@
import io.micronaut.http.server.netty.NettyEmbeddedServices;
import io.micronaut.inject.ExecutableMethod;
import io.micronaut.inject.MethodExecutionHandle;
import io.micronaut.scheduling.executor.ExecutorSelector;
import io.micronaut.scheduling.executor.ThreadSelection;
import io.micronaut.web.router.UriRouteMatch;
import io.micronaut.websocket.CloseReason;
import io.micronaut.websocket.WebSocketPongMessage;
Expand Down Expand Up @@ -65,6 +68,7 @@
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -91,6 +95,8 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {

private final Argument<?> bodyArgument;
private final Argument<?> pongArgument;
private final ThreadSelection threadSelection;
private final ExecutorSelector executorSelector;

/**
* Default constructor.
Expand All @@ -102,17 +108,20 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
* @param request The request used to create the websocket
* @param routeMatch The route match
* @param ctx The channel handler context
* @param executorSelector
* @param coroutineHelper Helper for kotlin coroutines
*/
NettyServerWebSocketHandler(
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
@Nullable CoroutineHelper coroutineHelper) {
NettyEmbeddedServices nettyEmbeddedServices,
WebSocketSessionRepository webSocketSessionRepository,
WebSocketServerHandshaker handshaker,
WebSocketBean<?> webSocketBean,
HttpRequest<?> request,
UriRouteMatch<Object, Object> routeMatch,
ChannelHandlerContext ctx,
ThreadSelection threadSelection,
ExecutorSelector executorSelector,
@Nullable CoroutineHelper coroutineHelper) {
super(
ctx,
nettyEmbeddedServices.getRequestArgumentSatisfier().getBinderRegistry(),
Expand All @@ -125,6 +134,9 @@ public class NettyServerWebSocketHandler extends AbstractNettyWebSocketHandler {
webSocketSessionRepository,
nettyEmbeddedServices.getApplicationContext().getConversionService());

this.threadSelection = threadSelection;
this.executorSelector = executorSelector;

this.serverSession = createWebSocketSession(ctx);

ExecutableBinder<WebSocketState> binder = new DefaultExecutableBinder<>();
Expand Down Expand Up @@ -345,8 +357,27 @@ protected Object invokeExecutable(BoundExecutable boundExecutable, MethodExecuti
}

private Object invokeExecutable0(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
return this.executorSelector.select(messageHandler.getExecutableMethod(), threadSelection)
.map(
executorService -> {
ReturnType<?> returnType = messageHandler.getExecutableMethod().getReturnType();
Mono<?> result;
if (returnType.isReactive()) {
result = Mono.from((Publisher<?>) boundExecutable.invoke(messageHandler.getTarget()))
.contextWrite(reactorContext -> reactorContext.put(ServerRequestContext.KEY, originatingRequest));;
} else if (returnType.isAsync()) {
result = Mono.fromFuture((Supplier<CompletableFuture<?>>) invokeWithContext(boundExecutable, messageHandler));
} else {
result = Mono.fromSupplier(invokeWithContext(boundExecutable, messageHandler));
}
return (Object) result.subscribeOn(Schedulers.fromExecutor(executorService));
}
).orElseGet(invokeWithContext(boundExecutable, messageHandler));
}

private Supplier<?> invokeWithContext(BoundExecutable boundExecutable, MethodExecutionHandle<?, ?> messageHandler) {
return () -> ServerRequestContext.with(originatingRequest,
(Supplier<Object>) () -> boundExecutable.invoke(messageHandler.getTarget()));
}

@Override
Expand Down
Expand Up @@ -199,6 +199,8 @@ private void writeResponse(ChannelHandlerContext ctx,
msg,
routeMatch,
ctx,
serverConfiguration.getThreadSelection(),
routeExecutor.getExecutorSelector(),
routeExecutor.getCoroutineHelper().orElse(null));
pipeline.addBefore(ctx.name(), NettyServerWebSocketHandler.ID, webSocketHandler);

Expand Down
Expand Up @@ -9,6 +9,7 @@ class Http2CompressionSpec extends CompressionSpec {

'micronaut.server.http-version': '2.0',
'micronaut.server.ssl.enabled': true,
'micronaut.server.ssl.port': 0,
'micronaut.server.ssl.build-self-signed': true,
] as Map<String, Object>
}
Expand Down
Expand Up @@ -23,11 +23,11 @@ class ContextURISpec extends Specification {
void "test getContextURI returns the base URI when context path is not set"() {
when:
EmbeddedServer embeddedServer = ApplicationContext.run(EmbeddedServer, [
'micronaut.server.port': 60006
'micronaut.server.port': 60007
])

then:
embeddedServer.getContextURI().toString() == 'http://localhost:60006'
embeddedServer.getContextURI().toString() == 'http://localhost:60007'

cleanup:
embeddedServer.close()
Expand All @@ -37,11 +37,11 @@ class ContextURISpec extends Specification {
when:
EmbeddedServer embeddedServer = ApplicationContext.run(EmbeddedServer, [
'micronaut.server.context-path': '',
'micronaut.server.port': 60006
'micronaut.server.port': 60008
])

then:
embeddedServer.getContextURI().toString() == 'http://localhost:60006'
embeddedServer.getContextURI().toString() == 'http://localhost:60008'

cleanup:
embeddedServer.close()
Expand Down
Expand Up @@ -229,6 +229,72 @@ class H2cSpec extends Specification {
content.release()
}

def 'prior knowledge'() {
given:
def responseFuture = new CompletableFuture()

def group = new NioEventLoopGroup(1)
def bootstrap = new Bootstrap()
.remoteAddress(embeddedServer.host, embeddedServer.port)
.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(@NonNull SocketChannel ch) throws Exception {
def http2Connection = new DefaultHttp2Connection(false)
def inboundAdapter = new InboundHttp2ToHttpAdapterBuilder(http2Connection)
.maxContentLength(1000000)
.validateHttpHeaders(true)
.propagateSettings(true)
.build()
def connectionHandler = new HttpToHttp2ConnectionHandlerBuilder()
.connection(http2Connection)
.frameListener(new DelegatingDecompressorFrameListener(http2Connection, inboundAdapter))
.build()

ch.pipeline()
.addLast(connectionHandler)
.addLast(new ChannelInboundHandlerAdapter() {
@Override
void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) throws Exception {
ctx.read()
if (msg instanceof HttpMessage) {
if (msg.headers().getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), -1) != 3) {
responseFuture.completeExceptionally(new AssertionError("Response must be on stream 3"));
}
responseFuture.complete(ReferenceCountUtil.retain(msg))
}
super.channelRead(ctx, msg)
}

@Override
void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
super.exceptionCaught(ctx, cause)
cause.printStackTrace()
responseFuture.completeExceptionally(cause)
}
})

}
})

def channel = (SocketChannel) bootstrap.connect().await().channel()

def request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, '/h2c/test')
request.headers().set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http")
channel.writeAndFlush(request)
channel.read()

expect:
def resp = responseFuture.get(10, TimeUnit.SECONDS)
resp != null

cleanup:
channel.close()
resp.release()
group.shutdownGracefully()
}

@Controller("/h2c")
@Requires(property = "spec.name", value = "H2cSpec")
static class TestController {
Expand Down

0 comments on commit 6ef098e

Please sign in to comment.