Skip to content

Commit

Permalink
KTOR-5252 Fix EOFException in read and readUtf8Line (#3285)
Browse files Browse the repository at this point in the history
* KTOR-5252 Fix Missing EOF exception

* KTOR-5252 Mute EOF exception in line reading

* KTOR-5252 Fix peekTo EOFException
  • Loading branch information
e5l committed Dec 6, 2022
1 parent 9e12c15 commit bfde300
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 39 deletions.
87 changes: 48 additions & 39 deletions ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt
Expand Up @@ -1655,7 +1655,7 @@ internal open class ByteBufferChannel(

if (!read) {
if (isClosedForRead) {
return
throw EOFException("Got EOF but at least $min bytes were expected")
}

readBlockSuspend(min, consumer)
Expand Down Expand Up @@ -1949,41 +1949,45 @@ internal open class ByteBufferChannel(

val output = CharArray(8 * 1024)
while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)
try {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)

val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()
val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()

required = kotlin.math.max(1, requiredBytes)
required = kotlin.math.max(1, requiredBytes)

if (requiredBytes == -1) {
newLine = true
}
if (requiredBytes == -1) {
newLine = true
}

if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}

if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}

if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}
if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}

consumed += decoded
consumed += decoded

if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
}
}
} catch (_: EOFException) {
// Ignored by the contract of [ByteReadChannel.readUTF8LineTo] method
}
}

Expand All @@ -1996,6 +2000,7 @@ internal open class ByteBufferChannel(
}
}
} catch (_: EOFException) {
// Ignored by the contract of [ByteReadChannel.readUTF8LineTo] method
}
}

Expand Down Expand Up @@ -2325,19 +2330,23 @@ internal open class ByteBufferChannel(
var bytesCopied = 0
val desiredSize = (min + offset).coerceAtMost(4088L).toInt()

read(desiredSize) { nioBuffer ->
if (nioBuffer.remaining() > offset) {
val view = nioBuffer.duplicate()!!
view.position(view.position() + offset.toInt())

val oldLimit = view.limit()
val canCopyToDestination = minOf(max, destination.size - destinationOffset)
val newLimit = minOf(view.limit().toLong(), canCopyToDestination + offset)
view.limit(newLimit.toInt())
bytesCopied = view.remaining()
view.copyTo(destination, destinationOffset.toInt())
view.limit(oldLimit)
try {
read(desiredSize) { nioBuffer ->
if (nioBuffer.remaining() > offset) {
val view = nioBuffer.duplicate()!!
view.position(view.position() + offset.toInt())

val oldLimit = view.limit()
val canCopyToDestination = minOf(max, destination.size - destinationOffset)
val newLimit = minOf(view.limit().toLong(), canCopyToDestination + offset)
view.limit(newLimit.toInt())
bytesCopied = view.remaining()
view.copyTo(destination, destinationOffset.toInt())
view.limit(oldLimit)
}
}
} catch (_: EOFException) {
// ignored by the contract of peekTo method
}

return bytesCopied.toLong()
Expand Down
48 changes: 48 additions & 0 deletions ktor-io/jvm/test/io/ktor/utils/io/ByteBufferChannelTest.kt
Expand Up @@ -4,6 +4,8 @@

package io.ktor.utils.io

import io.ktor.test.dispatcher.*
import io.ktor.utils.io.core.EOFException
import kotlinx.coroutines.*
import kotlinx.coroutines.debug.junit4.*
import org.junit.*
Expand All @@ -24,6 +26,29 @@ class ByteBufferChannelTest {
assertFailsWith<IOException> { runBlocking { channel.readByte() } }
}

@Test
fun testEarlyEOF() = testSuspend {
repeat(20000) {
val channel = ByteChannel(true)
launch(Dispatchers.IO) {
channel.writeFully("1\n".toByteArray())
channel.close()
}

launch(Dispatchers.IO) {
channel.read(1) {
it.get(ByteArray(it.remaining()))
}

assertFailsWith<EOFException> {
channel.read(1) {
it.get(ByteArray(it.remaining()))
}
}
}.join()
}
}

@Test
fun readRemainingThrowsOnClosed() = runBlocking {
val channel = ByteBufferChannel(false)
Expand All @@ -36,6 +61,29 @@ class ByteBufferChannelTest {
Unit
}

@Test
fun testReadUtf8LineEOF() = testSuspend {
repeat(20000) {
val channel = ByteChannel(true)
val writer = launch(Dispatchers.IO) {
channel.writeFully("1\n".toByteArray())
channel.close()
}

val reader = async(Dispatchers.IO) {
val lines = mutableListOf<String>()
while (true) {
val line = channel.readUTF8Line(5000) ?: break
lines.add(line)
}
lines
}

reader.await()
writer.join()
}
}

@Test
fun testWriteWriteAvailableRaceCondition() = runBlocking {
testWriteXRaceCondition { it.writeAvailable(1) { it.put(1) } }
Expand Down

0 comments on commit bfde300

Please sign in to comment.