Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix assertions mode when using coroutines in another thread #3604

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -4,7 +4,7 @@ import io.kotest.core.spec.style.FreeSpec
import io.kotest.core.test.AssertionMode
import io.kotest.matchers.shouldBe

class AssertionCounterFreeSpecTest : FreeSpec({
class AssertionModeFreeSpecTest : FreeSpec({
assertions = AssertionMode.Error
"container should not need to have an assertion" - {
"neither should this container" - {
Expand Down
Expand Up @@ -10,7 +10,7 @@ import io.kotest.core.test.TestCase
import io.kotest.core.test.TestResult
import io.kotest.matchers.shouldBe

class AssertionCounterFunSpecTest : FunSpec() {
class AssertionModeFunSpecTest : FunSpec() {

override fun assertionMode() = AssertionMode.Error

Expand Down
@@ -0,0 +1,28 @@
package com.sksamuel.kotest.assertions

import io.kotest.core.spec.style.FunSpec
import io.kotest.core.test.AssertionMode
import io.kotest.matchers.shouldBe
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext

@OptIn(ExperimentalCoroutinesApi::class)
class AssertionModeThreadTest : FunSpec() {

override fun assertionMode() = AssertionMode.Error

init {

test("assertions from another thread should be counted") {
withContext(Dispatchers.Default.limitedParallelism(1)) {
launch {
1 shouldBe 1
1 shouldBe 1
1 shouldBe 1
}
}
}
}
}
Expand Up @@ -297,6 +297,7 @@ public final class io/kotest/assertions/RetryKt {
public final class io/kotest/assertions/ThreadLocalAssertionCounter : io/kotest/assertions/AssertionCounter {
public static final field INSTANCE Lio/kotest/assertions/ThreadLocalAssertionCounter;
public fun get ()I
public final fun getValues ()Ljava/lang/ThreadLocal;
public fun inc ()V
public fun reset ()V
}
Expand Down
Expand Up @@ -5,11 +5,11 @@ actual val assertionCounter: AssertionCounter = ThreadLocalAssertionCounter

object ThreadLocalAssertionCounter : AssertionCounter {

private val context = object : ThreadLocal<Int>() {
val values = object : ThreadLocal<Int>() {
override fun initialValue(): Int = 0
}

override fun get(): Int = context.get()
override fun reset() = context.set(0)
override fun inc() = context.set(context.get() + 1)
override fun get(): Int = values.get()
override fun reset() = values.set(0)
override fun inc() = values.set(values.get() + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe if multiple coroutines are running for the same test simultaneously?

If all we need to track is "an assertion did / did not occur", perhaps an atomic boolean would be easier to reason about?

}
Expand Up @@ -800,6 +800,28 @@ public final class io/kotest/engine/test/TestCaseExecutor {
public final fun execute (Lio/kotest/core/test/TestCase;Lio/kotest/core/test/TestScope;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/kotest/engine/test/interceptors/AssertionCounterThreadContextElement : kotlinx/coroutines/ThreadContextElement {
public static final field Key Lio/kotest/engine/test/interceptors/AssertionCounterThreadContextElement$Key;
public fun <init> (Ljava/lang/String;)V
public fun fold (Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
public fun get (Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext$Element;
public fun getKey ()Lkotlin/coroutines/CoroutineContext$Key;
public fun minusKey (Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext;
public fun plus (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
public fun restoreThreadContext (Lkotlin/coroutines/CoroutineContext;I)V
public synthetic fun restoreThreadContext (Lkotlin/coroutines/CoroutineContext;Ljava/lang/Object;)V
public fun updateThreadContext (Lkotlin/coroutines/CoroutineContext;)Ljava/lang/Integer;
public synthetic fun updateThreadContext (Lkotlin/coroutines/CoroutineContext;)Ljava/lang/Object;
}

public final class io/kotest/engine/test/interceptors/AssertionCounterThreadContextElement$Key : kotlin/coroutines/CoroutineContext$Key {
}

public final class io/kotest/engine/test/interceptors/AssertionModeThreadLocalContextInterceptor : io/kotest/engine/test/interceptors/TestExecutionInterceptor {
public static final field INSTANCE Lio/kotest/engine/test/interceptors/AssertionModeThreadLocalContextInterceptor;
public fun intercept (Lio/kotest/core/test/TestCase;Lio/kotest/core/test/TestScope;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/kotest/engine/test/interceptors/BlockedThreadTestTimeoutException : io/kotest/engine/test/interceptors/TestTimeoutException {
public synthetic fun <init> (JLjava/lang/String;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
}
Expand Down
Expand Up @@ -28,6 +28,7 @@ import io.kotest.engine.test.interceptors.TestFinishedInterceptor
import io.kotest.engine.test.interceptors.TestNameContextInterceptor
import io.kotest.engine.test.interceptors.TestPathContextInterceptor
import io.kotest.engine.test.interceptors.TimeoutInterceptor
import io.kotest.engine.test.interceptors.assertionModeThreadLocalContextInterceptor
import io.kotest.engine.test.interceptors.blockedThreadTimeoutInterceptor
import io.kotest.engine.test.interceptors.coroutineDispatcherFactoryInterceptor
import io.kotest.engine.test.interceptors.coroutineErrorCollectorInterceptor
Expand Down Expand Up @@ -67,6 +68,7 @@ class TestCaseExecutor(
TestCaseExtensionInterceptor(configuration.registry),
EnabledCheckInterceptor(configuration),
LifecycleInterceptor(listener, timeMark, configuration.registry),
if (platform == Platform.JVM) assertionModeThreadLocalContextInterceptor() else null,
AssertionModeInterceptor,
SoftAssertInterceptor(),
CoroutineLoggingInterceptor(configuration),
Expand Down
Expand Up @@ -32,3 +32,6 @@ internal expect fun blockedThreadTimeoutInterceptor(
*/
@JVMOnly
internal expect fun coroutineErrorCollectorInterceptor(): TestExecutionInterceptor

@JVMOnly
internal expect fun assertionModeThreadLocalContextInterceptor(): TestExecutionInterceptor
Expand Up @@ -16,3 +16,6 @@ internal actual fun blockedThreadTimeoutInterceptor(

internal actual fun coroutineErrorCollectorInterceptor(): TestExecutionInterceptor =
error("Unsupported on $platform")

internal actual fun assertionModeThreadLocalContextInterceptor(): TestExecutionInterceptor =
error("Unsupported on $platform")
Expand Up @@ -18,3 +18,6 @@ internal actual fun blockedThreadTimeoutInterceptor(

internal actual fun coroutineErrorCollectorInterceptor(): TestExecutionInterceptor =
error("Unsupported on $platform")

internal actual fun assertionModeThreadLocalContextInterceptor(): TestExecutionInterceptor =
error("Unsupported on $platform")
@@ -0,0 +1,60 @@
package io.kotest.engine.test.interceptors

import io.kotest.assertions.ThreadLocalAssertionCounter
import io.kotest.common.JVMOnly
import io.kotest.common.TestNameContextElement
import io.kotest.core.test.TestCase
import io.kotest.core.test.TestResult
import io.kotest.core.test.TestScope
import io.kotest.engine.test.scopes.withCoroutineContext
import kotlinx.coroutines.ThreadContextElement
import kotlinx.coroutines.withContext
import java.util.concurrent.ConcurrentHashMap
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext

@JVMOnly
internal actual fun assertionModeThreadLocalContextInterceptor(): TestExecutionInterceptor =
AssertionModeThreadLocalContextInterceptor

/**
* Installs the [AssertionCounterThreadContextElement]s into the running coroutine context.
*/
object AssertionModeThreadLocalContextInterceptor : TestExecutionInterceptor {
override suspend fun intercept(
testCase: TestCase,
scope: TestScope,
test: suspend (TestCase, TestScope) -> TestResult
): TestResult {
val testNameContextElement = coroutineContext[TestNameContextElement] ?: error("Requires TestNameContextElement")
return withContext(AssertionCounterThreadContextElement(testNameContextElement.testName)) {
test(testCase, scope.withCoroutineContext(this.coroutineContext))
}
}
}

private val testNameCounters = ConcurrentHashMap<String, Int>()

class AssertionCounterThreadContextElement(private val testName: String) : ThreadContextElement<Int> {

companion object Key : CoroutineContext.Key<AssertionCounterThreadContextElement>

override val key: CoroutineContext.Key<AssertionCounterThreadContextElement>
get() = Key

// this is invoked before coroutine is resumed on current thread
override fun updateThreadContext(context: CoroutineContext): Int {
// need to use our backing map's value and install that in the thread local copy
val counter = testNameCounters.getOrPut(testName) { 0 }
ThreadLocalAssertionCounter.values.set(counter)
// we track the state in our backing map so this can be ignored
return -1
}

// this is invoked after coroutine has suspended on current thread
override fun restoreThreadContext(context: CoroutineContext, oldState: Int) {
// need to put the current thread-local value into our backing map before the coroutine is switched out
testNameCounters[testName] = ThreadLocalAssertionCounter.values.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if a test launches multiple coroutines and makes assertions in both?

What happens if multiple coroutines launch at the same time but the last one to finish doesn't make any assertions? Won't we end up with a count of 0 assertions for the test?

ThreadLocalAssertionCounter.values.set(oldState)
}
}