diff --git a/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt b/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt index 59101d5d09..bc4fd51e1c 100644 --- a/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt +++ b/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt @@ -4,9 +4,13 @@ package io.ktor.tests.server.testing +import io.ktor.client.network.sockets.* +import io.ktor.client.plugins.* import io.ktor.client.plugins.websocket.* import io.ktor.client.request.* import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.http.content.* import io.ktor.server.application.* import io.ktor.server.config.* import io.ktor.server.response.* @@ -14,6 +18,7 @@ import io.ktor.server.routing.* import io.ktor.server.testing.* import io.ktor.server.websocket.* import io.ktor.util.* +import io.ktor.utils.io.* import io.ktor.websocket.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* @@ -281,6 +286,54 @@ class TestApplicationTestJvm { } assertEquals("WebSocket connection failed", error.message) } + + private fun testSocketTimeoutWrite(timeout: Long, expectException: Boolean) = testApplication { + routing { + post { + call.respond(HttpStatusCode.OK, call.request.receiveChannel().readRemaining().toString()) + } + } + + val clientWithTimeout = createClient { + install(HttpTimeout) { + socketTimeoutMillis = timeout + } + } + + val body = object : OutgoingContent.WriteChannelContent() { + override suspend fun writeTo(channel: ByteWriteChannel) { + channel.writeAvailable("Hello".toByteArray()) + channel.flush() + delay(300) + channel.writeAvailable("World".toByteArray()) + channel.flush() + } + } + + if (expectException) { + assertFailsWith { + clientWithTimeout.post("/") { + setBody(body) + } + } + } else { + clientWithTimeout.post("/") { + setBody(body) + }.apply { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun testSocketTimeoutWriteElapsed() { + testSocketTimeoutWrite(100, true) + } + + @Test + fun testSocketTimeoutWriteNotElapsed() { + testSocketTimeoutWrite(1000, false) + } } class TestClass(val value: Int) : Serializable diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt index c58f8f8221..bb0a521202 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt @@ -6,6 +6,7 @@ package io.ktor.server.testing import io.ktor.client.* import io.ktor.client.engine.* +import io.ktor.client.plugins.* import io.ktor.events.* import io.ktor.http.* import io.ktor.server.application.* @@ -200,7 +201,7 @@ public class TestApplicationEngine( setup: TestApplicationRequest.() -> Unit ): TestApplicationCall { val callJob = GlobalScope.async(coroutineContext) { - handleRequestNonBlocking(closeRequest, setup) + handleRequestNonBlocking(closeRequest, timeoutAttributes = null, setup) } return runBlocking { callJob.await() } @@ -208,6 +209,7 @@ public class TestApplicationEngine( internal suspend fun handleRequestNonBlocking( closeRequest: Boolean = true, + timeoutAttributes: HttpTimeoutConfig? = null, setup: TestApplicationRequest.() -> Unit ): TestApplicationCall { val job = Job(testEngineJob) @@ -217,6 +219,9 @@ public class TestApplicationEngine( setup = { processRequest(setup) }, context = Dispatchers.IOBridge + job ) + if (timeoutAttributes != null) { + call.attributes.put(timeoutAttributesKey, timeoutAttributes) + } val context = SupervisorJob(job) + CoroutineName("request") withContext(coroutineContext + context) { @@ -303,3 +308,5 @@ public fun TestApplicationEngine.cookiesSession(callback: () -> Unit) { callback() } } + +internal val timeoutAttributesKey = AttributeKey("TimeoutAttributes") diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt index afc0179248..61589a366d 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt @@ -25,6 +25,9 @@ public class TestApplicationResponse( call: TestApplicationCall, private val readResponse: Boolean = false ) : BaseApplicationResponse(call), CoroutineScope by call { + private val scope: CoroutineScope get() = this + + private val timeoutAttributes get() = call.attributes.getOrNull(timeoutAttributesKey) /** * Gets a response body text content. Could be blocking. Remains `null` until response appears. @@ -75,7 +78,6 @@ public class TestApplicationResponse( } @Suppress("DEPRECATION") - @OptIn(DelicateCoroutinesApi::class) override suspend fun responseChannel(): ByteWriteChannel { val result = ByteChannel(autoFlush = true) @@ -83,8 +85,12 @@ public class TestApplicationResponse( launchResponseJob(result) } - val job = GlobalScope.reader(responseJob ?: EmptyCoroutineContext) { - channel.copyAndClose(result, Long.MAX_VALUE) + val job = scope.reader(responseJob ?: EmptyCoroutineContext) { + val readJob = launch { + channel.copyAndClose(result, Long.MAX_VALUE) + } + + configureSocketTimeoutIfNeeded(timeoutAttributes, readJob) { channel.totalBytesRead } } if (responseJob == null) { diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt index a249c3015d..d622a2409c 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt @@ -4,7 +4,12 @@ package io.ktor.server.testing +import io.ktor.client.plugins.* +import io.ktor.client.request.* import io.ktor.http.* +import io.ktor.util.* +import io.ktor.utils.io.* +import kotlinx.coroutines.* /** * [on] function receiver object @@ -35,3 +40,44 @@ public fun TestApplicationResponse.contentType(): ContentType { val contentTypeHeader = requireNotNull(headers[HttpHeaders.ContentType]) return ContentType.parse(contentTypeHeader) } + +internal fun CoroutineScope.configureSocketTimeoutIfNeeded( + timeoutAttributes: HttpTimeoutConfig?, + job: Job, + extract: () -> Long +) { + val socketTimeoutMillis = timeoutAttributes?.socketTimeoutMillis + if (socketTimeoutMillis != null) { + socketTimeoutKiller(socketTimeoutMillis, job, extract) + } +} + +internal fun CoroutineScope.socketTimeoutKiller(socketTimeoutMillis: Long, job: Job, extract: () -> Long) { + val killJob = launch { + var cur = extract() + while (job.isActive) { + delay(socketTimeoutMillis) + val next = extract() + if (cur == next) { + throw io.ktor.network.sockets.SocketTimeoutException("Socket timeout elapsed") + } + cur = next + } + } + job.invokeOnCompletion { + killJob.cancel() + } +} + +@OptIn(InternalAPI::class) +internal fun Throwable.mapToKtor(data: HttpRequestData): Throwable { + return when { + this is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException(data, this) + cause?.rootCause is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException( + data, + cause?.rootCause + ) + + else -> this + } +} diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt index c7011f9404..38a2f24a69 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt @@ -6,6 +6,7 @@ package io.ktor.server.testing.client import io.ktor.client.call.* import io.ktor.client.engine.* +import io.ktor.client.plugins.* import io.ktor.client.request.* import io.ktor.http.* import io.ktor.http.content.* @@ -41,36 +42,40 @@ public class TestHttpClientEngine(override val config: TestHttpClientConfig) : H @OptIn(InternalAPI::class) override suspend fun execute(data: HttpRequestData): HttpResponseData { val callContext = callContext() - if (data.isUpgradeRequest()) { - val (testServerCall, session) = with(data) { - bridge.runWebSocketRequest(url.fullPath, headers, body, callContext) + try { + if (data.isUpgradeRequest()) { + val (testServerCall, session) = with(data) { + bridge.runWebSocketRequest(url.fullPath, headers, body, callContext) + } + return with(testServerCall.response) { + httpResponseData(session) + } } - return with(testServerCall.response) { - httpResponseData(session) - } - } - val testServerCall = with(data) { - runRequest(method, url, headers, body, url.protocol) + val testServerCall = with(data) { + runRequest(method, url, headers, body, url.protocol, data.getCapabilityOrNull(HttpTimeoutCapability)) + } + val response = testServerCall.response + val status = response.statusOrNotFound() + val headers = response.headers.allValues().takeUnless { it.isEmpty() } ?: Headers + .build { append(HttpHeaders.ContentLength, "0") } + val body = ByteReadChannel(response.byteContent ?: byteArrayOf()) + + val responseBody: Any = data.attributes.getOrNull(ResponseAdapterAttributeKey) + ?.adapt(data, status, headers, body, data.body, callContext) + ?: body + + return HttpResponseData( + status, + GMTDate(), + headers, + HttpProtocolVersion.HTTP_1_1, + responseBody, + callContext + ) + } catch (cause: Throwable) { + throw cause.mapToKtor(data) } - val response = testServerCall.response - val status = response.statusOrNotFound() - val headers = response.headers.allValues().takeUnless { it.isEmpty() } ?: Headers - .build { append(HttpHeaders.ContentLength, "0") } - val body = ByteReadChannel(response.byteContent ?: byteArrayOf()) - - val responseBody: Any = data.attributes.getOrNull(ResponseAdapterAttributeKey) - ?.adapt(data, status, headers, body, data.body, callContext) - ?: body - - return HttpResponseData( - status, - GMTDate(), - headers, - HttpProtocolVersion.HTTP_1_1, - responseBody, - callContext - ) } private suspend fun runRequest( @@ -78,9 +83,10 @@ public class TestHttpClientEngine(override val config: TestHttpClientConfig) : H url: Url, headers: Headers, content: OutgoingContent, - protocol: URLProtocol + protocol: URLProtocol, + timeoutAttributes: HttpTimeoutConfig? = null ): TestApplicationCall { - return app.handleRequestNonBlocking { + return app.handleRequestNonBlocking(timeoutAttributes = timeoutAttributes) { this.uri = url.fullPath this.port = url.port this.method = method @@ -88,7 +94,7 @@ public class TestHttpClientEngine(override val config: TestHttpClientConfig) : H this.protocol = protocol.name if (content !is OutgoingContent.NoContent) { - bodyChannel = content.toByteReadChannel() + bodyChannel = content.toByteReadChannel(timeoutAttributes) } } } @@ -125,16 +131,22 @@ public class TestHttpClientEngine(override val config: TestHttpClientConfig) : H } } - private fun OutgoingContent.toByteReadChannel(): ByteReadChannel = when (this) { - is OutgoingContent.NoContent -> ByteReadChannel.Empty - is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes()) - is OutgoingContent.ReadChannelContent -> readFrom() - is OutgoingContent.WriteChannelContent -> writer(coroutineContext) { - writeTo(channel) - }.channel - is OutgoingContent.ContentWrapper -> delegate().toByteReadChannel() - is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this) - } + private fun OutgoingContent.toByteReadChannel(timeoutAttributes: HttpTimeoutConfig?): ByteReadChannel = + when (this) { + is OutgoingContent.NoContent -> ByteReadChannel.Empty + is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes()) + is OutgoingContent.ReadChannelContent -> readFrom() + is OutgoingContent.WriteChannelContent -> writer(coroutineContext) { + val job = launch { + writeTo(channel) + } + + configureSocketTimeoutIfNeeded(timeoutAttributes, job) { channel.totalBytesWritten } + }.channel + + is OutgoingContent.ContentWrapper -> delegate().toByteReadChannel(timeoutAttributes) + is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this) + } private fun TestApplicationResponse.statusOrNotFound() = status() ?: HttpStatusCode.NotFound } diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt b/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt index 55af0804e3..c6e75bc7f7 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt @@ -5,6 +5,7 @@ package io.ktor.tests.server.testing import io.ktor.client.* +import io.ktor.client.network.sockets.* import io.ktor.client.plugins.* import io.ktor.client.request.* import io.ktor.client.statement.* @@ -21,6 +22,7 @@ import io.ktor.server.testing.* import io.ktor.server.testing.client.* import io.ktor.util.* import io.ktor.utils.io.* +import io.ktor.utils.io.core.* import kotlinx.coroutines.* import kotlin.coroutines.* import kotlin.test.* @@ -395,6 +397,49 @@ class TestApplicationTest { } } + private fun testSocketTimeoutRead(timeout: Long, expectException: Boolean) = testApplication { + routing { + get { + call.respond( + HttpStatusCode.OK, + object : OutgoingContent.WriteChannelContent() { + override suspend fun writeTo(channel: ByteWriteChannel) { + channel.writeAvailable("Hello".toByteArray()) + channel.flush() + delay(300) + } + } + ) + } + } + + val clientWithTimeout = createClient { + install(HttpTimeout) { + socketTimeoutMillis = timeout + } + } + + if (expectException) { + assertFailsWith { + clientWithTimeout.get("/") + } + } else { + clientWithTimeout.get("/").apply { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun testSocketTimeoutReadElapsed() { + testSocketTimeoutRead(100, true) + } + + @Test + fun testSocketTimeoutReadNotElapsed() { + testSocketTimeoutRead(1000, false) + } + class MyElement(val data: String) : CoroutineContext.Element { override val key: CoroutineContext.Key<*> get() = MyElement