Skip to content

Commit

Permalink
KTOR-6909 Add SocketTimeout for Test Engine (#4021)
Browse files Browse the repository at this point in the history
  • Loading branch information
marychatte committed Apr 26, 2024
1 parent 5f27f30 commit e4c4c1c
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 44 deletions.
Expand Up @@ -4,16 +4,21 @@

package io.ktor.tests.server.testing

import io.ktor.client.network.sockets.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.server.application.*
import io.ktor.server.config.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.testing.*
import io.ktor.server.websocket.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.websocket.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
Expand Down Expand Up @@ -281,6 +286,54 @@ class TestApplicationTestJvm {
}
assertEquals("WebSocket connection failed", error.message)
}

private fun testSocketTimeoutWrite(timeout: Long, expectException: Boolean) = testApplication {
routing {
post {
call.respond(HttpStatusCode.OK, call.request.receiveChannel().readRemaining().toString())
}
}

val clientWithTimeout = createClient {
install(HttpTimeout) {
socketTimeoutMillis = timeout
}
}

val body = object : OutgoingContent.WriteChannelContent() {
override suspend fun writeTo(channel: ByteWriteChannel) {
channel.writeAvailable("Hello".toByteArray())
channel.flush()
delay(300)
channel.writeAvailable("World".toByteArray())
channel.flush()
}
}

if (expectException) {
assertFailsWith<SocketTimeoutException> {
clientWithTimeout.post("/") {
setBody(body)
}
}
} else {
clientWithTimeout.post("/") {
setBody(body)
}.apply {
assertEquals(HttpStatusCode.OK, status)
}
}
}

@Test
fun testSocketTimeoutWriteElapsed() {
testSocketTimeoutWrite(100, true)
}

@Test
fun testSocketTimeoutWriteNotElapsed() {
testSocketTimeoutWrite(1000, false)
}
}

class TestClass(val value: Int) : Serializable
Expand Down
Expand Up @@ -6,6 +6,7 @@ package io.ktor.server.testing

import io.ktor.client.*
import io.ktor.client.engine.*
import io.ktor.client.plugins.*
import io.ktor.events.*
import io.ktor.http.*
import io.ktor.server.application.*
Expand Down Expand Up @@ -200,14 +201,15 @@ public class TestApplicationEngine(
setup: TestApplicationRequest.() -> Unit
): TestApplicationCall {
val callJob = GlobalScope.async(coroutineContext) {
handleRequestNonBlocking(closeRequest, setup)
handleRequestNonBlocking(closeRequest, timeoutAttributes = null, setup)
}

return runBlocking { callJob.await() }
}

internal suspend fun handleRequestNonBlocking(
closeRequest: Boolean = true,
timeoutAttributes: HttpTimeoutConfig? = null,
setup: TestApplicationRequest.() -> Unit
): TestApplicationCall {
val job = Job(testEngineJob)
Expand All @@ -217,6 +219,9 @@ public class TestApplicationEngine(
setup = { processRequest(setup) },
context = Dispatchers.IOBridge + job
)
if (timeoutAttributes != null) {
call.attributes.put(timeoutAttributesKey, timeoutAttributes)
}

val context = SupervisorJob(job) + CoroutineName("request")
withContext(coroutineContext + context) {
Expand Down Expand Up @@ -303,3 +308,5 @@ public fun TestApplicationEngine.cookiesSession(callback: () -> Unit) {
callback()
}
}

internal val timeoutAttributesKey = AttributeKey<HttpTimeoutConfig>("TimeoutAttributes")
Expand Up @@ -25,6 +25,9 @@ public class TestApplicationResponse(
call: TestApplicationCall,
private val readResponse: Boolean = false
) : BaseApplicationResponse(call), CoroutineScope by call {
private val scope: CoroutineScope get() = this

private val timeoutAttributes get() = call.attributes.getOrNull(timeoutAttributesKey)

/**
* Gets a response body text content. Could be blocking. Remains `null` until response appears.
Expand Down Expand Up @@ -75,16 +78,19 @@ public class TestApplicationResponse(
}

@Suppress("DEPRECATION")
@OptIn(DelicateCoroutinesApi::class)
override suspend fun responseChannel(): ByteWriteChannel {
val result = ByteChannel(autoFlush = true)

if (readResponse) {
launchResponseJob(result)
}

val job = GlobalScope.reader(responseJob ?: EmptyCoroutineContext) {
channel.copyAndClose(result, Long.MAX_VALUE)
val job = scope.reader(responseJob ?: EmptyCoroutineContext) {
val readJob = launch {
channel.copyAndClose(result, Long.MAX_VALUE)
}

configureSocketTimeoutIfNeeded(timeoutAttributes, readJob) { channel.totalBytesRead }
}

if (responseJob == null) {
Expand Down
Expand Up @@ -4,7 +4,12 @@

package io.ktor.server.testing

import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.util.*
import io.ktor.utils.io.*
import kotlinx.coroutines.*

/**
* [on] function receiver object
Expand Down Expand Up @@ -35,3 +40,44 @@ public fun TestApplicationResponse.contentType(): ContentType {
val contentTypeHeader = requireNotNull(headers[HttpHeaders.ContentType])
return ContentType.parse(contentTypeHeader)
}

internal fun CoroutineScope.configureSocketTimeoutIfNeeded(
timeoutAttributes: HttpTimeoutConfig?,
job: Job,
extract: () -> Long
) {
val socketTimeoutMillis = timeoutAttributes?.socketTimeoutMillis
if (socketTimeoutMillis != null) {
socketTimeoutKiller(socketTimeoutMillis, job, extract)
}
}

internal fun CoroutineScope.socketTimeoutKiller(socketTimeoutMillis: Long, job: Job, extract: () -> Long) {
val killJob = launch {
var cur = extract()
while (job.isActive) {
delay(socketTimeoutMillis)
val next = extract()
if (cur == next) {
throw io.ktor.network.sockets.SocketTimeoutException("Socket timeout elapsed")
}
cur = next
}
}
job.invokeOnCompletion {
killJob.cancel()
}
}

@OptIn(InternalAPI::class)
internal fun Throwable.mapToKtor(data: HttpRequestData): Throwable {
return when {
this is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException(data, this)
cause?.rootCause is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException(
data,
cause?.rootCause
)

else -> this
}
}
Expand Up @@ -6,6 +6,7 @@ package io.ktor.server.testing.client

import io.ktor.client.call.*
import io.ktor.client.engine.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*
Expand Down Expand Up @@ -41,54 +42,59 @@ public class TestHttpClientEngine(override val config: TestHttpClientConfig) : H
@OptIn(InternalAPI::class)
override suspend fun execute(data: HttpRequestData): HttpResponseData {
val callContext = callContext()
if (data.isUpgradeRequest()) {
val (testServerCall, session) = with(data) {
bridge.runWebSocketRequest(url.fullPath, headers, body, callContext)
try {
if (data.isUpgradeRequest()) {
val (testServerCall, session) = with(data) {
bridge.runWebSocketRequest(url.fullPath, headers, body, callContext)
}
return with(testServerCall.response) {
httpResponseData(session)
}
}
return with(testServerCall.response) {
httpResponseData(session)
}
}

val testServerCall = with(data) {
runRequest(method, url, headers, body, url.protocol)
val testServerCall = with(data) {
runRequest(method, url, headers, body, url.protocol, data.getCapabilityOrNull(HttpTimeoutCapability))
}
val response = testServerCall.response
val status = response.statusOrNotFound()
val headers = response.headers.allValues().takeUnless { it.isEmpty() } ?: Headers
.build { append(HttpHeaders.ContentLength, "0") }
val body = ByteReadChannel(response.byteContent ?: byteArrayOf())

val responseBody: Any = data.attributes.getOrNull(ResponseAdapterAttributeKey)
?.adapt(data, status, headers, body, data.body, callContext)
?: body

return HttpResponseData(
status,
GMTDate(),
headers,
HttpProtocolVersion.HTTP_1_1,
responseBody,
callContext
)
} catch (cause: Throwable) {
throw cause.mapToKtor(data)
}
val response = testServerCall.response
val status = response.statusOrNotFound()
val headers = response.headers.allValues().takeUnless { it.isEmpty() } ?: Headers
.build { append(HttpHeaders.ContentLength, "0") }
val body = ByteReadChannel(response.byteContent ?: byteArrayOf())

val responseBody: Any = data.attributes.getOrNull(ResponseAdapterAttributeKey)
?.adapt(data, status, headers, body, data.body, callContext)
?: body

return HttpResponseData(
status,
GMTDate(),
headers,
HttpProtocolVersion.HTTP_1_1,
responseBody,
callContext
)
}

private suspend fun runRequest(
method: HttpMethod,
url: Url,
headers: Headers,
content: OutgoingContent,
protocol: URLProtocol
protocol: URLProtocol,
timeoutAttributes: HttpTimeoutConfig? = null
): TestApplicationCall {
return app.handleRequestNonBlocking {
return app.handleRequestNonBlocking(timeoutAttributes = timeoutAttributes) {
this.uri = url.fullPath
this.port = url.port
this.method = method
appendRequestHeaders(headers, content)
this.protocol = protocol.name

if (content !is OutgoingContent.NoContent) {
bodyChannel = content.toByteReadChannel()
bodyChannel = content.toByteReadChannel(timeoutAttributes)
}
}
}
Expand Down Expand Up @@ -125,16 +131,22 @@ public class TestHttpClientEngine(override val config: TestHttpClientConfig) : H
}
}

private fun OutgoingContent.toByteReadChannel(): ByteReadChannel = when (this) {
is OutgoingContent.NoContent -> ByteReadChannel.Empty
is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes())
is OutgoingContent.ReadChannelContent -> readFrom()
is OutgoingContent.WriteChannelContent -> writer(coroutineContext) {
writeTo(channel)
}.channel
is OutgoingContent.ContentWrapper -> delegate().toByteReadChannel()
is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this)
}
private fun OutgoingContent.toByteReadChannel(timeoutAttributes: HttpTimeoutConfig?): ByteReadChannel =
when (this) {
is OutgoingContent.NoContent -> ByteReadChannel.Empty
is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes())
is OutgoingContent.ReadChannelContent -> readFrom()
is OutgoingContent.WriteChannelContent -> writer(coroutineContext) {
val job = launch {
writeTo(channel)
}

configureSocketTimeoutIfNeeded(timeoutAttributes, job) { channel.totalBytesWritten }
}.channel

is OutgoingContent.ContentWrapper -> delegate().toByteReadChannel(timeoutAttributes)
is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this)
}

private fun TestApplicationResponse.statusOrNotFound() = status() ?: HttpStatusCode.NotFound
}

0 comments on commit e4c4c1c

Please sign in to comment.