diff --git a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt index f0eb60c8f2..f2a9e2752a 100644 --- a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt +++ b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt @@ -1655,7 +1655,7 @@ internal open class ByteBufferChannel( if (!read) { if (isClosedForRead) { - return + throw EOFException("Got EOF but at least $min bytes were expected") } readBlockSuspend(min, consumer) @@ -1949,41 +1949,45 @@ internal open class ByteBufferChannel( val output = CharArray(8 * 1024) while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) { - read(required) { - val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed) - val decodeResult = it.decodeUTF8Line(output, 0, readLimit) + try { + read(required) { + val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed) + val decodeResult = it.decodeUTF8Line(output, 0, readLimit) - val decoded = (decodeResult shr 32).toInt() - val requiredBytes = (decodeResult and 0xffffffffL).toInt() + val decoded = (decodeResult shr 32).toInt() + val requiredBytes = (decodeResult and 0xffffffffL).toInt() - required = kotlin.math.max(1, requiredBytes) + required = kotlin.math.max(1, requiredBytes) - if (requiredBytes == -1) { - newLine = true - } + if (requiredBytes == -1) { + newLine = true + } - if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) { - it.position(it.position() + 1) - caret = true - } + if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) { + it.position(it.position() + 1) + caret = true + } - if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) { - it.position(it.position() + 1) - newLine = true - } + if (requiredBytes != -1 && it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) { + it.position(it.position() + 1) + newLine = true + } - if (out is StringBuilder) { - out.append(output, 0, decoded) - } else { - val buffer = CharBuffer.wrap(output, 0, decoded) - out.append(buffer, 0, decoded) - } + if (out is StringBuilder) { + out.append(output, 0, decoded) + } else { + val buffer = CharBuffer.wrap(output, 0, decoded) + out.append(buffer, 0, decoded) + } - consumed += decoded + consumed += decoded - if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) { - throw TooLongLineException("Line is longer than limit") + if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) { + throw TooLongLineException("Line is longer than limit") + } } + } catch (_: EOFException) { + // Ignored by the contract of [ByteReadChannel.readUTF8LineTo] method } } @@ -1996,6 +2000,7 @@ internal open class ByteBufferChannel( } } } catch (_: EOFException) { + // Ignored by the contract of [ByteReadChannel.readUTF8LineTo] method } } @@ -2325,19 +2330,23 @@ internal open class ByteBufferChannel( var bytesCopied = 0 val desiredSize = (min + offset).coerceAtMost(4088L).toInt() - read(desiredSize) { nioBuffer -> - if (nioBuffer.remaining() > offset) { - val view = nioBuffer.duplicate()!! - view.position(view.position() + offset.toInt()) - - val oldLimit = view.limit() - val canCopyToDestination = minOf(max, destination.size - destinationOffset) - val newLimit = minOf(view.limit().toLong(), canCopyToDestination + offset) - view.limit(newLimit.toInt()) - bytesCopied = view.remaining() - view.copyTo(destination, destinationOffset.toInt()) - view.limit(oldLimit) + try { + read(desiredSize) { nioBuffer -> + if (nioBuffer.remaining() > offset) { + val view = nioBuffer.duplicate()!! + view.position(view.position() + offset.toInt()) + + val oldLimit = view.limit() + val canCopyToDestination = minOf(max, destination.size - destinationOffset) + val newLimit = minOf(view.limit().toLong(), canCopyToDestination + offset) + view.limit(newLimit.toInt()) + bytesCopied = view.remaining() + view.copyTo(destination, destinationOffset.toInt()) + view.limit(oldLimit) + } } + } catch (_: EOFException) { + // ignored by the contract of peekTo method } return bytesCopied.toLong() diff --git a/ktor-io/jvm/test/io/ktor/utils/io/ByteBufferChannelTest.kt b/ktor-io/jvm/test/io/ktor/utils/io/ByteBufferChannelTest.kt index cba1283972..54884433cc 100644 --- a/ktor-io/jvm/test/io/ktor/utils/io/ByteBufferChannelTest.kt +++ b/ktor-io/jvm/test/io/ktor/utils/io/ByteBufferChannelTest.kt @@ -4,6 +4,8 @@ package io.ktor.utils.io +import io.ktor.test.dispatcher.* +import io.ktor.utils.io.core.EOFException import kotlinx.coroutines.* import kotlinx.coroutines.debug.junit4.* import org.junit.* @@ -24,6 +26,29 @@ class ByteBufferChannelTest { assertFailsWith { runBlocking { channel.readByte() } } } + @Test + fun testEarlyEOF() = testSuspend { + repeat(20000) { + val channel = ByteChannel(true) + launch(Dispatchers.IO) { + channel.writeFully("1\n".toByteArray()) + channel.close() + } + + launch(Dispatchers.IO) { + channel.read(1) { + it.get(ByteArray(it.remaining())) + } + + assertFailsWith { + channel.read(1) { + it.get(ByteArray(it.remaining())) + } + } + }.join() + } + } + @Test fun readRemainingThrowsOnClosed() = runBlocking { val channel = ByteBufferChannel(false) @@ -36,6 +61,29 @@ class ByteBufferChannelTest { Unit } + @Test + fun testReadUtf8LineEOF() = testSuspend { + repeat(20000) { + val channel = ByteChannel(true) + val writer = launch(Dispatchers.IO) { + channel.writeFully("1\n".toByteArray()) + channel.close() + } + + val reader = async(Dispatchers.IO) { + val lines = mutableListOf() + while (true) { + val line = channel.readUTF8Line(5000) ?: break + lines.add(line) + } + lines + } + + reader.await() + writer.join() + } + } + @Test fun testWriteWriteAvailableRaceCondition() = runBlocking { testWriteXRaceCondition { it.writeAvailable(1) { it.put(1) } }