diff --git a/arrow-libs/core/arrow-core/api/arrow-core.api b/arrow-libs/core/arrow-core/api/arrow-core.api index fc33b4f3d0d..911840b8a4f 100644 --- a/arrow-libs/core/arrow-core/api/arrow-core.api +++ b/arrow-libs/core/arrow-core/api/arrow-core.api @@ -2722,6 +2722,7 @@ public final class arrow/core/continuations/EffectScopeKt { public final class arrow/core/continuations/FoldContinuation : kotlin/coroutines/Continuation { public fun (Larrow/core/continuations/Token;Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/Continuation;)V + public fun (Larrow/core/continuations/Token;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)V public fun getContext ()Lkotlin/coroutines/CoroutineContext; public fun resumeWith (Ljava/lang/Object;)V } diff --git a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/continuations/Effect.kt b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/continuations/Effect.kt index a66f6563786..73e00375f19 100644 --- a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/continuations/Effect.kt +++ b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/continuations/Effect.kt @@ -11,10 +11,11 @@ import arrow.core.nonFatalOrThrow import kotlin.coroutines.Continuation import kotlin.coroutines.CoroutineContext import kotlin.coroutines.cancellation.CancellationException -import kotlin.coroutines.intrinsics.createCoroutineUnintercepted +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED import kotlin.coroutines.intrinsics.startCoroutineUninterceptedOrReturn import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException /** * [Effect] represents a function of `suspend () -> A`, that short-circuit with a value of [R] (and [Throwable]), @@ -711,14 +712,34 @@ internal class Token { internal class FoldContinuation( private val token: Token, override val context: CoroutineContext, + private val error: suspend (Throwable) -> B, private val parent: Continuation ) : Continuation { + + constructor(token: Token, context: CoroutineContext, parent: Continuation) : this(token, context, { throw it }, parent) + + // In contrast to `createCoroutineUnintercepted this doesn't create a new ContinuationImpl + private fun (suspend () -> A).startCoroutineUnintercepted(cont: Continuation): Unit { + try { + when (val res = startCoroutineUninterceptedOrReturn(cont)) { + COROUTINE_SUSPENDED -> Unit + else -> cont.resume(res as A) + } + // We need to wire all immediately throw exceptions to the parent Continuation + } catch (e: Throwable) { + cont.resumeWithException(e) + } + } + override fun resumeWith(result: Result) { result.fold(parent::resume) { throwable -> - if (throwable is Suspend && token == throwable.token) { - val f: suspend () -> B = { throwable.recover(throwable.shifted) as B } - f.createCoroutineUnintercepted(parent).resume(Unit) - } else parent.resumeWith(result) + when { + throwable is Suspend && token == throwable.token -> + suspend { throwable.recover(throwable.shifted) as B }.startCoroutineUnintercepted(parent) + + throwable !is Suspend -> suspend { error(throwable.nonFatalOrThrow()) }.startCoroutineUnintercepted(parent) + else -> parent.resumeWith(result) + } } } } @@ -756,9 +777,18 @@ internal class FoldContinuation( public fun effect(f: suspend EffectScope.() -> A): Effect = DefaultEffect(f) private class DefaultEffect(val f: suspend EffectScope.() -> A) : Effect { - // We create a `Token` for fold Continuation, so we can properly differentiate between nested - // folds - override suspend fun fold(recover: suspend (R) -> B, transform: suspend (A) -> B): B = + + override suspend fun fold( + recover: suspend (shifted: R) -> B, + transform: suspend (value: A) -> B, + ): B = fold({ throw it }, recover, transform) + + // We create a `Token` for fold Continuation, so we can properly differentiate between nested folds + override suspend fun fold( + error: suspend (error: Throwable) -> B, + recover: suspend (shifted: R) -> B, + transform: suspend (value: A) -> B, + ): B = suspendCoroutineUninterceptedOrReturn { cont -> val token = Token() val effectScope = @@ -780,12 +810,15 @@ private class DefaultEffect(val f: suspend EffectScope.() -> A) : Effec try { suspend { transform(f(effectScope)) } - .startCoroutineUninterceptedOrReturn(FoldContinuation(token, cont.context, cont)) + .startCoroutineUninterceptedOrReturn(FoldContinuation(token, cont.context, error, cont)) } catch (e: Suspend) { if (token == e.token) { val f: suspend () -> B = { e.recover(e.shifted) as B } f.startCoroutineUninterceptedOrReturn(cont) } else throw e + } catch (e: Throwable) { + val f: suspend () -> B = { error(e.nonFatalOrThrow()) } + f.startCoroutineUninterceptedOrReturn(cont) } } } diff --git a/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/continuations/EffectSpec.kt b/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/continuations/EffectSpec.kt index 81f62f911e1..af23c30c529 100644 --- a/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/continuations/EffectSpec.kt +++ b/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/continuations/EffectSpec.kt @@ -5,11 +5,13 @@ import arrow.core.identity import arrow.core.left import arrow.core.right import io.kotest.assertions.fail -import io.kotest.common.runBlocking +import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.StringSpec import io.kotest.matchers.shouldBe import io.kotest.property.Arb +import io.kotest.property.arbitrary.arbitrary import io.kotest.property.arbitrary.boolean +import io.kotest.property.arbitrary.flatMap import io.kotest.property.arbitrary.int import io.kotest.property.arbitrary.long import io.kotest.property.arbitrary.orNull @@ -323,12 +325,61 @@ class EffectSpec : newError.toEither() shouldBe Either.Left(error.reversed().toList()) } } + + "Can handle thrown exceptions" { + checkAll(Arb.string().suspend(), Arb.string().suspend()) { msg, fallback -> + effect { + throw RuntimeException(msg()) + }.fold( + { fallback() }, + ::identity, + ::identity + ) shouldBe fallback() + } + } + + "Can shift from thrown exceptions" { + checkAll(Arb.string().suspend(), Arb.string().suspend()) { msg, fallback -> + effect { + effect { + throw RuntimeException(msg()) + }.fold( + { shift(fallback()) }, + ::identity, + { it.length } + ) + }.runCont() shouldBe fallback() + } + } + + "Can throw from thrown exceptions" { + checkAll(Arb.string().suspend(), Arb.string().suspend()) { msg, fallback -> + shouldThrow { + effect { + throw RuntimeException(msg()) + }.fold( + { throw IllegalStateException(fallback()) }, + ::identity, + { it.length } + ) + }.message shouldBe fallback() + } + } }) private data class Failure(val msg: String) suspend fun currentContext(): CoroutineContext = kotlin.coroutines.coroutineContext +// Turn `A` into `suspend () -> A` which tests both the `immediate` and `COROUTINE_SUSPENDED` path. +private fun Arb.suspend(): Arb A> = + flatMap { a -> + arbitrary(listOf( + { a }, + suspend { a.suspend() } + )) { suspend { a.suspend() } } + } + internal suspend fun Throwable.suspend(): Nothing = suspendCoroutineUninterceptedOrReturn { cont -> suspend { throw this } .startCoroutine(Continuation(Dispatchers.Default) { cont.intercepted().resumeWith(it) })