From ae3629402c10974a75fb74ebec7d1408d76e40dd Mon Sep 17 00:00:00 2001 From: "leonid.stashevsky" Date: Thu, 29 Sep 2022 11:37:15 +0200 Subject: [PATCH 1/3] KTOR-2588 Fix long line reading --- .../src/io/ktor/utils/io/ByteBufferChannel.kt | 172 ++++++------------ .../utils/io/charsets/ByteChannelTextTest.kt | 21 +++ 2 files changed, 76 insertions(+), 117 deletions(-) diff --git a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt index 2ecd3fefa6..27239eb5e1 100644 --- a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt +++ b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt @@ -10,7 +10,6 @@ import io.ktor.utils.io.pool.* import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -import java.io.EOFException import java.lang.Double.* import java.lang.Float.* import java.nio.* @@ -226,19 +225,23 @@ internal open class ByteBufferChannel( allocatedState?.let { releaseBuffer(it) } return null } + closed != null -> { allocatedState?.let { releaseBuffer(it) } rethrowClosed(closed!!.sendException) } + state === ReadWriteBufferState.IdleEmpty -> { val allocated = allocatedState ?: newBuffer().also { allocatedState = it } allocated.startWriting() } + state === ReadWriteBufferState.Terminated -> { allocatedState?.let { releaseBuffer(it) } if (joining != null) return null rethrowClosed(closed!!.sendException) } + else -> { state.startWriting() } @@ -290,6 +293,7 @@ internal open class ByteBufferChannel( val cause = closed?.cause ?: return null rethrowClosed(cause) } + else -> { closed?.cause?.let { rethrowClosed(it) } if (state.capacity.availableForRead == 0) return null @@ -406,11 +410,13 @@ internal open class ByteBufferChannel( toRelease = state.initial ReadWriteBufferState.Terminated } + forceTermination && state is ReadWriteBufferState.IdleNonEmpty && state.capacity.tryLockForRelease() -> { toRelease = state.initial ReadWriteBufferState.Terminated } + else -> return false } } @@ -672,6 +678,7 @@ internal open class ByteBufferChannel( -1 } } + consumed > 0 || length == 0 -> consumed else -> readAvailableSuspend(dst, offset, length) } @@ -688,6 +695,7 @@ internal open class ByteBufferChannel( -1 } } + consumed > 0 || !dst.hasRemaining() -> consumed else -> readAvailableSuspend(dst) } @@ -704,6 +712,7 @@ internal open class ByteBufferChannel( -1 } } + consumed > 0 || !dst.canWrite() -> consumed else -> readAvailableSuspend(dst) } @@ -1907,18 +1916,6 @@ internal open class ByteBufferChannel( return rc } - private suspend fun consumeEachBufferRangeSuspend(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean) { - var last = false - - do { - if (consumeEachBufferRangeFast(last, visitor)) return - if (last) return - if (!readSuspend(1)) { - last = true - } - } while (true) - } - private fun afterBufferVisited(buffer: ByteBuffer, capacity: RingBufferCapacity): Int { val consumed = buffer.position() - readPosition if (consumed > 0) { @@ -1934,134 +1931,75 @@ internal open class ByteBufferChannel( private suspend fun readUTF8LineToAscii(out: Appendable, limit: Int): Boolean { if (state === ReadWriteBufferState.Terminated) { val cause = closedCause - if (cause != null) { - throw cause - } - + if (cause != null) throw cause return false } - var consumed = 0 - - val array = CharArray(8192) - val buffer = CharBuffer.wrap(array) - var eol = false - - lookAhead { - eol = readLineLoop( - out, - array, - buffer, - await = { expected -> availableForRead >= expected }, - addConsumed = { consumed += it }, - decode = { - it.decodeASCIILine(array, 0, minOf(array.size, limit - consumed)) - } - ) - } - - if (eol) return true - if (consumed == 0 && isClosedForRead) return false - - return readUTF8LineToUtf8Suspend(out, limit - consumed, array, buffer, consumed) + return readUTF8LineToUtf8Suspend(out, limit) } - private inline fun LookAheadSession.readLineLoop( + private suspend fun readUTF8LineToUtf8Suspend( out: Appendable, - ca: CharArray, - cb: CharBuffer, - await: (Int) -> Boolean, - addConsumed: (Int) -> Unit, - decode: (ByteBuffer) -> Long + limit: Int ): Boolean { - // number of bytes required for the next character, <= 0 when no characters required anymore (exit loop) + var consumed = 0 var required = 1 + var caret = false + var newLine = false - do { - if (!await(required)) break - val buffer = request(0, 1) ?: break + val output = CharArray(8 * 1024) + while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) { + read(required) { + val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed) + val decodeResult = it.decodeUTF8Line(output, 0, readLimit) - val before = buffer.position() - if (buffer.remaining() < required) { - buffer.rollBytes(required) - } + val decoded = (decodeResult shr 32).toInt() + val requiredBytes = (decodeResult and 0xffffffffL).toInt() - val rc = decode(buffer) + required = kotlin.math.max(1, requiredBytes) - val after = buffer.position() - consumed(after - before) + if (requiredBytes == -1) { + newLine = true + } - val decoded = (rc shr 32).toInt() - val rcRequired = (rc and 0xffffffffL).toInt() + if (it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) { + it.position(it.position() + 1) + caret = true + } - required = when { - // EOL - rcRequired == -1 -> 0 - // no EOL, no demands but untouched bytes - // for ascii decoder that could mean that there was non-ASCII character encountered - rcRequired == 0 && buffer.hasRemaining() -> -1 - else -> maxOf(1, rcRequired) - } + if (it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) { + it.position(it.position() + 1) + newLine = true + } - addConsumed(decoded) + if (out is StringBuilder) { + out.append(output, 0, decoded) + } else { + val buffer = CharBuffer.wrap(output, 0, decoded) + out.append(buffer, 0, decoded) + } - if (out is StringBuilder) { - out.append(ca, 0, decoded) - } else { - out.append(cb, 0, decoded) - } - } while (required > 0) + consumed += decoded - return when (required) { - 0 -> true - else -> false + if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) { + throw TooLongLineException("Line is longer than limit") + } + } } - } - - private suspend fun readUTF8LineToUtf8Suspend( - out: Appendable, - limit: Int, - ca: CharArray, - cb: CharBuffer, - consumed0: Int - ): Boolean { - var consumed1 = 0 - var result = true - - lookAheadSuspend { - val rc = readLineLoop( - out, - ca, - cb, - await = { awaitAtLeast(it) }, - addConsumed = { consumed1 += it }, - decode = { it.decodeUTF8Line(ca, 0, minOf(ca.size, limit - consumed1)) } - ) - - if (rc || !isClosedForWrite) { - return@lookAheadSuspend - } - val buffer = request(0, 1) - when { - buffer != null -> { - if (buffer.get() != '\r'.code.toByte()) { - buffer.position(buffer.position() - 1) - throw TooLongLineException("Line is longer than limit") - } - - consumed(1) - if (buffer.hasRemaining()) { - throw MalformedInputException("Illegal trailing bytes: ${buffer.remaining()}") + if (!isClosedForRead && caret && !newLine) { + try { + read(1) { + if (it[it.position()] == '\n'.code.toByte()) { + it.position(it.position() + 1) + newLine = true } } - consumed1 == 0 && consumed0 == 0 -> { - result = false - } + } catch (_: EOFException) { } } - return result + return (isClosedForRead && consumed > 0) || (newLine || caret) } override suspend fun readUTF8LineTo(out: A, limit: Int): Boolean = diff --git a/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt b/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt index aba4b31034..fc7bcee1c3 100644 --- a/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt +++ b/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt @@ -21,4 +21,25 @@ class ByteChannelTextTest { channel.readUTF8Line(50) } } + + @Test + fun testReadUtf8Line32k() = runBlocking { + val line = "x".repeat(32 * 1024) + val bytes = line.encodeToByteArray() + val channel = ByteReadChannel(bytes) + + val result = channel.readUTF8Line() + assertEquals(line, result) + } + + @Test + fun testReadLineUtf8Chunks() = runBlocking { + val line = "x".repeat(32 * 1024) + val channel = writer { + channel.writeStringUtf8(line) + }.channel + + val result = channel.readUTF8Line() + assertEquals(line, result) + } } From 1555a8854a7cdf8624acb2529a7edba7db4eed3d Mon Sep 17 00:00:00 2001 From: "leonid.stashevsky" Date: Mon, 10 Oct 2022 11:40:21 +0200 Subject: [PATCH 2/3] fixup! KTOR-2588 Fix long line reading --- .../jvm/src/io/ktor/utils/io/ByteBufferChannel.kt | 4 ++-- .../ktor/utils/io/charsets/ByteChannelTextTest.kt | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt index 27239eb5e1..f0eb60c8f2 100644 --- a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt +++ b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt @@ -1962,12 +1962,12 @@ internal open class ByteBufferChannel( newLine = true } - if (it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) { + if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) { it.position(it.position() + 1) caret = true } - if (it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) { + if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) { it.position(it.position() + 1) newLine = true } diff --git a/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt b/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt index fc7bcee1c3..1158a6c45c 100644 --- a/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt +++ b/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt @@ -42,4 +42,18 @@ class ByteChannelTextTest { val result = channel.readUTF8Line() assertEquals(line, result) } + + @Test + fun test2EmptyLines() { + val text = ByteReadChannel("\r\n\r\n") + + runBlocking { + assertEquals(4, text.availableForRead) + assertEquals("", text.readUTF8Line()) + assertEquals(2, text.availableForRead) + assertEquals(2, text.totalBytesRead) + assertEquals("", text.readUTF8Line()) + assertNull(text.readUTF8Line()) + } + } } From 47d21b6138e325b2c143a65c2713cda77a052981 Mon Sep 17 00:00:00 2001 From: "leonid.stashevsky" Date: Mon, 10 Oct 2022 13:00:46 +0200 Subject: [PATCH 3/3] fixup! fixup! KTOR-2588 Fix long line reading --- .../ktor/server/cio/backend/ServerPipeline.kt | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/ktor-server/ktor-server-cio/jvmAndNix/src/io/ktor/server/cio/backend/ServerPipeline.kt b/ktor-server/ktor-server-cio/jvmAndNix/src/io/ktor/server/cio/backend/ServerPipeline.kt index 7b8db46225..50b8463f99 100644 --- a/ktor-server/ktor-server-cio/jvmAndNix/src/io/ktor/server/cio/backend/ServerPipeline.kt +++ b/ktor-server/ktor-server-cio/jvmAndNix/src/io/ktor/server/cio/backend/ServerPipeline.kt @@ -12,6 +12,7 @@ import io.ktor.server.cio.internal.* import io.ktor.util.* import io.ktor.util.cio.* import io.ktor.utils.io.* +import io.ktor.utils.io.charsets.* import io.ktor.utils.io.errors.* import kotlinx.coroutines.* import kotlinx.coroutines.CancellationException @@ -55,18 +56,15 @@ public fun CoroutineScope.startServerConnectionPipeline( while (true) { // parse requests loop val request = try { parseRequest(connection.input) ?: break + } catch (cause: TooLongLineException) { + respondBadRequest(actorChannel) + break // end pipeline loop } catch (io: IOException) { throw io } catch (cancelled: CancellationException) { throw cancelled } catch (parseFailed: Throwable) { // try to write 400 Bad Request - // TODO log parseFailed? - val bc = ByteChannel() - if (actorChannel.trySend(bc).isSuccess) { - bc.writePacket(BadRequestPacket.copy()) - bc.close() - } - actorChannel.close() + respondBadRequest(actorChannel) break // end pipeline loop } @@ -183,6 +181,15 @@ public fun CoroutineScope.startServerConnectionPipeline( } } +private suspend fun respondBadRequest(actorChannel: Channel) { + val bc = ByteChannel() + if (actorChannel.trySend(bc).isSuccess) { + bc.writePacket(BadRequestPacket.copy()) + bc.close() + } + actorChannel.close() +} + @OptIn(InternalAPI::class) private suspend fun pipelineWriterLoop( channel: ReceiveChannel,