Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-5252 Fix EOFException in read and readUtf8Line #3285

Merged
merged 5 commits into from Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 48 additions & 39 deletions ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt
Expand Up @@ -1655,7 +1655,7 @@ internal open class ByteBufferChannel(

if (!read) {
if (isClosedForRead) {
return
throw EOFException("Got EOF but at least $min bytes were expected")
}

readBlockSuspend(min, consumer)
Expand Down Expand Up @@ -1949,41 +1949,45 @@ internal open class ByteBufferChannel(

val output = CharArray(8 * 1024)
while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)
try {
read(required) {
val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed)
val decodeResult = it.decodeUTF8Line(output, 0, readLimit)

val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()
val decoded = (decodeResult shr 32).toInt()
val requiredBytes = (decodeResult and 0xffffffffL).toInt()

required = kotlin.math.max(1, requiredBytes)
required = kotlin.math.max(1, requiredBytes)

if (requiredBytes == -1) {
newLine = true
}
if (requiredBytes == -1) {
newLine = true
}

if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) {
it.position(it.position() + 1)
caret = true
}

if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}
if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) {
it.position(it.position() + 1)
newLine = true
}

if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}
if (out is StringBuilder) {
out.append(output, 0, decoded)
} else {
val buffer = CharBuffer.wrap(output, 0, decoded)
out.append(buffer, 0, decoded)
}

consumed += decoded
consumed += decoded

if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) {
throw TooLongLineException("Line is longer than limit")
}
}
} catch (_: EOFException) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not an expert of course, but I see this as something easy to break - is it possible to avoid it? or at least add a comment, on why it's ignored?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, it's ignored by the contract of the method

// Ignored by the contract of [ByteReadChannel.readUTF8LineTo] method
}
}

Expand All @@ -1996,6 +2000,7 @@ internal open class ByteBufferChannel(
}
}
} catch (_: EOFException) {
// Ignored by the contract of [ByteReadChannel.readUTF8LineTo] method
}
}

Expand Down Expand Up @@ -2325,19 +2330,23 @@ internal open class ByteBufferChannel(
var bytesCopied = 0
val desiredSize = (min + offset).coerceAtMost(4088L).toInt()

read(desiredSize) { nioBuffer ->
if (nioBuffer.remaining() > offset) {
val view = nioBuffer.duplicate()!!
view.position(view.position() + offset.toInt())

val oldLimit = view.limit()
val canCopyToDestination = minOf(max, destination.size - destinationOffset)
val newLimit = minOf(view.limit().toLong(), canCopyToDestination + offset)
view.limit(newLimit.toInt())
bytesCopied = view.remaining()
view.copyTo(destination, destinationOffset.toInt())
view.limit(oldLimit)
try {
read(desiredSize) { nioBuffer ->
if (nioBuffer.remaining() > offset) {
val view = nioBuffer.duplicate()!!
view.position(view.position() + offset.toInt())

val oldLimit = view.limit()
val canCopyToDestination = minOf(max, destination.size - destinationOffset)
val newLimit = minOf(view.limit().toLong(), canCopyToDestination + offset)
view.limit(newLimit.toInt())
bytesCopied = view.remaining()
view.copyTo(destination, destinationOffset.toInt())
view.limit(oldLimit)
}
}
} catch (_: EOFException) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

// ignored by the contract of peekTo method
}

return bytesCopied.toLong()
Expand Down
48 changes: 48 additions & 0 deletions ktor-io/jvm/test/io/ktor/utils/io/ByteBufferChannelTest.kt
Expand Up @@ -4,6 +4,8 @@

package io.ktor.utils.io

import io.ktor.test.dispatcher.*
import io.ktor.utils.io.core.EOFException
import kotlinx.coroutines.*
import kotlinx.coroutines.debug.junit4.*
import org.junit.*
Expand All @@ -24,6 +26,29 @@ class ByteBufferChannelTest {
assertFailsWith<IOException> { runBlocking { channel.readByte() } }
}

@Test
fun testEarlyEOF() = testSuspend {
repeat(20000) {
val channel = ByteChannel(true)
launch(Dispatchers.IO) {
channel.writeFully("1\n".toByteArray())
channel.close()
}

launch(Dispatchers.IO) {
channel.read(1) {
it.get(ByteArray(it.remaining()))
}

assertFailsWith<EOFException> {
channel.read(1) {
it.get(ByteArray(it.remaining()))
}
}
}.join()
}
}

@Test
fun readRemainingThrowsOnClosed() = runBlocking {
val channel = ByteBufferChannel(false)
Expand All @@ -36,6 +61,29 @@ class ByteBufferChannelTest {
Unit
}

@Test
fun testReadUtf8LineEOF() = testSuspend {
repeat(20000) {
val channel = ByteChannel(true)
val writer = launch(Dispatchers.IO) {
channel.writeFully("1\n".toByteArray())
channel.close()
}

val reader = async(Dispatchers.IO) {
val lines = mutableListOf<String>()
while (true) {
val line = channel.readUTF8Line(5000) ?: break
lines.add(line)
}
lines
}

reader.await()
writer.join()
}
}

@Test
fun testWriteWriteAvailableRaceCondition() = runBlocking {
testWriteXRaceCondition { it.writeAvailable(1) { it.put(1) } }
Expand Down