Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-578 KTOR-800 Fix Netty HTTP/2 #3152

Merged
merged 1 commit into from Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions ktor-server/ktor-server-netty/build.gradle.kts
Expand Up @@ -43,3 +43,8 @@ kotlin.sourceSets {
}
}
}

val jvmTest: org.jetbrains.kotlin.gradle.targets.jvm.tasks.KotlinJvmTest by tasks
jvmTest.apply {
systemProperty("enable.http2", "true")
}
Expand Up @@ -5,9 +5,9 @@
package io.ktor.server.netty

import io.ktor.server.engine.*
import io.ktor.server.netty.cio.*
import io.ktor.server.netty.http1.*
import io.ktor.server.netty.http2.*
import io.ktor.util.logging.*
import io.netty.channel.*
import io.netty.channel.socket.SocketChannel
import io.netty.handler.codec.http.*
Expand Down Expand Up @@ -103,7 +103,13 @@ public class NettyChannelInitializer(
private fun configurePipeline(pipeline: ChannelPipeline, protocol: String) {
when (protocol) {
ApplicationProtocolNames.HTTP_2 -> {
val handler = NettyHttp2Handler(enginePipeline, environment.application, callEventGroup, userContext)
val handler = NettyHttp2Handler(
enginePipeline,
environment.application,
callEventGroup,
userContext,
runningLimit
)
@Suppress("DEPRECATION")
pipeline.addLast(Http2MultiplexCodecBuilder.forServer(handler).build())
pipeline.channel().closeFuture().addListener {
Expand Down
@@ -0,0 +1,24 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.netty

import io.netty.channel.*
import kotlinx.atomicfu.*

internal class NettyHttpHandlerState(private val runningLimit: Int) {

internal val activeRequests: AtomicLong = atomic(0L)
internal val isCurrentRequestFullyRead: AtomicBoolean = atomic(false)
internal val isChannelReadCompleted: AtomicBoolean = atomic(false)
internal val skippedRead: AtomicBoolean = atomic(false)

internal fun onLastResponseMessage(context: ChannelHandlerContext) {
activeRequests.decrementAndGet()

if (skippedRead.compareAndSet(expect = false, update = true) && activeRequests.value < runningLimit) {
context.read()
}
}
}
Expand Up @@ -22,16 +22,10 @@ private const val UNFLUSHED_LIMIT = 65536

/**
* Contains methods for handling http request with Netty
* @param context
* @param coroutineContext
* @param activeRequests
* @param isCurrentRequestFullyRead
* @param isChannelReadCompleted
*/
@OptIn(InternalAPI::class)
internal class NettyHttpResponsePipeline constructor(
private val context: ChannelHandlerContext,
private val httpHandler: NettyHttp1Handler,
private val httpHandlerState: NettyHttpHandlerState,
override val coroutineContext: CoroutineContext
) : CoroutineScope {
/**
Expand All @@ -56,8 +50,8 @@ internal class NettyHttpResponsePipeline constructor(
internal fun flushIfNeeded() {
if (
isDataNotFlushed.value &&
httpHandler.isChannelReadCompleted.value &&
httpHandler.activeRequests.value == 0L
httpHandlerState.isChannelReadCompleted.value &&
httpHandlerState.activeRequests.value == 0L
) {
context.flush()
isDataNotFlushed.compareAndSet(expect = true, update = false)
Expand Down Expand Up @@ -145,7 +139,7 @@ internal class NettyHttpResponsePipeline constructor(
null
}

httpHandler.onLastResponseMessage(context)
httpHandlerState.onLastResponseMessage(context)
call.finishedEvent.setSuccess()

lastMessageFuture?.addListener {
Expand Down Expand Up @@ -232,9 +226,9 @@ internal class NettyHttpResponsePipeline constructor(
* True if client is waiting for response header, false otherwise
*/
private fun isHeaderFlushNeeded(): Boolean {
val activeRequestsValue = httpHandler.activeRequests.value
return httpHandler.isChannelReadCompleted.value &&
!httpHandler.isCurrentRequestFullyRead.value &&
val activeRequestsValue = httpHandlerState.activeRequests.value
return httpHandlerState.isChannelReadCompleted.value &&
!httpHandlerState.isCurrentRequestFullyRead.value &&
activeRequestsValue == 1L
}

Expand Down Expand Up @@ -365,7 +359,6 @@ internal class NettyHttpResponsePipeline constructor(
}
}

@OptIn(InternalAPI::class)
private fun NettyApplicationResponse.isUpgradeResponse() =
status()?.value == HttpStatusCode.SwitchingProtocols.value

Expand Down
Expand Up @@ -31,29 +31,15 @@ internal class NettyHttp1Handler(
override val coroutineContext: CoroutineContext get() = handlerJob

private var skipEmpty = false
private val skippedRead: AtomicBoolean = atomic(false)

private lateinit var responseWriter: NettyHttpResponsePipeline

/**
* Represents current number of processing requests
*/
internal val activeRequests: AtomicLong = atomic(0L)

/**
* True if current request's last http content is read, false otherwise.
*/
internal val isCurrentRequestFullyRead: AtomicBoolean = atomic(false)

/**
* True if [channelReadComplete] was invoked for the current request, false otherwise
*/
internal val isChannelReadCompleted: AtomicBoolean = atomic(false)
private val state = NettyHttpHandlerState(runningLimit)

override fun channelActive(context: ChannelHandlerContext) {
responseWriter = NettyHttpResponsePipeline(
context,
this,
state,
coroutineContext
)

Expand All @@ -68,16 +54,16 @@ internal class NettyHttp1Handler(

override fun channelRead(context: ChannelHandlerContext, message: Any) {
if (message is LastHttpContent) {
isCurrentRequestFullyRead.compareAndSet(expect = false, update = true)
state.isCurrentRequestFullyRead.compareAndSet(expect = false, update = true)
}

when {
message is HttpRequest -> {
if (message !is LastHttpContent) {
isCurrentRequestFullyRead.compareAndSet(expect = true, update = false)
state.isCurrentRequestFullyRead.compareAndSet(expect = true, update = false)
}
isChannelReadCompleted.compareAndSet(expect = true, update = false)
activeRequests.incrementAndGet()
state.isChannelReadCompleted.compareAndSet(expect = true, update = false)
state.activeRequests.incrementAndGet()

handleRequest(context, message)
callReadIfNeeded(context)
Expand Down Expand Up @@ -110,7 +96,7 @@ internal class NettyHttp1Handler(
}

override fun channelReadComplete(context: ChannelHandlerContext?) {
isChannelReadCompleted.compareAndSet(expect = false, update = true)
state.isChannelReadCompleted.compareAndSet(expect = false, update = true)
responseWriter.flushIfNeeded()
super.channelReadComplete(context)
}
Expand Down Expand Up @@ -165,19 +151,11 @@ internal class NettyHttp1Handler(
}

private fun callReadIfNeeded(context: ChannelHandlerContext) {
if (activeRequests.value < runningLimit) {
if (state.activeRequests.value < runningLimit) {
context.read()
skippedRead.value = false
state.skippedRead.value = false
} else {
skippedRead.value = true
}
}

internal fun onLastResponseMessage(context: ChannelHandlerContext) {
activeRequests.decrementAndGet()

if (skippedRead.compareAndSet(expect = false, update = true) && activeRequests.value < runningLimit) {
context.read()
state.skippedRead.value = true
}
}
}
Expand Up @@ -8,11 +8,11 @@ import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.engine.*
import io.ktor.server.netty.*
import io.ktor.server.netty.cio.*
import io.ktor.server.response.*
import io.ktor.util.*
import io.netty.channel.*
import io.netty.handler.codec.http2.*
import io.netty.util.AttributeKey
import io.netty.util.*
import io.netty.util.concurrent.*
import kotlinx.coroutines.*
import java.lang.reflect.*
Expand All @@ -24,16 +24,22 @@ internal class NettyHttp2Handler(
private val enginePipeline: EnginePipeline,
private val application: Application,
private val callEventGroup: EventExecutorGroup,
private val userCoroutineContext: CoroutineContext
private val userCoroutineContext: CoroutineContext,
runningLimit: Int
) : ChannelInboundHandlerAdapter(), CoroutineScope {
private val handlerJob = SupervisorJob(userCoroutineContext[Job])

private val state = NettyHttpHandlerState(runningLimit)
private lateinit var responseWriter: NettyHttpResponsePipeline

override val coroutineContext: CoroutineContext
get() = handlerJob

override fun channelRead(context: ChannelHandlerContext, message: Any?) {
override fun channelRead(context: ChannelHandlerContext, message: Any) {
when (message) {
is Http2HeadersFrame -> {
state.isChannelReadCompleted.compareAndSet(expect = true, update = false)
state.activeRequests.incrementAndGet()
startHttp2(context, message.headers())
}
is Http2DataFrame -> {
Expand All @@ -42,6 +48,9 @@ internal class NettyHttp2Handler(
contentActor.trySend(message).isSuccess
if (eof) {
contentActor.close()
state.isCurrentRequestFullyRead.compareAndSet(expect = false, update = true)
} else {
state.isCurrentRequestFullyRead.compareAndSet(expect = true, update = false)
}
} ?: message.release()
}
Expand All @@ -55,20 +64,30 @@ internal class NettyHttp2Handler(
}
}

override fun channelRegistered(ctx: ChannelHandlerContext?) {
super.channelRegistered(ctx)
override fun channelActive(context: ChannelHandlerContext) {
responseWriter = NettyHttpResponsePipeline(
context,
state,
coroutineContext
)

ctx?.pipeline()?.apply {
context.pipeline()?.apply {
addLast(callEventGroup, NettyApplicationCallHandler(userCoroutineContext, enginePipeline))
}
context.fireChannelActive()
}

override fun channelReadComplete(context: ChannelHandlerContext) {
state.isChannelReadCompleted.compareAndSet(expect = false, update = true)
responseWriter.flushIfNeeded()
context.fireChannelReadComplete()
}

@Suppress("OverridingDeprecatedMember")
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
ctx.close()
}

@OptIn(InternalAPI::class)
private fun startHttp2(context: ChannelHandlerContext, headers: Http2Headers) {
val call = NettyHttp2ApplicationCall(
application,
Expand All @@ -79,6 +98,9 @@ internal class NettyHttp2Handler(
userCoroutineContext
)
context.applicationCall = call

context.fireChannelRead(call)
responseWriter.processResponse(call)
}

@Suppress("DEPRECATION")
Expand Down