From b88f2ed29631d564b673473ab2e7eb9e36eb77fc Mon Sep 17 00:00:00 2001 From: rsinukov Date: Fri, 2 Sep 2022 16:55:27 +0200 Subject: [PATCH] KTOR-578 KTOR-800 Fix Netty HTTP/2 --- .../ktor-server-netty/build.gradle.kts | 5 +++ .../server/netty/NettyChannelInitializer.kt | 10 ++++- .../server/netty/NettyHttpHandlerState.kt | 24 +++++++++++ .../netty/cio/NettyHttpResponsePipeline.kt | 21 ++++------ .../server/netty/http1/NettyHttp1Handler.kt | 42 +++++-------------- .../server/netty/http2/NettyHttp2Handler.kt | 38 +++++++++++++---- 6 files changed, 84 insertions(+), 56 deletions(-) create mode 100644 ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyHttpHandlerState.kt diff --git a/ktor-server/ktor-server-netty/build.gradle.kts b/ktor-server/ktor-server-netty/build.gradle.kts index 5dd9868d56..5b5e4c7608 100644 --- a/ktor-server/ktor-server-netty/build.gradle.kts +++ b/ktor-server/ktor-server-netty/build.gradle.kts @@ -43,3 +43,8 @@ kotlin.sourceSets { } } } + +val jvmTest: org.jetbrains.kotlin.gradle.targets.jvm.tasks.KotlinJvmTest by tasks +jvmTest.apply { + systemProperty("enable.http2", "true") +} diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyChannelInitializer.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyChannelInitializer.kt index 782bbc875e..6199822c4b 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyChannelInitializer.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyChannelInitializer.kt @@ -5,9 +5,9 @@ package io.ktor.server.netty import io.ktor.server.engine.* -import io.ktor.server.netty.cio.* import io.ktor.server.netty.http1.* import io.ktor.server.netty.http2.* +import io.ktor.util.logging.* import io.netty.channel.* import io.netty.channel.socket.SocketChannel import io.netty.handler.codec.http.* @@ -103,7 +103,13 @@ public class NettyChannelInitializer( private fun configurePipeline(pipeline: ChannelPipeline, protocol: String) { when (protocol) { ApplicationProtocolNames.HTTP_2 -> { - val handler = NettyHttp2Handler(enginePipeline, environment.application, callEventGroup, userContext) + val handler = NettyHttp2Handler( + enginePipeline, + environment.application, + callEventGroup, + userContext, + runningLimit + ) @Suppress("DEPRECATION") pipeline.addLast(Http2MultiplexCodecBuilder.forServer(handler).build()) pipeline.channel().closeFuture().addListener { diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyHttpHandlerState.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyHttpHandlerState.kt new file mode 100644 index 0000000000..ac8cc1b37f --- /dev/null +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyHttpHandlerState.kt @@ -0,0 +1,24 @@ +/* + * Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.netty + +import io.netty.channel.* +import kotlinx.atomicfu.* + +internal class NettyHttpHandlerState(private val runningLimit: Int) { + + internal val activeRequests: AtomicLong = atomic(0L) + internal val isCurrentRequestFullyRead: AtomicBoolean = atomic(false) + internal val isChannelReadCompleted: AtomicBoolean = atomic(false) + internal val skippedRead: AtomicBoolean = atomic(false) + + internal fun onLastResponseMessage(context: ChannelHandlerContext) { + activeRequests.decrementAndGet() + + if (skippedRead.compareAndSet(expect = false, update = true) && activeRequests.value < runningLimit) { + context.read() + } + } +} diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/cio/NettyHttpResponsePipeline.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/cio/NettyHttpResponsePipeline.kt index c78efa5ea5..73a3d3c7ee 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/cio/NettyHttpResponsePipeline.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/cio/NettyHttpResponsePipeline.kt @@ -22,16 +22,10 @@ private const val UNFLUSHED_LIMIT = 65536 /** * Contains methods for handling http request with Netty - * @param context - * @param coroutineContext - * @param activeRequests - * @param isCurrentRequestFullyRead - * @param isChannelReadCompleted */ -@OptIn(InternalAPI::class) internal class NettyHttpResponsePipeline constructor( private val context: ChannelHandlerContext, - private val httpHandler: NettyHttp1Handler, + private val httpHandlerState: NettyHttpHandlerState, override val coroutineContext: CoroutineContext ) : CoroutineScope { /** @@ -56,8 +50,8 @@ internal class NettyHttpResponsePipeline constructor( internal fun flushIfNeeded() { if ( isDataNotFlushed.value && - httpHandler.isChannelReadCompleted.value && - httpHandler.activeRequests.value == 0L + httpHandlerState.isChannelReadCompleted.value && + httpHandlerState.activeRequests.value == 0L ) { context.flush() isDataNotFlushed.compareAndSet(expect = true, update = false) @@ -145,7 +139,7 @@ internal class NettyHttpResponsePipeline constructor( null } - httpHandler.onLastResponseMessage(context) + httpHandlerState.onLastResponseMessage(context) call.finishedEvent.setSuccess() lastMessageFuture?.addListener { @@ -232,9 +226,9 @@ internal class NettyHttpResponsePipeline constructor( * True if client is waiting for response header, false otherwise */ private fun isHeaderFlushNeeded(): Boolean { - val activeRequestsValue = httpHandler.activeRequests.value - return httpHandler.isChannelReadCompleted.value && - !httpHandler.isCurrentRequestFullyRead.value && + val activeRequestsValue = httpHandlerState.activeRequests.value + return httpHandlerState.isChannelReadCompleted.value && + !httpHandlerState.isCurrentRequestFullyRead.value && activeRequestsValue == 1L } @@ -365,7 +359,6 @@ internal class NettyHttpResponsePipeline constructor( } } -@OptIn(InternalAPI::class) private fun NettyApplicationResponse.isUpgradeResponse() = status()?.value == HttpStatusCode.SwitchingProtocols.value diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt index 00a5eff906..396be4bbdd 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt @@ -31,29 +31,15 @@ internal class NettyHttp1Handler( override val coroutineContext: CoroutineContext get() = handlerJob private var skipEmpty = false - private val skippedRead: AtomicBoolean = atomic(false) private lateinit var responseWriter: NettyHttpResponsePipeline - /** - * Represents current number of processing requests - */ - internal val activeRequests: AtomicLong = atomic(0L) - - /** - * True if current request's last http content is read, false otherwise. - */ - internal val isCurrentRequestFullyRead: AtomicBoolean = atomic(false) - - /** - * True if [channelReadComplete] was invoked for the current request, false otherwise - */ - internal val isChannelReadCompleted: AtomicBoolean = atomic(false) + private val state = NettyHttpHandlerState(runningLimit) override fun channelActive(context: ChannelHandlerContext) { responseWriter = NettyHttpResponsePipeline( context, - this, + state, coroutineContext ) @@ -68,16 +54,16 @@ internal class NettyHttp1Handler( override fun channelRead(context: ChannelHandlerContext, message: Any) { if (message is LastHttpContent) { - isCurrentRequestFullyRead.compareAndSet(expect = false, update = true) + state.isCurrentRequestFullyRead.compareAndSet(expect = false, update = true) } when { message is HttpRequest -> { if (message !is LastHttpContent) { - isCurrentRequestFullyRead.compareAndSet(expect = true, update = false) + state.isCurrentRequestFullyRead.compareAndSet(expect = true, update = false) } - isChannelReadCompleted.compareAndSet(expect = true, update = false) - activeRequests.incrementAndGet() + state.isChannelReadCompleted.compareAndSet(expect = true, update = false) + state.activeRequests.incrementAndGet() handleRequest(context, message) callReadIfNeeded(context) @@ -110,7 +96,7 @@ internal class NettyHttp1Handler( } override fun channelReadComplete(context: ChannelHandlerContext?) { - isChannelReadCompleted.compareAndSet(expect = false, update = true) + state.isChannelReadCompleted.compareAndSet(expect = false, update = true) responseWriter.flushIfNeeded() super.channelReadComplete(context) } @@ -165,19 +151,11 @@ internal class NettyHttp1Handler( } private fun callReadIfNeeded(context: ChannelHandlerContext) { - if (activeRequests.value < runningLimit) { + if (state.activeRequests.value < runningLimit) { context.read() - skippedRead.value = false + state.skippedRead.value = false } else { - skippedRead.value = true - } - } - - internal fun onLastResponseMessage(context: ChannelHandlerContext) { - activeRequests.decrementAndGet() - - if (skippedRead.compareAndSet(expect = false, update = true) && activeRequests.value < runningLimit) { - context.read() + state.skippedRead.value = true } } } diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt index 1ac5cabb9c..915eb9120f 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt @@ -8,11 +8,11 @@ import io.ktor.http.* import io.ktor.server.application.* import io.ktor.server.engine.* import io.ktor.server.netty.* +import io.ktor.server.netty.cio.* import io.ktor.server.response.* -import io.ktor.util.* import io.netty.channel.* import io.netty.handler.codec.http2.* -import io.netty.util.AttributeKey +import io.netty.util.* import io.netty.util.concurrent.* import kotlinx.coroutines.* import java.lang.reflect.* @@ -24,16 +24,22 @@ internal class NettyHttp2Handler( private val enginePipeline: EnginePipeline, private val application: Application, private val callEventGroup: EventExecutorGroup, - private val userCoroutineContext: CoroutineContext + private val userCoroutineContext: CoroutineContext, + runningLimit: Int ) : ChannelInboundHandlerAdapter(), CoroutineScope { private val handlerJob = SupervisorJob(userCoroutineContext[Job]) + private val state = NettyHttpHandlerState(runningLimit) + private lateinit var responseWriter: NettyHttpResponsePipeline + override val coroutineContext: CoroutineContext get() = handlerJob - override fun channelRead(context: ChannelHandlerContext, message: Any?) { + override fun channelRead(context: ChannelHandlerContext, message: Any) { when (message) { is Http2HeadersFrame -> { + state.isChannelReadCompleted.compareAndSet(expect = true, update = false) + state.activeRequests.incrementAndGet() startHttp2(context, message.headers()) } is Http2DataFrame -> { @@ -42,6 +48,9 @@ internal class NettyHttp2Handler( contentActor.trySend(message).isSuccess if (eof) { contentActor.close() + state.isCurrentRequestFullyRead.compareAndSet(expect = false, update = true) + } else { + state.isCurrentRequestFullyRead.compareAndSet(expect = true, update = false) } } ?: message.release() } @@ -55,12 +64,23 @@ internal class NettyHttp2Handler( } } - override fun channelRegistered(ctx: ChannelHandlerContext?) { - super.channelRegistered(ctx) + override fun channelActive(context: ChannelHandlerContext) { + responseWriter = NettyHttpResponsePipeline( + context, + state, + coroutineContext + ) - ctx?.pipeline()?.apply { + context.pipeline()?.apply { addLast(callEventGroup, NettyApplicationCallHandler(userCoroutineContext, enginePipeline)) } + context.fireChannelActive() + } + + override fun channelReadComplete(context: ChannelHandlerContext) { + state.isChannelReadCompleted.compareAndSet(expect = false, update = true) + responseWriter.flushIfNeeded() + context.fireChannelReadComplete() } @Suppress("OverridingDeprecatedMember") @@ -68,7 +88,6 @@ internal class NettyHttp2Handler( ctx.close() } - @OptIn(InternalAPI::class) private fun startHttp2(context: ChannelHandlerContext, headers: Http2Headers) { val call = NettyHttp2ApplicationCall( application, @@ -79,6 +98,9 @@ internal class NettyHttp2Handler( userCoroutineContext ) context.applicationCall = call + + context.fireChannelRead(call) + responseWriter.processResponse(call) } @Suppress("DEPRECATION")