Skip to content

Commit

Permalink
KTOR-4770 Fix ignoreBody for Server ContentNegotiation Request (#3134)
Browse files Browse the repository at this point in the history
  • Loading branch information
e5l committed Aug 29, 2022
1 parent cc7f6fc commit af882b3
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 41 deletions.
15 changes: 6 additions & 9 deletions ktor-http/common/src/io/ktor/http/content/Versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package io.ktor.http.content
import io.ktor.http.*
import io.ktor.util.*
import io.ktor.util.date.*
import kotlin.native.concurrent.*

/**
* Specifies a key for the VersionList extension property for [OutgoingContent].
Expand Down Expand Up @@ -77,16 +76,14 @@ public data class LastModifiedVersion(val lastModified: GMTDate) : Version {
* [VersionCheckResult.PRECONDITION_FAILED] for `If-Unmodified-Since`
*/
override fun check(requestHeaders: Headers): VersionCheckResult {
requestHeaders.getAll(HttpHeaders.IfModifiedSince)?.parseDates()?.let { dates ->
if (!ifModifiedSince(dates)) {
return VersionCheckResult.NOT_MODIFIED
}
val modifiedSince = requestHeaders.getAll(HttpHeaders.IfModifiedSince)?.parseDates()
if (modifiedSince != null && !ifModifiedSince(modifiedSince)) {
return VersionCheckResult.NOT_MODIFIED
}

requestHeaders.getAll(HttpHeaders.IfUnmodifiedSince)?.parseDates()?.let { dates ->
if (!ifUnmodifiedSince(dates)) {
return VersionCheckResult.PRECONDITION_FAILED
}
val unmodifiedSince = requestHeaders.getAll(HttpHeaders.IfUnmodifiedSince)?.parseDates()
if (unmodifiedSince != null && !ifUnmodifiedSince(unmodifiedSince)) {
return VersionCheckResult.PRECONDITION_FAILED
}

return VersionCheckResult.OK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal fun PluginBuilder<ContentNegotiationConfig>.convertRequestBody() {
val registrations = pluginConfig.registrations
val requestedType = call.receiveType

if (requestedType.type == ByteReadChannel::class) return@onCallReceive
if (requestedType.type in pluginConfig.ignoredTypes) return@onCallReceive

transformBody { body: ByteReadChannel ->
val requestContentType = try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ internal fun PluginBuilder<ContentNegotiationConfig>.convertResponseBody() = onC
if (subject is OutgoingContent || pluginConfig.ignoredTypes.any { it.isInstance(subject) }) {
return@onCallRespond
}
if (call.response.responseType == null) return@onCallRespond

val responseType = call.response.responseType ?: return@onCallRespond
val registrations = pluginConfig.registrations
val checkAcceptHeader = pluginConfig.checkAcceptHeaderCompliance

Expand Down Expand Up @@ -52,7 +52,7 @@ internal fun PluginBuilder<ContentNegotiationConfig>.convertResponseBody() = onC
it.converter.serializeNullable(
contentType = contentType ?: it.contentType,
charset = acceptCharset ?: Charsets.UTF_8,
typeInfo = call.response.responseType!!,
typeInfo = responseType,
value = subject.takeIf { it != NullBody }
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package io.ktor.server.plugins.contentnegotiation

import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.*
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.testing.*
import io.ktor.util.reflect.*
import io.ktor.utils.io.*
import io.ktor.utils.io.charsets.*
import kotlin.test.*

/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

class RequestConverterTest {

@Test
fun testIgnoreType() = testApplication {
var used = false

install(ContentNegotiation) {
ignoreType<NonSerializableClass>()
register(
ContentType.Application.Json,
object : ContentConverter {
override suspend fun deserialize(
charset: Charset,
typeInfo: TypeInfo,
content: ByteReadChannel
): Any? {
used = true
return null
}
}
)
}

routing {
post("/foo") {
val result: String = try {
call.receive<NonSerializableClass>()
"OK"
} catch (cause: Throwable) {
cause.message ?: cause.toString()
}
call.respondText(result)
}
post("/bar") {
val result: String = try {
call.receive<SerializableClass>()
"OK"
} catch (cause: Throwable) {
cause.message ?: cause.toString()
}
call.respondText(result)
}
}

val responseFoo = client.post("/foo") {
contentType(ContentType.Application.Json)
}

assertEquals(
"Cannot transform this request's content to io.ktor.server.plugins.contentnegotiation.NonSerializableClass",
responseFoo.bodyAsText()
)
assertFalse(used)

val responseBar = client.post("/bar") {
contentType(ContentType.Application.Json)
}

assertEquals(
"No suitable converter found for TypeInfo(" +
"type=class io.ktor.server.plugins.contentnegotiation.SerializableClass, " +
"reifiedType=class io.ktor.server.plugins.contentnegotiation.SerializableClass, " +
"kotlinType=io.ktor.server.plugins.contentnegotiation.SerializableClass" +
")",
responseBar.bodyAsText()
)
assertTrue(used)
}
}

internal class NonSerializableClass

internal class SerializableClass
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,14 @@ public class TestHttpClientEngine(override val config: TestHttpClientConfig) : H
callContext()
)

@OptIn(InternalAPI::class)
internal fun TestApplicationRequest.appendRequestHeaders(
headers: Headers,
content: OutgoingContent
) {
headers.flattenForEach { name, value ->
if (HttpHeaders.ContentLength == name) return@flattenForEach // set later
if (HttpHeaders.ContentType == name) return@flattenForEach // set later
addHeader(name, value)
}

content.headers.flattenForEach { name, value ->
if (HttpHeaders.ContentLength == name) return@flattenForEach // TODO: throw exception for unsafe header?
if (HttpHeaders.ContentType == name) return@flattenForEach
mergeHeaders(headers, content) { name, value ->
addHeader(name, value)
}

val contentLength = headers[HttpHeaders.ContentLength] ?: content.contentLength?.toString()
val contentType = headers[HttpHeaders.ContentType] ?: content.contentType?.toString()

contentLength?.let { addHeader(HttpHeaders.ContentLength, it) }
contentType?.let { addHeader(HttpHeaders.ContentType, it) }
}

override fun close() {
Expand Down
21 changes: 10 additions & 11 deletions ktor-shared/ktor-serialization/common/src/ContentConverter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,18 @@ public suspend fun List<ContentConverter>.deserialize(
// 1. there is no suitable converter
// 2. result of deserialization is null
// We can differentiate these cases by checking if body was consumed or not
val result = this.asFlow()
.filter { !body.isClosedForRead }
.map { converter ->
converter.deserialize(
charset = charset,
typeInfo = typeInfo,
content = body
)
}
.firstOrNull { it != null }
val result = asFlow().map { converter ->
converter.deserialize(
charset = charset,
typeInfo = typeInfo,
content = body
)
}.firstOrNull { it != null || body.isClosedForRead }

return when {
result != null -> result
!body.isClosedForRead -> body
else -> NullBody
typeInfo.kotlinType?.isMarkedNullable == true -> NullBody
else -> throw ContentConvertException("No suitable converter found for $typeInfo")
}
}
4 changes: 1 addition & 3 deletions ktor-utils/common/src/io/ktor/util/date/GMTDateParser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

package io.ktor.util.date

import io.ktor.util.*

/**
* Build [GMTDate] parser using [pattern] string.
*
Expand Down Expand Up @@ -116,7 +114,7 @@ internal class GMTDateBuilder {
lateinit var month: Month
var year: Int? = null

public fun build(): GMTDate = GMTDate(seconds!!, minutes!!, hours!!, dayOfMonth!!, month, year!!)
fun build(): GMTDate = GMTDate(seconds!!, minutes!!, hours!!, dayOfMonth!!, month, year!!)
}

/**
Expand Down

0 comments on commit af882b3

Please sign in to comment.