Skip to content

Commit

Permalink
KTOR-2588 Fix long line reading (#3182)
Browse files Browse the repository at this point in the history
* KTOR-2588 Fix long line reading
  • Loading branch information
e5l committed Oct 12, 2022
1 parent 3455bce commit e1bfa58
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 124 deletions.
172 changes: 55 additions & 117 deletions ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt
Expand Up @@ -10,7 +10,6 @@ import io.ktor.utils.io.pool.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import java.io.EOFException
import java.lang.Double.*
import java.lang.Float.*
import java.nio.*
Expand Down Expand Up @@ -226,19 +225,23 @@ internal open class ByteBufferChannel(
allocatedState?.let { releaseBuffer(it) }
return null
}

closed != null -> {
allocatedState?.let { releaseBuffer(it) }
rethrowClosed(closed!!.sendException)
}

state === ReadWriteBufferState.IdleEmpty -> {
val allocated = allocatedState ?: newBuffer().also { allocatedState = it }
allocated.startWriting()
}

state === ReadWriteBufferState.Terminated -> {
allocatedState?.let { releaseBuffer(it) }
if (joining != null) return null
rethrowClosed(closed!!.sendException)
}

else -> {
state.startWriting()
}
Expand Down Expand Up @@ -290,6 +293,7 @@ internal open class ByteBufferChannel(
val cause = closed?.cause ?: return null
rethrowClosed(cause)
}

else -> {
closed?.cause?.let { rethrowClosed(it) }
if (state.capacity.availableForRead == 0) return null
Expand Down Expand Up @@ -406,11 +410,13 @@ internal open class ByteBufferChannel(
toRelease = state.initial
ReadWriteBufferState.Terminated
}

forceTermination && state is ReadWriteBufferState.IdleNonEmpty &&
state.capacity.tryLockForRelease() -> {
toRelease = state.initial
ReadWriteBufferState.Terminated
}

else -> return false
}
}
Expand Down Expand Up @@ -672,6 +678,7 @@ internal open class ByteBufferChannel(
-1
}
}

consumed > 0 || length == 0 -> consumed
else -> readAvailableSuspend(dst, offset, length)
}
Expand All @@ -688,6 +695,7 @@ internal open class ByteBufferChannel(
-1
}
}

consumed > 0 || !dst.hasRemaining() -> consumed
else -> readAvailableSuspend(dst)
}
Expand All @@ -704,6 +712,7 @@ internal open class ByteBufferChannel(
-1
}
}

consumed > 0 || !dst.canWrite() -> consumed
else -> readAvailableSuspend(dst)
}
Expand Down Expand Up @@ -1907,18 +1916,6 @@ internal open class ByteBufferChannel(
return rc
}

private suspend fun consumeEachBufferRangeSuspend(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean) {
var last = false

do {
if (consumeEachBufferRangeFast(last, visitor)) return
if (last) return
if (!readSuspend(1)) {
last = true
}
} while (true)
}

private fun afterBufferVisited(buffer: ByteBuffer, capacity: RingBufferCapacity): Int {
val consumed = buffer.position() - readPosition
if (consumed > 0) {
Expand All @@ -1934,134 +1931,75 @@ internal open class ByteBufferChannel(
private suspend fun readUTF8LineToAscii(out: Appendable, limit: Int): Boolean {
if (state === ReadWriteBufferState.Terminated) {
val cause = closedCause
if (cause != null) {
throw cause
}

if (cause != null) throw cause
return false
}

var consumed = 0

val array = CharArray(8192)
val buffer = CharBuffer.wrap(array)
var eol = false

lookAhead {
eol = readLineLoop(
out,
array,
buffer,
await = { expected -> availableForRead >= expected },
addConsumed = { consumed += it },
decode = {
it.decodeASCIILine(array, 0, minOf(array.size, limit - consumed))
}
)
}

if (eol) return true
if (consumed == 0 && isClosedForRead) return false

return readUTF8LineToUtf8Suspend(out, limit - consumed, array, buffer, consumed)
return readUTF8LineToUtf8Suspend(out, limit)
}

private inline fun LookAheadSession.readLineLoop(
private suspend fun readUTF8LineToUtf8Suspend(
out: Appendable,
ca: CharArray,
cb: CharBuffer,
await: (Int) -> Boolean,
addConsumed: (Int) -> Unit,
decode: (ByteBuffer) -> Long
limit: Int
): Boolean {
// number of bytes required for the next character, <= 0 when no characters required anymore (exit loop)
var consumed = 0
var required = 1
var caret = false
var newLine = false

do {
if (!await(required)) break
val buffer = request(0, 1) ?: break
val output = CharArray(8 * 1024)
while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)

val before = buffer.position()
if (buffer.remaining() < required) {
buffer.rollBytes(required)
}
val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()

val rc = decode(buffer)
required = kotlin.math.max(1, requiredBytes)

val after = buffer.position()
consumed(after - before)
if (requiredBytes == -1) {
newLine = true
}

val decoded = (rc shr 32).toInt()
val rcRequired = (rc and 0xffffffffL).toInt()
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}

required = when {
// EOL
rcRequired == -1 -> 0
// no EOL, no demands but untouched bytes
// for ascii decoder that could mean that there was non-ASCII character encountered
rcRequired == 0 && buffer.hasRemaining() -> -1
else -> maxOf(1, rcRequired)
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}

addConsumed(decoded)
if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}

if (out is StringBuilder) {
out.append(ca, 0, decoded)
} else {
out.append(cb, 0, decoded)
}
} while (required > 0)
consumed += decoded

return when (required) {
0 -> true
else -> false
if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
}
}
}
}

private suspend fun readUTF8LineToUtf8Suspend(
out: Appendable,
limit: Int,
ca: CharArray,
cb: CharBuffer,
consumed0: Int
): Boolean {
var consumed1 = 0
var result = true

lookAheadSuspend {
val rc = readLineLoop(
out,
ca,
cb,
await = { awaitAtLeast(it) },
addConsumed = { consumed1 += it },
decode = { it.decodeUTF8Line(ca, 0, minOf(ca.size, limit - consumed1)) }
)

if (rc || !isClosedForWrite) {
return@lookAheadSuspend
}
val buffer = request(0, 1)
when {
buffer != null -> {
if (buffer.get() != '\r'.code.toByte()) {
buffer.position(buffer.position() - 1)
throw TooLongLineException("Line is longer than limit")
}

consumed(1)

if (buffer.hasRemaining()) {
throw MalformedInputException("Illegal trailing bytes: ${buffer.remaining()}")
if (!isClosedForRead && caret && !newLine) {
try {
read(1) {
if (it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}
}
consumed1 == 0 && consumed0 == 0 -> {
result = false
}
} catch (_: EOFException) {
}
}

return result
return (isClosedForRead && consumed > 0) || (newLine || caret)
}

override suspend fun <A : Appendable> readUTF8LineTo(out: A, limit: Int): Boolean =
Expand Down
35 changes: 35 additions & 0 deletions ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt
Expand Up @@ -21,4 +21,39 @@ class ByteChannelTextTest {
channel.readUTF8Line(50)
}
}

@Test
fun testReadUtf8Line32k() = runBlocking {
val line = "x".repeat(32 * 1024)
val bytes = line.encodeToByteArray()
val channel = ByteReadChannel(bytes)

val result = channel.readUTF8Line()
assertEquals(line, result)
}

@Test
fun testReadLineUtf8Chunks() = runBlocking {
val line = "x".repeat(32 * 1024)
val channel = writer {
channel.writeStringUtf8(line)
}.channel

val result = channel.readUTF8Line()
assertEquals(line, result)
}

@Test
fun test2EmptyLines() {
val text = ByteReadChannel("\r\n\r\n")

runBlocking {
assertEquals(4, text.availableForRead)
assertEquals("", text.readUTF8Line())
assertEquals(2, text.availableForRead)
assertEquals(2, text.totalBytesRead)
assertEquals("", text.readUTF8Line())
assertNull(text.readUTF8Line())
}
}
}
Expand Up @@ -12,6 +12,7 @@ import io.ktor.server.cio.internal.*
import io.ktor.util.*
import io.ktor.util.cio.*
import io.ktor.utils.io.*
import io.ktor.utils.io.charsets.*
import io.ktor.utils.io.errors.*
import kotlinx.coroutines.*
import kotlinx.coroutines.CancellationException
Expand Down Expand Up @@ -55,18 +56,15 @@ public fun CoroutineScope.startServerConnectionPipeline(
while (true) { // parse requests loop
val request = try {
parseRequest(connection.input) ?: break
} catch (cause: TooLongLineException) {
respondBadRequest(actorChannel)
break // end pipeline loop
} catch (io: IOException) {
throw io
} catch (cancelled: CancellationException) {
throw cancelled
} catch (parseFailed: Throwable) { // try to write 400 Bad Request
// TODO log parseFailed?
val bc = ByteChannel()
if (actorChannel.trySend(bc).isSuccess) {
bc.writePacket(BadRequestPacket.copy())
bc.close()
}
actorChannel.close()
respondBadRequest(actorChannel)
break // end pipeline loop
}

Expand Down Expand Up @@ -183,6 +181,15 @@ public fun CoroutineScope.startServerConnectionPipeline(
}
}

private suspend fun respondBadRequest(actorChannel: Channel<ByteReadChannel>) {
val bc = ByteChannel()
if (actorChannel.trySend(bc).isSuccess) {
bc.writePacket(BadRequestPacket.copy())
bc.close()
}
actorChannel.close()
}

@OptIn(InternalAPI::class)
private suspend fun pipelineWriterLoop(
channel: ReceiveChannel<ByteReadChannel>,
Expand Down

0 comments on commit e1bfa58

Please sign in to comment.