Skip to content

Commit

Permalink
Properly recover exceptions when they are constructed from 'Throwable… (
Browse files Browse the repository at this point in the history
#3731)

* Properly recover exceptions when they are constructed from 'Throwable(cause)' constructor.

And restore the original behaviour. After #1631 this constructor's recovery mechanism was broken because 'Throwable(cause)' changes the message to be equal to 'cause.toString()', which isn't equal to the original message.

Also, make reflective constructor choice undependable from source-code order

Fixes #3714
  • Loading branch information
qwwdfsad committed May 3, 2023
1 parent 298419f commit 25a3553
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .idea/codeStyles/Project.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 37 additions & 31 deletions kotlinx-coroutines-core/jvm/src/internal/ExceptionsConstructor.kt
Expand Up @@ -32,42 +32,48 @@ internal fun <E : Throwable> tryCopyException(exception: E): E? {

private fun <E : Throwable> createConstructor(clz: Class<E>): Ctor {
val nullResult: Ctor = { null } // Pre-cache class
// Skip reflective copy if an exception has additional fields (that are usually populated in user-defined constructors)
// Skip reflective copy if an exception has additional fields (that are typically populated in user-defined constructors)
if (throwableFields != clz.fieldsCountOrDefault(0)) return nullResult
/*
* Try to reflectively find constructor(), constructor(message, cause), constructor(cause) or constructor(message).
* Exceptions are shared among coroutines, so we should copy exception before recovering current stacktrace.
*/
val constructors = clz.constructors.sortedByDescending { it.parameterTypes.size }
for (constructor in constructors) {
val result = createSafeConstructor(constructor)
if (result != null) return result
}
return nullResult
}

private fun createSafeConstructor(constructor: Constructor<*>): Ctor? {
val p = constructor.parameterTypes
return when (p.size) {
2 -> when {
p[0] == String::class.java && p[1] == Throwable::class.java ->
safeCtor { e -> constructor.newInstance(e.message, e) as Throwable }
else -> null
* Try to reflectively find constructor(message, cause), constructor(message), constructor(cause), or constructor(),
* in that order of priority.
* Exceptions are shared among coroutines, so we should copy exception before recovering current stacktrace.
*
* By default, Java's reflection iterates over ctors in the source-code order and the sorting is stable, so we can
* not rely on the order of iteration. Instead, we assign a unique priority to each ctor type.
*/
return clz.constructors.map { constructor ->
val p = constructor.parameterTypes
when (p.size) {
2 -> when {
p[0] == String::class.java && p[1] == Throwable::class.java ->
safeCtor { e -> constructor.newInstance(e.message, e) as Throwable } to 3
else -> null to -1
}
1 -> when (p[0]) {
String::class.java ->
safeCtor { e -> (constructor.newInstance(e.message) as Throwable).also { it.initCause(e) } } to 2
Throwable::class.java ->
safeCtor { e -> constructor.newInstance(e) as Throwable } to 1
else -> null to -1
}
0 -> safeCtor { e -> (constructor.newInstance() as Throwable).also { it.initCause(e) } } to 0
else -> null to -1
}
1 -> when (p[0]) {
Throwable::class.java ->
safeCtor { e -> constructor.newInstance(e) as Throwable }
String::class.java ->
safeCtor { e -> (constructor.newInstance(e.message) as Throwable).also { it.initCause(e) } }
else -> null
}
0 -> safeCtor { e -> (constructor.newInstance() as Throwable).also { it.initCause(e) } }
else -> null
}
}.maxByOrNull(Pair<*, Int>::second)?.first ?: nullResult
}

private inline fun safeCtor(crossinline block: (Throwable) -> Throwable): Ctor =
{ e -> runCatching { block(e) }.getOrNull() }
private fun safeCtor(block: (Throwable) -> Throwable): Ctor = { e ->
runCatching {
val result = block(e)
/*
* Verify that the new exception has the same message as the original one (bail out if not, see #1631)
* or if the new message complies the contract from `Throwable(cause).message` contract.
*/
if (e.message != result.message && result.message != e.toString()) null
else result
}.getOrNull()
}

private fun Class<*>.fieldsCountOrDefault(defaultValue: Int) =
kotlin.runCatching { fieldsCount() }.getOrDefault(defaultValue)
Expand Down
21 changes: 6 additions & 15 deletions kotlinx-coroutines-core/jvm/src/internal/StackTraceRecovery.kt
Expand Up @@ -33,16 +33,16 @@ private val stackTraceRecoveryClassName = runCatching {
internal actual fun <E : Throwable> recoverStackTrace(exception: E): E {
if (!RECOVER_STACK_TRACES) return exception
// No unwrapping on continuation-less path: exception is not reported multiple times via slow paths
val copy = tryCopyAndVerify(exception) ?: return exception
val copy = tryCopyException(exception) ?: return exception
return copy.sanitizeStackTrace()
}

private fun <E : Throwable> E.sanitizeStackTrace(): E {
val stackTrace = stackTrace
val size = stackTrace.size
val lastIntrinsic = stackTrace.frameIndex(stackTraceRecoveryClassName)
val lastIntrinsic = stackTrace.indexOfLast { stackTraceRecoveryClassName == it.className }
val startIndex = lastIntrinsic + 1
val endIndex = stackTrace.frameIndex(baseContinuationImplClassName)
val endIndex = stackTrace.firstFrameIndex(baseContinuationImplClassName)
val adjustment = if (endIndex == -1) 0 else size - endIndex
val trace = Array(size - lastIntrinsic - adjustment) {
if (it == 0) {
Expand Down Expand Up @@ -70,7 +70,7 @@ private fun <E : Throwable> recoverFromStackFrame(exception: E, continuation: Co
val (cause, recoveredStacktrace) = exception.causeAndStacktrace()

// Try to create an exception of the same type and get stacktrace from continuation
val newException = tryCopyAndVerify(cause) ?: return exception
val newException = tryCopyException(cause) ?: return exception
// Update stacktrace
val stacktrace = createStackTrace(continuation)
if (stacktrace.isEmpty()) return exception
Expand All @@ -82,14 +82,6 @@ private fun <E : Throwable> recoverFromStackFrame(exception: E, continuation: Co
return createFinalException(cause, newException, stacktrace)
}

private fun <E : Throwable> tryCopyAndVerify(exception: E): E? {
val newException = tryCopyException(exception) ?: return null
// Verify that the new exception has the same message as the original one (bail out if not, see #1631)
// CopyableThrowable has control over its message and thus can modify it the way it wants
if (exception !is CopyableThrowable<*> && newException.message != exception.message) return null
return newException
}

/*
* Here we partially copy original exception stackTrace to make current one much prettier.
* E.g. for
Expand All @@ -109,7 +101,7 @@ private fun <E : Throwable> tryCopyAndVerify(exception: E): E? {
private fun <E : Throwable> createFinalException(cause: E, result: E, resultStackTrace: ArrayDeque<StackTraceElement>): E {
resultStackTrace.addFirst(ARTIFICIAL_FRAME)
val causeTrace = cause.stackTrace
val size = causeTrace.frameIndex(baseContinuationImplClassName)
val size = causeTrace.firstFrameIndex(baseContinuationImplClassName)
if (size == -1) {
result.stackTrace = resultStackTrace.toTypedArray()
return result
Expand Down Expand Up @@ -157,7 +149,6 @@ private fun mergeRecoveredTraces(recoveredStacktrace: Array<StackTraceElement>,
}
}

@Suppress("NOTHING_TO_INLINE")
internal actual suspend inline fun recoverAndThrow(exception: Throwable): Nothing {
if (!RECOVER_STACK_TRACES) throw exception
suspendCoroutineUninterceptedOrReturn<Nothing> {
Expand Down Expand Up @@ -198,7 +189,7 @@ private fun createStackTrace(continuation: CoroutineStackFrame): ArrayDeque<Stac
}

internal fun StackTraceElement.isArtificial() = className.startsWith(ARTIFICIAL_FRAME_PACKAGE_NAME)
private fun Array<StackTraceElement>.frameIndex(methodName: String) = indexOfFirst { methodName == it.className }
private fun Array<StackTraceElement>.firstFrameIndex(methodName: String) = indexOfFirst { methodName == it.className }

private fun StackTraceElement.elementWiseEquals(e: StackTraceElement): Boolean {
/*
Expand Down
@@ -1,10 +1,9 @@
kotlinx.coroutines.RecoverableTestException
at kotlinx.coroutines.internal.StackTraceRecoveryKt.recoverStackTrace(StackTraceRecovery.kt)
at kotlinx.coroutines.channels.BufferedChannel.receive$suspendImpl(BufferedChannel.kt)
at kotlinx.coroutines.channels.BufferedChannel.receive(BufferedChannel.kt)
at kotlinx.coroutines.exceptions.StackTraceRecoveryChannelsTest.channelReceive(StackTraceRecoveryChannelsTest.kt)
at kotlinx.coroutines.exceptions.StackTraceRecoveryChannelsTest.access$channelReceive(StackTraceRecoveryChannelsTest.kt)
at kotlinx.coroutines.exceptions.StackTraceRecoveryChannelsTest$channelReceive$1.invokeSuspend(StackTraceRecoveryChannelsTest.kt)
Caused by: kotlinx.coroutines.RecoverableTestException
at kotlinx.coroutines.exceptions.StackTraceRecoveryChannelsTest$testReceiveFromChannel$1$job$1.invokeSuspend(StackTraceRecoveryChannelsTest.kt)
at kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt)
at kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt)
@@ -1,10 +1,9 @@
kotlinx.coroutines.RecoverableTestException
at kotlinx.coroutines.internal.StackTraceRecoveryKt.recoverStackTrace(StackTraceRecovery.kt)
at kotlinx.coroutines.channels.BufferedChannel.receive$suspendImpl(BufferedChannel.kt)
at kotlinx.coroutines.channels.BufferedChannel.receive(BufferedChannel.kt)
at kotlinx.coroutines.exceptions.StackTraceRecoveryResumeModeTest$withContext$2.invokeSuspend(StackTraceRecoveryResumeModeTest.kt)
Caused by: kotlinx.coroutines.RecoverableTestException
at kotlinx.coroutines.exceptions.StackTraceRecoveryResumeModeTest.testResumeModeFastPath(StackTraceRecoveryResumeModeTest.kt)
at kotlinx.coroutines.exceptions.StackTraceRecoveryResumeModeTest.access$testResumeModeFastPath(StackTraceRecoveryResumeModeTest.kt)
at kotlinx.coroutines.exceptions.StackTraceRecoveryResumeModeTest$testUnconfined$1.invokeSuspend(StackTraceRecoveryResumeModeTest.kt)
at kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt)
at kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt)
Expand Up @@ -87,6 +87,26 @@ class StackTraceRecoveryCustomExceptionsTest : TestBase() {
assertEquals("Token OK", ex.message)
}

@Test
fun testNestedExceptionWithCause() = runTest {
val result = runCatching {
coroutineScope<Unit> {
throw NestedException(IllegalStateException("ERROR"))
}
}
val ex = result.exceptionOrNull() ?: error("Expected to fail")
assertIs<NestedException>(ex)
assertIs<NestedException>(ex.cause)
val originalCause = ex.cause?.cause
assertIs<IllegalStateException>(originalCause)
assertEquals("ERROR", originalCause.message)
}

class NestedException : RuntimeException {
constructor(cause: Throwable) : super(cause)
constructor() : super()
}

@Test
fun testWrongMessageExceptionInChannel() = runTest {
val result = produce<Unit>(SupervisorJob() + Dispatchers.Unconfined) {
Expand Down

0 comments on commit 25a3553

Please sign in to comment.