Skip to content

Commit

Permalink
KTOR-4164 Fix ClassCastException when development mode is on
Browse files Browse the repository at this point in the history
  • Loading branch information
rsinukov committed Jun 27, 2022
1 parent 7c32f4e commit a85f88e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
1 change: 1 addition & 0 deletions ktor-server/ktor-server-host-common/build.gradle.kts
Expand Up @@ -11,6 +11,7 @@ kotlin.sourceSets {

jvmTest {
dependencies {
implementation(project(":ktor-server:ktor-server-cio"))
implementation(project(":ktor-server:ktor-server-test-host"))
implementation(project(":ktor-server:ktor-server-test-suites"))
api(project(":ktor-server:ktor-server-core", configuration = "testOutput"))
Expand Down
Expand Up @@ -36,11 +36,16 @@ public class ApplicationEngineEnvironmentReloading(
override val connectors: List<EngineConnectorConfig>,
internal val modules: List<Application.() -> Unit>,
internal val watchPaths: List<String> = emptyList(),
override val parentCoroutineContext: CoroutineContext = EmptyCoroutineContext,
parentCoroutineContext: CoroutineContext = EmptyCoroutineContext,
override val rootPath: String = "",
override val developmentMode: Boolean = true
) : ApplicationEngineEnvironment {

override val parentCoroutineContext: CoroutineContext = when {
developmentMode -> parentCoroutineContext + ClassLoaderAwareContinuationInterceptor
else -> parentCoroutineContext
}

public constructor(
classLoader: ClassLoader,
log: Logger,
Expand Down Expand Up @@ -365,3 +370,20 @@ public class ApplicationEngineEnvironmentReloading(

public companion object
}

private object ClassLoaderAwareContinuationInterceptor : ContinuationInterceptor {
override val key: CoroutineContext.Key<*> =
object : CoroutineContext.Key<ClassLoaderAwareContinuationInterceptor> {}

override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> {
val classLoader = Thread.currentThread().contextClassLoader
return object : Continuation<T> {
override val context: CoroutineContext = continuation.context

override fun resumeWith(result: Result<T>) {
Thread.currentThread().contextClassLoader = classLoader
continuation.resumeWith(result)
}
}
}
}
Expand Up @@ -9,6 +9,7 @@ package io.ktor.tests.hosts
import com.typesafe.config.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.cio.*
import io.ktor.server.config.*
import io.ktor.server.engine.*
import io.ktor.server.response.*
Expand All @@ -17,6 +18,8 @@ import io.ktor.server.testing.*
import io.ktor.util.*
import kotlinx.coroutines.*
import org.slf4j.helpers.*
import java.io.*
import kotlin.coroutines.*
import kotlin.reflect.*
import kotlin.reflect.jvm.*
import kotlin.test.*
Expand Down Expand Up @@ -45,6 +48,38 @@ class ApplicationEngineEnvironmentReloadingTests {
environment.stop()
}

@Test
fun `test class loader in module function with launch`() {
var error: Throwable? = null
val exceptionHandler: CoroutineContext = object : CoroutineExceptionHandler {
override val key: CoroutineContext.Key<*> = CoroutineExceptionHandler.Key
override fun handleException(context: CoroutineContext, exception: Throwable) {
error = exception
}
}
val server = embeddedServer(CIO, applicationEngineEnvironment {
parentCoroutineContext = exceptionHandler
developmentMode = true
module {
launch {
val byteArrayInputStream = ByteArrayOutputStream()
val objectOutputStream = ObjectOutputStream(byteArrayInputStream)
objectOutputStream.writeObject(TestClass(123))
objectOutputStream.flush()
objectOutputStream.close()

val ois = TestObjectInputStream(ByteArrayInputStream(byteArrayInputStream.toByteArray()))
val test = ois.readObject()
test as TestClass
}
}
}).start(false)

Thread.sleep(3000)
assertNull(error)
server.stop()
}

@Test
fun `top level extension function as module function reloading stress`() {
val environment = applicationEngineEnvironment {
Expand Down Expand Up @@ -546,3 +581,18 @@ fun Application.topLevelWithDefaultArg(testing: Boolean = false) {
fun Application.topLevelWithJvmOverloads(testing: Boolean = false) {
attributes.put(ApplicationEngineEnvironmentReloadingTests.TestKey, "topLevelWithJvmOverloads")
}

class TestClass(val value: Int) : Serializable

class TestObjectInputStream(input: InputStream) : ObjectInputStream(input) {
override fun resolveClass(desc: ObjectStreamClass?): Class<*> {
val name = desc?.name
val loader = Thread.currentThread().contextClassLoader

return try {
Class.forName(name, false, loader)
} catch (e: ClassNotFoundException) {
super.resolveClass(desc)
}
}
}

0 comments on commit a85f88e

Please sign in to comment.