Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix usage of startCoroutineUninterceptedOrReturn #2789

Merged
merged 7 commits into from Aug 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions arrow-libs/core/arrow-core/api/arrow-core.api
Expand Up @@ -2722,6 +2722,7 @@ public final class arrow/core/continuations/EffectScopeKt {

public final class arrow/core/continuations/FoldContinuation : kotlin/coroutines/Continuation {
public fun <init> (Larrow/core/continuations/Token;Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/Continuation;)V
public fun <init> (Larrow/core/continuations/Token;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)V
public fun getContext ()Lkotlin/coroutines/CoroutineContext;
public fun resumeWith (Ljava/lang/Object;)V
}
Expand Down
Expand Up @@ -11,10 +11,11 @@ import arrow.core.nonFatalOrThrow
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.cancellation.CancellationException
import kotlin.coroutines.intrinsics.createCoroutineUnintercepted
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
import kotlin.coroutines.intrinsics.startCoroutineUninterceptedOrReturn
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException

/**
* [Effect] represents a function of `suspend () -> A`, that short-circuit with a value of [R] (and [Throwable]),
Expand Down Expand Up @@ -711,14 +712,34 @@ internal class Token {
internal class FoldContinuation<B>(
private val token: Token,
override val context: CoroutineContext,
private val error: suspend (Throwable) -> B,
private val parent: Continuation<B>
) : Continuation<B> {

constructor(token: Token, context: CoroutineContext, parent: Continuation<B>) : this(token, context, { throw it }, parent)

// In contrast to `createCoroutineUnintercepted this doesn't create a new ContinuationImpl
private fun <A> (suspend () -> A).startCoroutineUnintercepted(cont: Continuation<A>): Unit {
try {
when (val res = startCoroutineUninterceptedOrReturn(cont)) {
COROUTINE_SUSPENDED -> Unit
else -> cont.resume(res as A)
}
// We need to wire all immediately throw exceptions to the parent Continuation
} catch (e: Throwable) {
cont.resumeWithException(e)
}
}

override fun resumeWith(result: Result<B>) {
result.fold(parent::resume) { throwable ->
if (throwable is Suspend && token == throwable.token) {
val f: suspend () -> B = { throwable.recover(throwable.shifted) as B }
f.createCoroutineUnintercepted(parent).resume(Unit)
} else parent.resumeWith(result)
when {
throwable is Suspend && token == throwable.token ->
suspend { throwable.recover(throwable.shifted) as B }.startCoroutineUnintercepted(parent)

throwable !is Suspend -> suspend { error(throwable.nonFatalOrThrow()) }.startCoroutineUnintercepted(parent)
else -> parent.resumeWith(result)
}
}
}
}
Expand Down Expand Up @@ -756,9 +777,18 @@ internal class FoldContinuation<B>(
public fun <R, A> effect(f: suspend EffectScope<R>.() -> A): Effect<R, A> = DefaultEffect(f)

private class DefaultEffect<R, A>(val f: suspend EffectScope<R>.() -> A) : Effect<R, A> {
// We create a `Token` for fold Continuation, so we can properly differentiate between nested
// folds
override suspend fun <B> fold(recover: suspend (R) -> B, transform: suspend (A) -> B): B =

override suspend fun <B> fold(
recover: suspend (shifted: R) -> B,
transform: suspend (value: A) -> B,
): B = fold({ throw it }, recover, transform)

// We create a `Token` for fold Continuation, so we can properly differentiate between nested folds
override suspend fun <B> fold(
error: suspend (error: Throwable) -> B,
recover: suspend (shifted: R) -> B,
transform: suspend (value: A) -> B,
): B =
suspendCoroutineUninterceptedOrReturn { cont ->
val token = Token()
val effectScope =
Expand All @@ -780,12 +810,15 @@ private class DefaultEffect<R, A>(val f: suspend EffectScope<R>.() -> A) : Effec

try {
suspend { transform(f(effectScope)) }
.startCoroutineUninterceptedOrReturn(FoldContinuation(token, cont.context, cont))
.startCoroutineUninterceptedOrReturn(FoldContinuation(token, cont.context, error, cont))
} catch (e: Suspend) {
if (token == e.token) {
val f: suspend () -> B = { e.recover(e.shifted) as B }
f.startCoroutineUninterceptedOrReturn(cont)
} else throw e
} catch (e: Throwable) {
val f: suspend () -> B = { error(e.nonFatalOrThrow()) }
f.startCoroutineUninterceptedOrReturn(cont)
}
}
}
Expand Down
Expand Up @@ -5,11 +5,13 @@ import arrow.core.identity
import arrow.core.left
import arrow.core.right
import io.kotest.assertions.fail
import io.kotest.common.runBlocking
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
import io.kotest.property.Arb
import io.kotest.property.arbitrary.arbitrary
import io.kotest.property.arbitrary.boolean
import io.kotest.property.arbitrary.flatMap
import io.kotest.property.arbitrary.int
import io.kotest.property.arbitrary.long
import io.kotest.property.arbitrary.orNull
Expand Down Expand Up @@ -323,12 +325,61 @@ class EffectSpec :
newError.toEither() shouldBe Either.Left(error.reversed().toList())
}
}

"Can handle thrown exceptions" {
checkAll(Arb.string().suspend(), Arb.string().suspend()) { msg, fallback ->
effect<Int, String> {
throw RuntimeException(msg())
}.fold(
{ fallback() },
::identity,
::identity
) shouldBe fallback()
}
}

"Can shift from thrown exceptions" {
checkAll(Arb.string().suspend(), Arb.string().suspend()) { msg, fallback ->
effect<String, Int> {
effect<Int, String> {
throw RuntimeException(msg())
}.fold(
{ shift(fallback()) },
::identity,
{ it.length }
)
}.runCont() shouldBe fallback()
}
}

"Can throw from thrown exceptions" {
checkAll(Arb.string().suspend(), Arb.string().suspend()) { msg, fallback ->
shouldThrow<IllegalStateException> {
effect<Int, String> {
throw RuntimeException(msg())
}.fold(
{ throw IllegalStateException(fallback()) },
::identity,
{ it.length }
)
}.message shouldBe fallback()
}
}
})

private data class Failure(val msg: String)

suspend fun currentContext(): CoroutineContext = kotlin.coroutines.coroutineContext

// Turn `A` into `suspend () -> A` which tests both the `immediate` and `COROUTINE_SUSPENDED` path.
private fun <A> Arb<A>.suspend(): Arb<suspend () -> A> =
flatMap { a ->
arbitrary(listOf(
{ a },
suspend { a.suspend() }
)) { suspend { a.suspend() } }
}

internal suspend fun Throwable.suspend(): Nothing = suspendCoroutineUninterceptedOrReturn { cont ->
suspend { throw this }
.startCoroutine(Continuation(Dispatchers.Default) { cont.intercepted().resumeWith(it) })
Expand Down