diff --git a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt index 2ecd3fefa6..27239eb5e1 100644 --- a/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt +++ b/ktor-io/jvm/src/io/ktor/utils/io/ByteBufferChannel.kt @@ -10,7 +10,6 @@ import io.ktor.utils.io.pool.* import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* -import java.io.EOFException import java.lang.Double.* import java.lang.Float.* import java.nio.* @@ -226,19 +225,23 @@ internal open class ByteBufferChannel( allocatedState?.let { releaseBuffer(it) } return null } + closed != null -> { allocatedState?.let { releaseBuffer(it) } rethrowClosed(closed!!.sendException) } + state === ReadWriteBufferState.IdleEmpty -> { val allocated = allocatedState ?: newBuffer().also { allocatedState = it } allocated.startWriting() } + state === ReadWriteBufferState.Terminated -> { allocatedState?.let { releaseBuffer(it) } if (joining != null) return null rethrowClosed(closed!!.sendException) } + else -> { state.startWriting() } @@ -290,6 +293,7 @@ internal open class ByteBufferChannel( val cause = closed?.cause ?: return null rethrowClosed(cause) } + else -> { closed?.cause?.let { rethrowClosed(it) } if (state.capacity.availableForRead == 0) return null @@ -406,11 +410,13 @@ internal open class ByteBufferChannel( toRelease = state.initial ReadWriteBufferState.Terminated } + forceTermination && state is ReadWriteBufferState.IdleNonEmpty && state.capacity.tryLockForRelease() -> { toRelease = state.initial ReadWriteBufferState.Terminated } + else -> return false } } @@ -672,6 +678,7 @@ internal open class ByteBufferChannel( -1 } } + consumed > 0 || length == 0 -> consumed else -> readAvailableSuspend(dst, offset, length) } @@ -688,6 +695,7 @@ internal open class ByteBufferChannel( -1 } } + consumed > 0 || !dst.hasRemaining() -> consumed else -> readAvailableSuspend(dst) } @@ -704,6 +712,7 @@ internal open class ByteBufferChannel( -1 } } + consumed > 0 || !dst.canWrite() -> consumed else -> readAvailableSuspend(dst) } @@ -1907,18 +1916,6 @@ internal open class ByteBufferChannel( return rc } - private suspend fun consumeEachBufferRangeSuspend(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean) { - var last = false - - do { - if (consumeEachBufferRangeFast(last, visitor)) return - if (last) return - if (!readSuspend(1)) { - last = true - } - } while (true) - } - private fun afterBufferVisited(buffer: ByteBuffer, capacity: RingBufferCapacity): Int { val consumed = buffer.position() - readPosition if (consumed > 0) { @@ -1934,134 +1931,75 @@ internal open class ByteBufferChannel( private suspend fun readUTF8LineToAscii(out: Appendable, limit: Int): Boolean { if (state === ReadWriteBufferState.Terminated) { val cause = closedCause - if (cause != null) { - throw cause - } - + if (cause != null) throw cause return false } - var consumed = 0 - - val array = CharArray(8192) - val buffer = CharBuffer.wrap(array) - var eol = false - - lookAhead { - eol = readLineLoop( - out, - array, - buffer, - await = { expected -> availableForRead >= expected }, - addConsumed = { consumed += it }, - decode = { - it.decodeASCIILine(array, 0, minOf(array.size, limit - consumed)) - } - ) - } - - if (eol) return true - if (consumed == 0 && isClosedForRead) return false - - return readUTF8LineToUtf8Suspend(out, limit - consumed, array, buffer, consumed) + return readUTF8LineToUtf8Suspend(out, limit) } - private inline fun LookAheadSession.readLineLoop( + private suspend fun readUTF8LineToUtf8Suspend( out: Appendable, - ca: CharArray, - cb: CharBuffer, - await: (Int) -> Boolean, - addConsumed: (Int) -> Unit, - decode: (ByteBuffer) -> Long + limit: Int ): Boolean { - // number of bytes required for the next character, <= 0 when no characters required anymore (exit loop) + var consumed = 0 var required = 1 + var caret = false + var newLine = false - do { - if (!await(required)) break - val buffer = request(0, 1) ?: break + val output = CharArray(8 * 1024) + while (!isClosedForRead && !newLine && !caret && (limit == Int.MAX_VALUE || consumed <= limit)) { + read(required) { + val readLimit = if (limit == Int.MAX_VALUE) output.size else minOf(output.size, limit - consumed) + val decodeResult = it.decodeUTF8Line(output, 0, readLimit) - val before = buffer.position() - if (buffer.remaining() < required) { - buffer.rollBytes(required) - } + val decoded = (decodeResult shr 32).toInt() + val requiredBytes = (decodeResult and 0xffffffffL).toInt() - val rc = decode(buffer) + required = kotlin.math.max(1, requiredBytes) - val after = buffer.position() - consumed(after - before) + if (requiredBytes == -1) { + newLine = true + } - val decoded = (rc shr 32).toInt() - val rcRequired = (rc and 0xffffffffL).toInt() + if (it.hasRemaining() && it[it.position()] == '\r'.code.toByte()) { + it.position(it.position() + 1) + caret = true + } - required = when { - // EOL - rcRequired == -1 -> 0 - // no EOL, no demands but untouched bytes - // for ascii decoder that could mean that there was non-ASCII character encountered - rcRequired == 0 && buffer.hasRemaining() -> -1 - else -> maxOf(1, rcRequired) - } + if (it.hasRemaining() && it[it.position()] == '\n'.code.toByte()) { + it.position(it.position() + 1) + newLine = true + } - addConsumed(decoded) + if (out is StringBuilder) { + out.append(output, 0, decoded) + } else { + val buffer = CharBuffer.wrap(output, 0, decoded) + out.append(buffer, 0, decoded) + } - if (out is StringBuilder) { - out.append(ca, 0, decoded) - } else { - out.append(cb, 0, decoded) - } - } while (required > 0) + consumed += decoded - return when (required) { - 0 -> true - else -> false + if (limit != Int.MAX_VALUE && consumed >= limit && !newLine) { + throw TooLongLineException("Line is longer than limit") + } + } } - } - - private suspend fun readUTF8LineToUtf8Suspend( - out: Appendable, - limit: Int, - ca: CharArray, - cb: CharBuffer, - consumed0: Int - ): Boolean { - var consumed1 = 0 - var result = true - - lookAheadSuspend { - val rc = readLineLoop( - out, - ca, - cb, - await = { awaitAtLeast(it) }, - addConsumed = { consumed1 += it }, - decode = { it.decodeUTF8Line(ca, 0, minOf(ca.size, limit - consumed1)) } - ) - - if (rc || !isClosedForWrite) { - return@lookAheadSuspend - } - val buffer = request(0, 1) - when { - buffer != null -> { - if (buffer.get() != '\r'.code.toByte()) { - buffer.position(buffer.position() - 1) - throw TooLongLineException("Line is longer than limit") - } - - consumed(1) - if (buffer.hasRemaining()) { - throw MalformedInputException("Illegal trailing bytes: ${buffer.remaining()}") + if (!isClosedForRead && caret && !newLine) { + try { + read(1) { + if (it[it.position()] == '\n'.code.toByte()) { + it.position(it.position() + 1) + newLine = true } } - consumed1 == 0 && consumed0 == 0 -> { - result = false - } + } catch (_: EOFException) { } } - return result + return (isClosedForRead && consumed > 0) || (newLine || caret) } override suspend fun readUTF8LineTo(out: A, limit: Int): Boolean = diff --git a/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt b/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt index aba4b31034..fc7bcee1c3 100644 --- a/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt +++ b/ktor-io/jvm/test/io/ktor/utils/io/charsets/ByteChannelTextTest.kt @@ -21,4 +21,25 @@ class ByteChannelTextTest { channel.readUTF8Line(50) } } + + @Test + fun testReadUtf8Line32k() = runBlocking { + val line = "x".repeat(32 * 1024) + val bytes = line.encodeToByteArray() + val channel = ByteReadChannel(bytes) + + val result = channel.readUTF8Line() + assertEquals(line, result) + } + + @Test + fun testReadLineUtf8Chunks() = runBlocking { + val line = "x".repeat(32 * 1024) + val channel = writer { + channel.writeStringUtf8(line) + }.channel + + val result = channel.readUTF8Line() + assertEquals(line, result) + } }