Skip to content

Commit

Permalink
KTOR-4511 Ignore ByteReadChannel for server response and client request
Browse files Browse the repository at this point in the history
  • Loading branch information
rsinukov committed Jun 22, 2022
1 parent 30e05ae commit 084529f
Show file tree
Hide file tree
Showing 18 changed files with 254 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ public fun HttpClient.defaultTransformers() {
override fun readFrom(): ByteReadChannel = body
}
is OutgoingContent -> body
else -> null
else -> platformRequestDefaultTransform(contentType, context, body)
}

if (content != null) {
context.headers.remove(HttpHeaders.ContentType)
proceedWith(content)
Expand Down Expand Up @@ -111,7 +110,13 @@ public fun HttpClient.defaultTransformers() {
}
}

platformDefaultTransformers()
platformResponseDefaultTransformers()
}

internal expect fun HttpClient.platformDefaultTransformers()
internal expect fun platformRequestDefaultTransform(
contentType: ContentType?,
context: HttpRequestBuilder,
body: Any
): OutgoingContent?

internal expect fun HttpClient.platformResponseDefaultTransformers()
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,14 @@
package io.ktor.client.plugins

import io.ktor.client.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*

internal actual fun HttpClient.platformDefaultTransformers() {}
internal actual fun platformRequestDefaultTransform(
contentType: ContentType?,
context: HttpRequestBuilder,
body: Any
): OutgoingContent? = null

internal actual fun HttpClient.platformResponseDefaultTransformers() {}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
package io.ktor.client.plugins

import io.ktor.client.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.utils.io.jvm.javaio.*
import kotlinx.coroutines.*
import java.io.*

@OptIn(InternalAPI::class)
internal actual fun HttpClient.platformDefaultTransformers() {
internal actual fun HttpClient.platformResponseDefaultTransformers() {
responsePipeline.intercept(HttpResponsePipeline.Parse) { (info, body) ->
if (body !is ByteReadChannel) return@intercept
when (info.type) {
Expand All @@ -35,3 +38,16 @@ internal actual fun HttpClient.platformDefaultTransformers() {
}
}
}

internal actual fun platformRequestDefaultTransform(
contentType: ContentType?,
context: HttpRequestBuilder,
body: Any
): OutgoingContent? = when (body) {
is InputStream -> object : OutgoingContent.ReadChannelContent() {
override val contentLength = context.headers[HttpHeaders.ContentLength]?.toLong()
override val contentType: ContentType = contentType ?: ContentType.Application.OctetStream
override fun readFrom(): ByteReadChannel = body.toByteReadChannel()
}
else -> null
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
package io.ktor.client.plugins

import io.ktor.client.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*

internal actual fun HttpClient.platformDefaultTransformers() {
internal actual fun platformRequestDefaultTransform(
contentType: ContentType?,
context: HttpRequestBuilder,
body: Any
): OutgoingContent? = null

internal actual fun HttpClient.platformResponseDefaultTransformers() {
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.client.utils.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.serialization.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.utils.io.charsets.*
import kotlin.reflect.*

internal expect val DefaultIgnoredTypes: Set<KClass<*>>

/**
* A plugin that serves two primary purposes:
Expand Down Expand Up @@ -95,9 +99,12 @@ public class ContentNegotiation internal constructor(
val registrations = plugin.registrations
registrations.forEach { context.accept(it.contentTypeToSend) }

if (subject is OutgoingContent || DefaultIgnoredTypes.any { it.isInstance(payload) }) {
return@intercept
}
val contentType = context.contentType() ?: return@intercept

if (payload is Unit || payload is EmptyContent) {
if (payload is Unit) {
context.headers.remove(HttpHeaders.ContentType)
proceedWith(EmptyContent)
return@intercept
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2014-2021 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.plugins
Expand Down Expand Up @@ -84,7 +84,7 @@ class ContentNegotiationTests {
}

@Test
fun testIgnoresByteReadChannel() {
fun testReceiveByteReadChannel() {
val contentType = ContentType("testing", "a")
testWithEngine(MockEngine) {
setupWithContentNegotiation {
Expand All @@ -103,6 +103,33 @@ class ContentNegotiationTests {
}
}

@Test
fun testSendByteReadChannel() = testWithEngine(MockEngine) {
config {
install(ContentNegotiation) {
register(ContentType.Application.Json, TestContentConverter()) {
deserializeFn = { _, _, _ -> fail() }
serializeFn = { _, _, _, _ -> fail() }
}
}
engine {
addHandler {
val text = (it.body as OutgoingContent.ReadChannelContent).readFrom().readRemaining().readText()
respond(text)
}
}
}

test { client ->
val response = client.post("/post") {
val channel = ByteReadChannel("""{"x": 123}""".toByteArray())
contentType(ContentType.Application.Json)
setBody(channel)
}.bodyAsText()
assertEquals("""{"x": 123}""", response)
}
}

@Test
fun replaceContentTypeInRequestPipeline(): Unit = testWithEngine(MockEngine) {
val bodyContentType = ContentType("testing", "a")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.plugins.contentnegotiation

import io.ktor.http.content.*
import io.ktor.utils.io.*
import java.io.*
import kotlin.reflect.*

internal actual val DefaultIgnoredTypes: Set<KClass<*>> =
mutableSetOf(OutgoingContent::class, ByteReadChannel::class, ByteArray::class)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.plugins.contentnegotiation

import io.ktor.http.content.*
import io.ktor.utils.io.*
import java.io.*
import kotlin.reflect.*

internal actual val DefaultIgnoredTypes: Set<KClass<*>> =
mutableSetOf(OutgoingContent::class, ByteReadChannel::class, InputStream::class, ByteArray::class)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.plugins.contentnegotiation

import io.ktor.http.content.*
import io.ktor.utils.io.*
import java.io.*
import kotlin.reflect.*

internal actual val DefaultIgnoredTypes: Set<KClass<*>> =
mutableSetOf(OutgoingContent::class, ByteReadChannel::class, ByteArray::class)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.client.tests

import io.ktor.client.call.*
import io.ktor.client.engine.mock.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.client.tests.utils.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.utils.io.*
import java.io.*
import kotlin.test.*

/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

class DefaultTransformTest {

@Test
fun testSendInputStream() = testWithEngine(MockEngine) {
config {
engine {
addHandler {
val text = (it.body as OutgoingContent.ReadChannelContent).readFrom().readRemaining().readText()
respond(text)
}
}
}

test { client ->
val response = client.post("/post") {
val stream = ByteArrayInputStream("""{"x": 123}""".toByteArray())
contentType(ContentType.Application.Json)
setBody(stream)
}.bodyAsText()
assertEquals("""{"x": 123}""", response)
}
}

@Test
fun testReceiveInputStream() = testWithEngine(MockEngine) {
config {
engine {
addHandler {
respond("""{"x": 123}""")
}
}
}

test { client ->
val response = client.get("/").body<InputStream>()
assertEquals("""{"x": 123}""", response.bufferedReader().readText())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package io.ktor.server.http.content

import io.ktor.http.content.*
import io.ktor.server.application.*
import io.ktor.utils.io.*
import io.ktor.utils.io.jvm.javaio.*
import java.io.*

/**
Expand All @@ -21,5 +23,8 @@ internal actual fun platformTransformDefaultContent(
else -> null
}
}
is InputStream -> object : OutgoingContent.ReadChannelContent() {
override fun readFrom(): ByteReadChannel = value.toByteReadChannel()
}
else -> null
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class MDCProviderTest {

@Test
fun testLogErrorWithEmptyApplication() = testApplication {
val environment = createTestEnvironment { }
val environment = createTestEnvironment { }
val application = Application(environment)
assertNotNull(application.mdcProvider)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.plugins.contentnegotiation

import io.ktor.http.*
import io.ktor.utils.io.*
import java.io.*
import kotlin.reflect.*

internal actual val DefaultIgnoredTypes: Set<KClass<*>> =
mutableSetOf(HttpStatusCode::class, ByteReadChannel::class, InputStream::class)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import io.ktor.serialization.*
import io.ktor.util.*
import kotlin.reflect.*

internal expect val DefaultIgnoredTypes: Set<KClass<*>>

/**
* A configuration for the [ContentNegotiation] plugin.
*/
Expand All @@ -18,7 +20,7 @@ public class ContentNegotiationConfig : Configuration {
internal val acceptContributors = mutableListOf<AcceptHeaderContributor>()

@PublishedApi
internal val ignoredTypes: MutableSet<KClass<*>> = mutableSetOf(HttpStatusCode::class)
internal val ignoredTypes: MutableSet<KClass<*>> = DefaultIgnoredTypes.toMutableSet()

/**
* Checks that the `ContentType` header value of a response suits the `Accept` header value of a request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import io.ktor.utils.io.charsets.*
private val NOT_ACCEPTABLE = HttpStatusCodeContent(HttpStatusCode.NotAcceptable)

internal fun PluginBuilder<ContentNegotiationConfig>.convertResponseBody() = onCallRespond { call, subject ->
if (subject is OutgoingContent || subject::class in pluginConfig.ignoredTypes) return@onCallRespond
if (subject is OutgoingContent || pluginConfig.ignoredTypes.any { it.isInstance(subject) }) {
return@onCallRespond
}
if (call.response.responseType == null) return@onCallRespond

val registrations = pluginConfig.registrations
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.plugins.contentnegotiation

import io.ktor.http.*
import io.ktor.utils.io.*
import kotlin.reflect.*

internal actual val DefaultIgnoredTypes: Set<KClass<*>> =
mutableSetOf(HttpStatusCode::class, ByteReadChannel::class)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package io.ktor.server.plugins

import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.serialization.*
Expand Down Expand Up @@ -105,4 +106,20 @@ class ContentNegotiationTest {
val response = client.get("/").body<ByteArray>()
assertContentEquals("test".toByteArray(), response)
}

@Test
fun testRespondInputStream() = testApplication {
application {
routing {
install(ContentNegotiation) {
register(ContentType.Application.Json, alwaysFailingConverter)
}
get("/") {
call.respond(ByteArrayInputStream("""{"x": 123}""".toByteArray()))
}
}
}
val response = client.get("/").bodyAsText()
assertEquals("""{"x": 123}""", response)
}
}

0 comments on commit 084529f

Please sign in to comment.