Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-5252 Fix EOFException in read and readUtf8Line #3285

Merged
merged 5 commits into from Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 30 additions & 27 deletions ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt
Expand Up @@ -1655,7 +1655,7 @@ internal open class ByteBufferChannel(

if (!read) {
if (isClosedForRead) {
return
throw EOFException("Got EOF but at least $min bytes were expected")
}

readBlockSuspend(min, consumer)
Expand Down Expand Up @@ -1949,41 +1949,44 @@ internal open class ByteBufferChannel(

val output = CharArray(8 * 1024)
while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)
try {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)

val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()
val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()

required = kotlin.math.max(1, requiredBytes)
required = kotlin.math.max(1, requiredBytes)

if (requiredBytes == -1) {
newLine = true
}
if (requiredBytes == -1) {
newLine = true
}

if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}

if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}

if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}
if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}

consumed += decoded
consumed += decoded

if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
}
}
} catch (_: EOFException) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not an expert of course, but I see this as something easy to break - is it possible to avoid it? or at least add a comment, on why it's ignored?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, it's ignored by the contract of the method

}
}

Expand Down
48 changes: 48 additions & 0 deletions ktor-io/jvm/test/io/ktor/utils/io/ByteBufferChannelTest.kt
Expand Up @@ -4,6 +4,8 @@

package io.ktor.utils.io

import io.ktor.test.dispatcher.*
import io.ktor.utils.io.core.EOFException
import kotlinx.coroutines.*
import kotlinx.coroutines.debug.junit4.*
import org.junit.*
Expand All @@ -24,6 +26,29 @@ class ByteBufferChannelTest {
assertFailsWith<IOException> { runBlocking { channel.readByte() } }
}

@Test
fun testEarlyEOF() = testSuspend {
repeat(20000) {
val channel = ByteChannel(true)
launch(Dispatchers.IO) {
channel.writeFully("1\n".toByteArray())
channel.close()
}

launch(Dispatchers.IO) {
channel.read(1) {
it.get(ByteArray(it.remaining()))
}

assertFailsWith<EOFException> {
channel.read(1) {
it.get(ByteArray(it.remaining()))
}
}
}.join()
}
}

@Test
fun readRemainingThrowsOnClosed() = runBlocking {
val channel = ByteBufferChannel(false)
Expand All @@ -36,6 +61,29 @@ class ByteBufferChannelTest {
Unit
}

@Test
fun testReadUtf8LineEOF() = testSuspend {
repeat(20000) {
val channel = ByteChannel(true)
val writer = launch(Dispatchers.IO) {
channel.writeFully("1\n".toByteArray())
channel.close()
}

val reader = async(Dispatchers.IO) {
val lines = mutableListOf<String>()
while (true) {
val line = channel.readUTF8Line(5000) ?: break
lines.add(line)
}
lines
}

reader.await()
writer.join()
}
}

@Test
fun testWriteWriteAvailableRaceCondition() = runBlocking {
testWriteXRaceCondition { it.writeAvailable(1) { it.put(1) } }
Expand Down