Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-2588 Fix long line reading #3182

Merged
merged 3 commits into from Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
172 changes: 55 additions & 117 deletions ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt
Expand Up @@ -10,7 +10,6 @@ import io.ktor.utils.io.pool.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import java.io.EOFException
import java.lang.Double.*
import java.lang.Float.*
import java.nio.*
Expand Down Expand Up @@ -226,19 +225,23 @@ internal open class ByteBufferChannel(
allocatedState?.let { releaseBuffer(it) }
return null
}

closed != null -> {
allocatedState?.let { releaseBuffer(it) }
rethrowClosed(closed!!.sendException)
}

state === ReadWriteBufferState.IdleEmpty -> {
val allocated = allocatedState ?: newBuffer().also { allocatedState = it }
allocated.startWriting()
}

state === ReadWriteBufferState.Terminated -> {
allocatedState?.let { releaseBuffer(it) }
if (joining != null) return null
rethrowClosed(closed!!.sendException)
}

else -> {
state.startWriting()
}
Expand Down Expand Up @@ -290,6 +293,7 @@ internal open class ByteBufferChannel(
val cause = closed?.cause ?: return null
rethrowClosed(cause)
}

else -> {
closed?.cause?.let { rethrowClosed(it) }
if (state.capacity.availableForRead == 0) return null
Expand Down Expand Up @@ -406,11 +410,13 @@ internal open class ByteBufferChannel(
toRelease = state.initial
ReadWriteBufferState.Terminated
}

forceTermination && state is ReadWriteBufferState.IdleNonEmpty &&
state.capacity.tryLockForRelease() -> {
toRelease = state.initial
ReadWriteBufferState.Terminated
}

else -> return false
}
}
Expand Down Expand Up @@ -672,6 +678,7 @@ internal open class ByteBufferChannel(
-1
}
}

consumed > 0 || length == 0 -> consumed
else -> readAvailableSuspend(dst, offset, length)
}
Expand All @@ -688,6 +695,7 @@ internal open class ByteBufferChannel(
-1
}
}

consumed > 0 || !dst.hasRemaining() -> consumed
else -> readAvailableSuspend(dst)
}
Expand All @@ -704,6 +712,7 @@ internal open class ByteBufferChannel(
-1
}
}

consumed > 0 || !dst.canWrite() -> consumed
else -> readAvailableSuspend(dst)
}
Expand Down Expand Up @@ -1907,18 +1916,6 @@ internal open class ByteBufferChannel(
return rc
}

private suspend fun consumeEachBufferRangeSuspend(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean) {
var last = false

do {
if (consumeEachBufferRangeFast(last, visitor)) return
if (last) return
if (!readSuspend(1)) {
last = true
}
} while (true)
}

private fun afterBufferVisited(buffer: ByteBuffer, capacity: RingBufferCapacity): Int {
val consumed = buffer.position() - readPosition
if (consumed > 0) {
Expand All @@ -1934,134 +1931,75 @@ internal open class ByteBufferChannel(
private suspend fun readUTF8LineToAscii(out: Appendable, limit: Int): Boolean {
if (state === ReadWriteBufferState.Terminated) {
val cause = closedCause
if (cause != null) {
throw cause
}

if (cause != null) throw cause
return false
}

var consumed = 0

val array = CharArray(8192)
val buffer = CharBuffer.wrap(array)
var eol = false

lookAhead {
eol = readLineLoop(
out,
array,
buffer,
await = { expected -> availableForRead >= expected },
addConsumed = { consumed += it },
decode = {
it.decodeASCIILine(array, 0, minOf(array.size, limit - consumed))
}
)
}

if (eol) return true
if (consumed == 0 && isClosedForRead) return false

return readUTF8LineToUtf8Suspend(out, limit - consumed, array, buffer, consumed)
return readUTF8LineToUtf8Suspend(out, limit)
}

private inline fun LookAheadSession.readLineLoop(
private suspend fun readUTF8LineToUtf8Suspend(
out: Appendable,
ca: CharArray,
cb: CharBuffer,
await: (Int) -> Boolean,
addConsumed: (Int) -> Unit,
decode: (ByteBuffer) -> Long
limit: Int
): Boolean {
// number of bytes required for the next character, <= 0 when no characters required anymore (exit loop)
var consumed = 0
var required = 1
var caret = false
var newLine = false

do {
if (!await(required)) break
val buffer = request(0, 1) ?: break
val output = CharArray(8 * 1024)
while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)

val before = buffer.position()
if (buffer.remaining() < required) {
buffer.rollBytes(required)
}
val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()

val rc = decode(buffer)
required = kotlin.math.max(1, requiredBytes)

val after = buffer.position()
consumed(after - before)
if (requiredBytes == -1) {
newLine = true
}

val decoded = (rc shr 32).toInt()
val rcRequired = (rc and 0xffffffffL).toInt()
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}

required = when {
// EOL
rcRequired == -1 -> 0
// no EOL, no demands but untouched bytes
// for ascii decoder that could mean that there was non-ASCII character encountered
rcRequired == 0 && buffer.hasRemaining() -> -1
else -> maxOf(1, rcRequired)
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}

addConsumed(decoded)
if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}

if (out is StringBuilder) {
out.append(ca, 0, decoded)
} else {
out.append(cb, 0, decoded)
}
} while (required > 0)
consumed += decoded

return when (required) {
0 -> true
else -> false
if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
}
}
}
}

private suspend fun readUTF8LineToUtf8Suspend(
out: Appendable,
limit: Int,
ca: CharArray,
cb: CharBuffer,
consumed0: Int
): Boolean {
var consumed1 = 0
var result = true

lookAheadSuspend {
val rc = readLineLoop(
out,
ca,
cb,
await = { awaitAtLeast(it) },
addConsumed = { consumed1 += it },
decode = { it.decodeUTF8Line(ca, 0, minOf(ca.size, limit - consumed1)) }
)

if (rc || !isClosedForWrite) {
return@lookAheadSuspend
}
val buffer = request(0, 1)
when {
buffer != null -> {
if (buffer.get() != '\r'.code.toByte()) {
buffer.position(buffer.position() - 1)
throw TooLongLineException("Line is longer than limit")
}

consumed(1)

if (buffer.hasRemaining()) {
throw MalformedInputException("Illegal trailing bytes: ${buffer.remaining()}")
if (!isClosedForRead && caret && !newLine) {
try {
read(1) {
if (it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}
}
consumed1 == 0 && consumed0 == 0 -> {
result = false
}
} catch (_: EOFException) {
}
}

return result
return (isClosedForRead && consumed > 0) || (newLine || caret)
}

override suspend fun <A : Appendable> readUTF8LineTo(out: A, limit: Int): Boolean =
Expand Down
35 changes: 35 additions & 0 deletions ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt
Expand Up @@ -21,4 +21,39 @@ class ByteChannelTextTest {
channel.readUTF8Line(50)
}
}

@Test
fun testReadUtf8Line32k() = runBlocking {
val line = "x".repeat(32 * 1024)
val bytes = line.encodeToByteArray()
val channel = ByteReadChannel(bytes)

val result = channel.readUTF8Line()
assertEquals(line, result)
}

@Test
fun testReadLineUtf8Chunks() = runBlocking {
val line = "x".repeat(32 * 1024)
val channel = writer {
channel.writeStringUtf8(line)
}.channel

val result = channel.readUTF8Line()
assertEquals(line, result)
}

@Test
fun test2EmptyLines() {
val text = ByteReadChannel("\r\n\r\n")

runBlocking {
assertEquals(4, text.availableForRead)
assertEquals("", text.readUTF8Line())
assertEquals(2, text.availableForRead)
assertEquals(2, text.totalBytesRead)
assertEquals("", text.readUTF8Line())
assertNull(text.readUTF8Line())
}
}
}
Expand Up @@ -12,6 +12,7 @@ import io.ktor.server.cio.internal.*
import io.ktor.util.*
import io.ktor.util.cio.*
import io.ktor.utils.io.*
import io.ktor.utils.io.charsets.*
import io.ktor.utils.io.errors.*
import kotlinx.coroutines.*
import kotlinx.coroutines.CancellationException
Expand Down Expand Up @@ -55,18 +56,15 @@ public fun CoroutineScope.startServerConnectionPipeline(
while (true) { // parse requests loop
val request = try {
parseRequest(connection.input) ?: break
} catch (cause: TooLongLineException) {
respondBadRequest(actorChannel)
break // end pipeline loop
} catch (io: IOException) {
throw io
} catch (cancelled: CancellationException) {
throw cancelled
} catch (parseFailed: Throwable) { // try to write 400 Bad Request
// TODO log parseFailed?
val bc = ByteChannel()
if (actorChannel.trySend(bc).isSuccess) {
bc.writePacket(BadRequestPacket.copy())
bc.close()
}
actorChannel.close()
respondBadRequest(actorChannel)
break // end pipeline loop
}

Expand Down Expand Up @@ -183,6 +181,15 @@ public fun CoroutineScope.startServerConnectionPipeline(
}
}

private suspend fun respondBadRequest(actorChannel: Channel<ByteReadChannel>) {
val bc = ByteChannel()
if (actorChannel.trySend(bc).isSuccess) {
bc.writePacket(BadRequestPacket.copy())
bc.close()
}
actorChannel.close()
}

@OptIn(InternalAPI::class)
private suspend fun pipelineWriterLoop(
channel: ReceiveChannel<ByteReadChannel>,
Expand Down