Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-4379 Validate body size equals Content-Length #3069

Merged
merged 3 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package io.ktor.http.cio
import io.ktor.http.*
import io.ktor.http.cio.internals.*
import io.ktor.utils.io.*

import io.ktor.utils.io.errors.*
/**
* @return `true` if an http upgrade is expected accoding to request [method], [upgrade] header value and
* parsed [connectionOptions]
Expand Down Expand Up @@ -84,7 +84,11 @@ public suspend fun parseHttpBody(
}

if (contentLength != -1L) {
input.copyTo(out, contentLength)
val size = input.copyTo(out, contentLength)

if (size != contentLength) {
throw IOException("Unexpected body length: expected $contentLength, actual $size")
}
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ public abstract class NettyApplicationCall(
private val requestMessage: Any,
) : BaseApplicationCall(application) {

@OptIn(InternalAPI::class)
public abstract override val request: NettyApplicationRequest
@OptIn(InternalAPI::class)
public abstract override val response: NettyApplicationResponse

internal lateinit var previousCallFinished: ChannelPromise
Expand Down Expand Up @@ -65,7 +63,6 @@ public abstract class NettyApplicationCall(

internal suspend fun finish() {
try {
@OptIn(InternalAPI::class)
response.ensureResponseSent()
} catch (cause: Throwable) {
finishedEvent.setFailure(cause)
Expand All @@ -89,14 +86,12 @@ public abstract class NettyApplicationCall(
}
}

@OptIn(InternalAPI::class)
private fun finishComplete() {
responseWriteJob.cancel()
request.close()
releaseRequestMessage()
}

@OptIn(InternalAPI::class)
internal fun dispose() {
response.close()
request.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
package io.ktor.server.netty.cio

import io.ktor.utils.io.*
import io.ktor.utils.io.errors.*
import io.netty.buffer.*
import io.netty.channel.*
import io.netty.handler.codec.http.*
import io.netty.util.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import java.lang.Integer.*
import kotlin.coroutines.*

private class ChannelEvent(val channel: ByteWriteChannel, val expectedLength: Long)

internal class RequestBodyHandler(
val context: ChannelHandlerContext
) : ChannelInboundHandlerAdapter(), CoroutineScope {
Expand All @@ -26,36 +30,48 @@ internal class RequestBodyHandler(

private val job = launch(context.executor().asCoroutineDispatcher(), start = CoroutineStart.LAZY) {
var current: ByteWriteChannel? = null
var expectedLength = -1L
var written = 0L
var upgraded = false

fun checkCurrentLengthAndClose() {
if (expectedLength == -1L || written == expectedLength) {
current?.close()
return
}

val message = "Unexpected length of the request body. Expected $expectedLength but was $written"
current?.close(IOException(message))
}

try {
while (true) {
@OptIn(ExperimentalCoroutinesApi::class)
val event = queue.tryReceive().getOrNull()
?: run { current?.flush(); queue.receiveCatching().getOrNull() }
?: break

when (event) {
is ByteBufHolder -> {
val channel = current
?: throw IllegalStateException("No current channel but received a byte buf")
processContent(channel, event)
val channel = current ?: error("No current channel but received a byte buf")
written += processContent(channel, event)

if (!upgraded && event is LastHttpContent) {
current.close()
checkCurrentLengthAndClose()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it redundant, since this check is in finally block?

Copy link
Member Author

@e5l e5l Jun 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope: the next message can be a new channel in the pipeline or upgrade

current = null
}
requestMoreEvents()
}
is ByteBuf -> {
val channel =
current ?: throw IllegalStateException("No current channel but received a byte buf")
processContent(channel, event)
val channel = current ?: error("No current channel but received a byte buf")
written += processContent(channel, event)
requestMoreEvents()
}
is ByteWriteChannel -> {
current?.close()
current = event
is ChannelEvent -> {
checkCurrentLengthAndClose()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


current = event.channel
expectedLength = event.expectedLength
written = 0L
}
is Upgrade -> {
upgraded = true
Expand All @@ -66,7 +82,7 @@ internal class RequestBodyHandler(
queue.close(t)
current?.close(t)
} finally {
current?.close()
checkCurrentLengthAndClose()
queue.close()
consumeAndReleaseQueue()
}
Expand All @@ -75,7 +91,7 @@ internal class RequestBodyHandler(
@OptIn(ExperimentalCoroutinesApi::class)
fun upgrade(): ByteReadChannel {
val result = queue.trySend(Upgrade)
if (result.isSuccess) return newChannel()
if (result.isSuccess) return newChannel(-1L)

if (queue.isClosedForSend) {
throw CancellationException("HTTP pipeline has been terminated.", result.exceptionOrNull())
Expand All @@ -87,10 +103,10 @@ internal class RequestBodyHandler(
)
}

fun newChannel(): ByteReadChannel {
val bc = ByteChannel()
tryOfferChannelOrToken(bc)
return bc
fun newChannel(contentLength: Long): ByteReadChannel {
val result = ByteChannel()
tryOfferChannelOrToken(ChannelEvent(result, contentLength))
return result
}

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down Expand Up @@ -121,18 +137,18 @@ internal class RequestBodyHandler(
}
}

private suspend fun processContent(current: ByteWriteChannel, event: ByteBufHolder) {
private suspend fun processContent(current: ByteWriteChannel, event: ByteBufHolder): Int {
try {
val buf = event.content()
copy(buf, current)
return copy(buf, current)
} finally {
event.release()
}
}

private suspend fun processContent(current: ByteWriteChannel, buf: ByteBuf) {
private suspend fun processContent(current: ByteWriteChannel, buf: ByteBuf): Int {
try {
copy(buf, current)
return copy(buf, current)
} finally {
buf.release()
}
Expand Down Expand Up @@ -160,12 +176,14 @@ internal class RequestBodyHandler(
}
}

private suspend fun copy(buf: ByteBuf, dst: ByteWriteChannel) {
private suspend fun copy(buf: ByteBuf, dst: ByteWriteChannel): Int {
val length = buf.readableBytes()
if (length > 0) {
val buffer = buf.internalNioBuffer(buf.readerIndex(), length)
dst.writeFully(buffer)
}

return max(length, 0)
}

private fun handleBytesRead(content: ReferenceCounted) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ internal class NettyHttp1Handler(
*/
internal val isChannelReadCompleted: AtomicBoolean = atomic(false)

@OptIn(InternalAPI::class)
override fun channelActive(context: ChannelHandlerContext) {
responseWriter = NettyHttpResponsePipeline(
context,
Expand Down Expand Up @@ -152,17 +151,19 @@ internal class NettyHttp1Handler(
)
}

private fun prepareRequestContentChannel(context: ChannelHandlerContext, message: HttpRequest): ByteReadChannel {
return when (message) {
is HttpContent -> {
val bodyHandler = context.pipeline().get(RequestBodyHandler::class.java)
bodyHandler.newChannel().also { bodyHandler.channelRead(context, message) }
}
else -> {
val bodyHandler = context.pipeline().get(RequestBodyHandler::class.java)
bodyHandler.newChannel()
}
private fun prepareRequestContentChannel(
context: ChannelHandlerContext,
message: HttpRequest
): ByteReadChannel {
val bodyHandler = context.pipeline().get(RequestBodyHandler::class.java)
val length = message.headers()[io.ktor.http.HttpHeaders.ContentLength]?.toLongOrNull() ?: -1
val result = bodyHandler.newChannel(length)

if (message is HttpContent) {
bodyHandler.channelRead(context, message)
}

return result
}

private fun callReadIfNeeded(context: ChannelHandlerContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ internal fun CoroutineScope.servletReader(input: ServletInputStream, contentLeng
}
}

private class ServletReader(val input: ServletInputStream, contentLength: Int) : ReadListener {
private class ServletReader(val input: ServletInputStream, val contentLength: Int) : ReadListener {
val channel = ByteChannel()
private val events = Channel<Unit>(2)

private val contentLength: Int = if (contentLength < 0) Int.MAX_VALUE else contentLength

suspend fun run() {
val buffer = ArrayPool.borrow()
try {
Expand Down Expand Up @@ -72,20 +70,24 @@ private class ServletReader(val input: ServletInputStream, contentLength: Int) :

channel.writeFully(buffer, 0, readCount)

if (contentLength < 0) continue

if (bodySize == contentLength) {
channel.close()
events.close()
break
}

if (bodySize > contentLength) {
val cause = IOException(
"Client provided more bytes than content length. Expected $contentLength but got $bodySize."
)
channel.close(cause)
events.close()
break
val message = if (bodySize > contentLength) {
"Client provided more bytes than content length. Expected $contentLength but got $bodySize."
} else {
"Client provided less bytes than content length. Expected $contentLength but got $bodySize."
}

val cause = IOException(message)
channel.close(cause)
events.close()
break
}
}

Expand All @@ -112,8 +114,7 @@ private class ServletReader(val input: ServletInputStream, contentLength: Int) :
private fun wrapException(cause: Throwable): Throwable? {
return when (cause) {
is EOFException -> null
is TimeoutException,
is IOException -> ChannelReadException(
is TimeoutException -> ChannelReadException(
"Cannot read from a servlet input stream",
exception = cause as Exception
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,10 +639,10 @@ abstract class ContentTestSuite<TEngine : ApplicationEngine, TConfiguration : Ap
call.receiveMultipart().forEachPart { part ->
when (part) {
is PartData.FormItem -> response.append("${part.name}=${part.value}\n")
is PartData.FileItem -> response.append(
"file:${part.name},${part.originalFileName}," +
"${part.streamProvider().bufferedReader().lineSequence().count()}\n"
)
is PartData.FileItem -> {
val lineSequence = part.streamProvider().bufferedReader().lineSequence()
response.append("file:${part.name},${part.originalFileName},${lineSequence.count()}\n")
}
is PartData.BinaryItem -> {
}
is PartData.BinaryChannelItem -> {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,49 @@ abstract class SustainabilityTestSuite<TEngine : ApplicationEngine, TConfigurati
outputStream.close()
}
}

@Test
fun testBodySmallerThanContentLength() {
var failCause: Throwable? = null
val result = Job()

createAndStartServer {
post("/") {
try {
println(call.receive<ByteArray>().size)
} catch (cause: Throwable) {
failCause = cause
} finally {
result.complete()
}

call.respond("OK")
}
}

socket {
val request = buildString {
append("POST / HTTP/1.1\r\n")
append("Content-Length: 4\r\n")
append("Content-Type: text/plain\r\n")
append("Connection: close\r\n")
append("Host: localhost\r\n")
append("\r\n")
append("ABC")
}

outputStream.writer().use {
it.write(request)
}
}

runBlocking {
result.join()
}

assertTrue(failCause != null)
assertTrue(failCause is IOException)
}
}

internal inline fun assertFails(block: () -> Unit) {
Expand Down
6 changes: 4 additions & 2 deletions ktor-utils/common/src/io/ktor/util/cio/Channels.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ public class ChannelWriteException(message: String = "Cannot write to a channel"
* An exception that is thrown when an IO error occurred during reading from the request channel.
* Usually it happens when a remote client closed the connection.
*/
public class ChannelReadException(message: String = "Cannot read from a channel", exception: Throwable) :
ChannelIOException(message, exception)
public class ChannelReadException(
message: String = "Cannot read from a channel",
exception: Throwable
) : ChannelIOException(message, exception)