Skip to content

Commit

Permalink
KTOR-4379 Validate body size equals Content-Length (#3069)
Browse files Browse the repository at this point in the history
* KTOR-4379 Validate body size equals Content-Length
  • Loading branch information
e5l committed Jun 27, 2022
1 parent 7c32f4e commit 2fffabe
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 58 deletions.
Expand Up @@ -7,7 +7,7 @@ package io.ktor.http.cio
import io.ktor.http.*
import io.ktor.http.cio.internals.*
import io.ktor.utils.io.*

import io.ktor.utils.io.errors.*
/**
* @return `true` if an http upgrade is expected accoding to request [method], [upgrade] header value and
* parsed [connectionOptions]
Expand Down Expand Up @@ -84,7 +84,11 @@ public suspend fun parseHttpBody(
}

if (contentLength != -1L) {
input.copyTo(out, contentLength)
val size = input.copyTo(out, contentLength)

if (size != contentLength) {
throw IOException("Unexpected body length: expected $contentLength, actual $size")
}
return
}

Expand Down
Expand Up @@ -19,9 +19,7 @@ public abstract class NettyApplicationCall(
private val requestMessage: Any,
) : BaseApplicationCall(application) {

@OptIn(InternalAPI::class)
public abstract override val request: NettyApplicationRequest
@OptIn(InternalAPI::class)
public abstract override val response: NettyApplicationResponse

internal lateinit var previousCallFinished: ChannelPromise
Expand Down Expand Up @@ -65,7 +63,6 @@ public abstract class NettyApplicationCall(

internal suspend fun finish() {
try {
@OptIn(InternalAPI::class)
response.ensureResponseSent()
} catch (cause: Throwable) {
finishedEvent.setFailure(cause)
Expand All @@ -89,14 +86,12 @@ public abstract class NettyApplicationCall(
}
}

@OptIn(InternalAPI::class)
private fun finishComplete() {
responseWriteJob.cancel()
request.close()
releaseRequestMessage()
}

@OptIn(InternalAPI::class)
internal fun dispose() {
response.close()
request.close()
Expand Down
Expand Up @@ -5,14 +5,18 @@
package io.ktor.server.netty.cio

import io.ktor.utils.io.*
import io.ktor.utils.io.errors.*
import io.netty.buffer.*
import io.netty.channel.*
import io.netty.handler.codec.http.*
import io.netty.util.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import java.lang.Integer.*
import kotlin.coroutines.*

private class ChannelEvent(val channel: ByteWriteChannel, val expectedLength: Long)

internal class RequestBodyHandler(
val context: ChannelHandlerContext
) : ChannelInboundHandlerAdapter(), CoroutineScope {
Expand All @@ -26,36 +30,48 @@ internal class RequestBodyHandler(

private val job = launch(context.executor().asCoroutineDispatcher(), start = CoroutineStart.LAZY) {
var current: ByteWriteChannel? = null
var expectedLength = -1L
var written = 0L
var upgraded = false

fun checkCurrentLengthAndClose() {
if (expectedLength == -1L || written == expectedLength) {
current?.close()
return
}

val message = "Unexpected length of the request body. Expected $expectedLength but was $written"
current?.close(IOException(message))
}

try {
while (true) {
@OptIn(ExperimentalCoroutinesApi::class)
val event = queue.tryReceive().getOrNull()
?: run { current?.flush(); queue.receiveCatching().getOrNull() }
?: break

when (event) {
is ByteBufHolder -> {
val channel = current
?: throw IllegalStateException("No current channel but received a byte buf")
processContent(channel, event)
val channel = current ?: error("No current channel but received a byte buf")
written += processContent(channel, event)

if (!upgraded && event is LastHttpContent) {
current.close()
checkCurrentLengthAndClose()
current = null
}
requestMoreEvents()
}
is ByteBuf -> {
val channel =
current ?: throw IllegalStateException("No current channel but received a byte buf")
processContent(channel, event)
val channel = current ?: error("No current channel but received a byte buf")
written += processContent(channel, event)
requestMoreEvents()
}
is ByteWriteChannel -> {
current?.close()
current = event
is ChannelEvent -> {
checkCurrentLengthAndClose()

current = event.channel
expectedLength = event.expectedLength
written = 0L
}
is Upgrade -> {
upgraded = true
Expand All @@ -66,7 +82,7 @@ internal class RequestBodyHandler(
queue.close(t)
current?.close(t)
} finally {
current?.close()
checkCurrentLengthAndClose()
queue.close()
consumeAndReleaseQueue()
}
Expand All @@ -75,7 +91,7 @@ internal class RequestBodyHandler(
@OptIn(ExperimentalCoroutinesApi::class)
fun upgrade(): ByteReadChannel {
val result = queue.trySend(Upgrade)
if (result.isSuccess) return newChannel()
if (result.isSuccess) return newChannel(-1L)

if (queue.isClosedForSend) {
throw CancellationException("HTTP pipeline has been terminated.", result.exceptionOrNull())
Expand All @@ -87,10 +103,10 @@ internal class RequestBodyHandler(
)
}

fun newChannel(): ByteReadChannel {
val bc = ByteChannel()
tryOfferChannelOrToken(bc)
return bc
fun newChannel(contentLength: Long): ByteReadChannel {
val result = ByteChannel()
tryOfferChannelOrToken(ChannelEvent(result, contentLength))
return result
}

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down Expand Up @@ -121,18 +137,18 @@ internal class RequestBodyHandler(
}
}

private suspend fun processContent(current: ByteWriteChannel, event: ByteBufHolder) {
private suspend fun processContent(current: ByteWriteChannel, event: ByteBufHolder): Int {
try {
val buf = event.content()
copy(buf, current)
return copy(buf, current)
} finally {
event.release()
}
}

private suspend fun processContent(current: ByteWriteChannel, buf: ByteBuf) {
private suspend fun processContent(current: ByteWriteChannel, buf: ByteBuf): Int {
try {
copy(buf, current)
return copy(buf, current)
} finally {
buf.release()
}
Expand Down Expand Up @@ -160,12 +176,14 @@ internal class RequestBodyHandler(
}
}

private suspend fun copy(buf: ByteBuf, dst: ByteWriteChannel) {
private suspend fun copy(buf: ByteBuf, dst: ByteWriteChannel): Int {
val length = buf.readableBytes()
if (length > 0) {
val buffer = buf.internalNioBuffer(buf.readerIndex(), length)
dst.writeFully(buffer)
}

return max(length, 0)
}

private fun handleBytesRead(content: ReferenceCounted) {
Expand Down
Expand Up @@ -51,7 +51,6 @@ internal class NettyHttp1Handler(
*/
internal val isChannelReadCompleted: AtomicBoolean = atomic(false)

@OptIn(InternalAPI::class)
override fun channelActive(context: ChannelHandlerContext) {
responseWriter = NettyHttpResponsePipeline(
context,
Expand Down Expand Up @@ -152,17 +151,19 @@ internal class NettyHttp1Handler(
)
}

private fun prepareRequestContentChannel(context: ChannelHandlerContext, message: HttpRequest): ByteReadChannel {
return when (message) {
is HttpContent -> {
val bodyHandler = context.pipeline().get(RequestBodyHandler::class.java)
bodyHandler.newChannel().also { bodyHandler.channelRead(context, message) }
}
else -> {
val bodyHandler = context.pipeline().get(RequestBodyHandler::class.java)
bodyHandler.newChannel()
}
private fun prepareRequestContentChannel(
context: ChannelHandlerContext,
message: HttpRequest
): ByteReadChannel {
val bodyHandler = context.pipeline().get(RequestBodyHandler::class.java)
val length = message.headers()[io.ktor.http.HttpHeaders.ContentLength]?.toLongOrNull() ?: -1
val result = bodyHandler.newChannel(length)

if (message is HttpContent) {
bodyHandler.channelRead(context, message)
}

return result
}

private fun callReadIfNeeded(context: ChannelHandlerContext) {
Expand Down
Expand Up @@ -20,12 +20,10 @@ internal fun CoroutineScope.servletReader(input: ServletInputStream, contentLeng
}
}

private class ServletReader(val input: ServletInputStream, contentLength: Int) : ReadListener {
private class ServletReader(val input: ServletInputStream, val contentLength: Int) : ReadListener {
val channel = ByteChannel()
private val events = Channel<Unit>(2)

private val contentLength: Int = if (contentLength < 0) Int.MAX_VALUE else contentLength

suspend fun run() {
val buffer = ArrayPool.borrow()
try {
Expand Down Expand Up @@ -72,20 +70,24 @@ private class ServletReader(val input: ServletInputStream, contentLength: Int) :

channel.writeFully(buffer, 0, readCount)

if (contentLength < 0) continue

if (bodySize == contentLength) {
channel.close()
events.close()
break
}

if (bodySize > contentLength) {
val cause = IOException(
"Client provided more bytes than content length. Expected $contentLength but got $bodySize."
)
channel.close(cause)
events.close()
break
val message = if (bodySize > contentLength) {
"Client provided more bytes than content length. Expected $contentLength but got $bodySize."
} else {
"Client provided less bytes than content length. Expected $contentLength but got $bodySize."
}

val cause = IOException(message)
channel.close(cause)
events.close()
break
}
}

Expand All @@ -112,8 +114,7 @@ private class ServletReader(val input: ServletInputStream, contentLength: Int) :
private fun wrapException(cause: Throwable): Throwable? {
return when (cause) {
is EOFException -> null
is TimeoutException,
is IOException -> ChannelReadException(
is TimeoutException -> ChannelReadException(
"Cannot read from a servlet input stream",
exception = cause as Exception
)
Expand Down
Expand Up @@ -639,10 +639,10 @@ abstract class ContentTestSuite<TEngine : ApplicationEngine, TConfiguration : Ap
call.receiveMultipart().forEachPart { part ->
when (part) {
is PartData.FormItem -> response.append("${part.name}=${part.value}\n")
is PartData.FileItem -> response.append(
"file:${part.name},${part.originalFileName}," +
"${part.streamProvider().bufferedReader().lineSequence().count()}\n"
)
is PartData.FileItem -> {
val lineSequence = part.streamProvider().bufferedReader().lineSequence()
response.append("file:${part.name},${part.originalFileName},${lineSequence.count()}\n")
}
is PartData.BinaryItem -> {
}
is PartData.BinaryChannelItem -> {}
Expand Down
Expand Up @@ -846,6 +846,49 @@ abstract class SustainabilityTestSuite<TEngine : ApplicationEngine, TConfigurati
outputStream.close()
}
}

@Test
fun testBodySmallerThanContentLength() {
var failCause: Throwable? = null
val result = Job()

createAndStartServer {
post("/") {
try {
println(call.receive<ByteArray>().size)
} catch (cause: Throwable) {
failCause = cause
} finally {
result.complete()
}

call.respond("OK")
}
}

socket {
val request = buildString {
append("POST / HTTP/1.1\r\n")
append("Content-Length: 4\r\n")
append("Content-Type: text/plain\r\n")
append("Connection: close\r\n")
append("Host: localhost\r\n")
append("\r\n")
append("ABC")
}

outputStream.writer().use {
it.write(request)
}
}

runBlocking {
result.join()
}

assertTrue(failCause != null)
assertTrue(failCause is IOException)
}
}

internal inline fun assertFails(block: () -> Unit) {
Expand Down
6 changes: 4 additions & 2 deletions ktor-utils/common/src/io/ktor/util/cio/Channels.kt
Expand Up @@ -23,5 +23,7 @@ public class ChannelWriteException(message: String = "Cannot write to a channel"
* An exception that is thrown when an IO error occurred during reading from the request channel.
* Usually it happens when a remote client closed the connection.
*/
public class ChannelReadException(message: String = "Cannot read from a channel", exception: Throwable) :
ChannelIOException(message, exception)
public class ChannelReadException(
message: String = "Cannot read from a channel",
exception: Throwable
) : ChannelIOException(message, exception)

0 comments on commit 2fffabe

Please sign in to comment.