Skip to content

Commit

Permalink
Refactor concurrency (#3623)
Browse files Browse the repository at this point in the history
* Bind BlockHoundMode setting to coroutines, lifting restrictions

* Refactor concurrency

* Avoid the use of `runBlocking` inside coroutines. Schedule all
  coroutines below the test engine layer via regular dispatchers,
  bringing the test environment more in line with actual production
  environments.

* Remove direct use of Java's `Executors`, delegating to
  kotlinx.coroutines-provided dispatchers instead.

* Preserve coroutine context (which includes lots of Kotest
  configuration) in multithreaded test executions.

* Remove hard-coded 1-day time limit for multithreaded test executions.

* Change tests relying on single-threading from a coroutine model to a
  thread model, removing coroutine invocations (`delay`), replacing
  those with thread invocations (`Thread.sleep`).

* Avoid leaking threads by closing dispatchers after use.

* Keep the `assertionCounter` synchronized with thread-switching
  coroutines.

* Remove `fun <K, V> concurrentHashMap(): MutableMap<K, V>`, which was
  not thread-safe.

* Make `FixedThreadCoroutineDispatcherFactory.dispatcherAffinity`
  thread-safe.

* Revert a change in `ConcurrentTestSuiteScheduler.schedule` by
  commit c316bbd, which replaced
  `launch` with `async` plus `joinAll`. This was functionally
  equivalent, but `async` was not needed, as there were no values
  returned, and `joinAll` was already performed by the enclosing
  `coroutineScope`.

* Allow 1 thread configured for test case, let invocations default to 1

* Update API

* Simplify replay

* Fix multithreading test, provoke thread switching more reliably

* Fix System.out/err handling: flush streams when switching or capturing

This avoids missing captured output.

* Stabilize multithreading tests, provoke thread switching consistently

* Make CollectingTestEngineListener thread-safe internally

* CollectingTestEngineListener: bring back `result` methods

* Stabilize provoked thread switching with Dispatchers.Unconfined

Increase the minimum delay to 50 milliseconds based on observed
behavior.

* Update API dump

* Make provokeThreadSwitch internal and non-suspending

---------

Co-authored-by: Sam <sam@sksamuel.com>
  • Loading branch information
OliverO2 and sksamuel committed Sep 1, 2023
1 parent 5401405 commit 29ed6c9
Show file tree
Hide file tree
Showing 45 changed files with 399 additions and 305 deletions.
@@ -0,0 +1,25 @@
package com.sksamuel.kotest.assertions

import io.kotest.assertions.assertionCounter
import io.kotest.assertions.assertionCounterContextElement
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.shouldBe
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.withContext

class AssertionCounterMultithreadingTests : FunSpec({
test("assertionCounter should work across coroutine thread switch") {
withContext(Dispatchers.Unconfined + assertionCounterContextElement) {
val threadIds = mutableSetOf<Long>()
assertionCounter.inc()
threadIds.add(Thread.currentThread().id)
delay(50)
assertionCounter.inc()
threadIds.add(Thread.currentThread().id)
assertionCounter.get() shouldBe 2
threadIds shouldHaveSize 2
}
}
})
Expand Up @@ -20,20 +20,21 @@ import io.kotest.matchers.nulls.shouldBeNull
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import java.io.FileNotFoundException
import java.io.IOException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.time.Duration.Companion.days
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext

@OptIn(DelicateCoroutinesApi::class)
class EventuallyTest : WordSpec() {

init {
Expand Down Expand Up @@ -139,35 +140,37 @@ class EventuallyTest : WordSpec() {
count.shouldBeLessThan(3)
}
"do one final iteration if we never executed before interval expired" {
val dispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
launch(dispatcher) {
Thread.sleep(250)
}
val counter = AtomicInteger(0)
withContext(dispatcher) {
// we won't be able to run in here
eventually(1.seconds, 5.milliseconds) {
counter.incrementAndGet()
newSingleThreadContext("single").use { dispatcher ->
launch(dispatcher) {
Thread.sleep(250)
}
val counter = AtomicInteger(0)
withContext(dispatcher) {
// we won't be able to run in here
eventually(1.seconds, 5.milliseconds) {
counter.incrementAndGet()
}
}
counter.get().shouldBe(1)
}
counter.get().shouldBe(1)
}
"do one final iteration if we only executed once and the last delay > interval" {
val dispatcher = Executors.newSingleThreadExecutor().asCoroutineDispatcher()
// this will start immediately, free the dispatcher to allow eventually to run once, then block the thread
launch(dispatcher) {
delay(100.milliseconds)
Thread.sleep(500)
}
val counter = AtomicInteger(0)
withContext(dispatcher) {
// this will execute once immediately, then the earlier async will steal the thread
// and then since the delay has been > interval and times == 1, we will execute once more
eventually(250.milliseconds, 25.milliseconds) {
counter.incrementAndGet() shouldBe 2
newSingleThreadContext("single").use { dispatcher ->
// this will start immediately, free the dispatcher to allow eventually to run once, then block the thread
launch(dispatcher) {
delay(100.milliseconds)
Thread.sleep(500)
}
val counter = AtomicInteger(0)
withContext(dispatcher) {
// this will execute once immediately, then the earlier async will steal the thread
// and then since the delay has been > interval and times == 1, we will execute once more
eventually(250.milliseconds, 25.milliseconds) {
counter.incrementAndGet() shouldBe 2
}
}
counter.get().shouldBe(2)
}
counter.get().shouldBe(2)
}
"handle shouldNotBeNull" {
val duration = measureTimeMillisCompat {
Expand Down
Expand Up @@ -3,11 +3,22 @@ package com.sksamuel.kotest.matchers.future
import io.kotest.assertions.throwables.shouldThrowMessage
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.future.*
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Executors

class FutureMatcherTest : StringSpec({
suspend fun runOnSeparateThread(block: () -> Unit) {
@OptIn(DelicateCoroutinesApi::class)
newSingleThreadContext("separate").use {
withContext(it) {
block()
}
}
}

"test future is completed" {
val completableFuture = CompletableFuture<Int>()
completableFuture.complete(2)
Expand All @@ -28,7 +39,7 @@ class FutureMatcherTest : StringSpec({
}
"test future is completed exceptionally" {
val completableFuture = CompletableFuture<Int>()
Executors.newFixedThreadPool(1).submit {
runOnSeparateThread {
completableFuture.cancel(false)
}
delay(200)
Expand All @@ -39,11 +50,11 @@ class FutureMatcherTest : StringSpec({
completableFuture.complete(2)
completableFuture.shouldNotBeCompletedExceptionally()
}
"test future completes exceptionally with the given exception"{
"test future completes exceptionally with the given exception" {
val completableFuture = CompletableFuture<Int>()
val exception = RuntimeException("Boom Boom")

Executors.newFixedThreadPool(1).submit {
runOnSeparateThread {
completableFuture.completeExceptionally(exception)
}

Expand All @@ -52,7 +63,7 @@ class FutureMatcherTest : StringSpec({
"test future does not completes exceptionally with given exception " {
val completableFuture = CompletableFuture<Int>()

Executors.newFixedThreadPool(1).submit {
runOnSeparateThread {
completableFuture.completeExceptionally(RuntimeException("Boom Boom"))
}

Expand Down
Expand Up @@ -300,13 +300,6 @@ public final class io/kotest/assertions/RetryKt {
public static final fun retryConfig (Lkotlin/jvm/functions/Function1;)Lio/kotest/assertions/RetryConfig;
}

public final class io/kotest/assertions/ThreadLocalAssertionCounter : io/kotest/assertions/AssertionCounter {
public static final field INSTANCE Lio/kotest/assertions/ThreadLocalAssertionCounter;
public fun get ()I
public fun inc ()V
public fun reset ()V
}

public final class io/kotest/assertions/async/TimeoutKt {
public static final fun shouldTimeout (JLjava/util/concurrent/TimeUnit;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun shouldTimeout (Ljava/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand Down Expand Up @@ -389,6 +382,7 @@ public final class io/kotest/assertions/eq/ThrowableEq : io/kotest/assertions/eq

public final class io/kotest/assertions/jvmcounter {
public static final fun getAssertionCounter ()Lio/kotest/assertions/AssertionCounter;
public static final fun getAssertionCounterContextElement ()Lkotlin/coroutines/CoroutineContext$Element;
}

public final class io/kotest/assertions/jvmerrorcollector {
Expand Down
Expand Up @@ -22,9 +22,7 @@ val errorCollectorContextElement: CoroutineContext.Element
get() = ErrorCollectorContextElement(threadLocalErrorCollector.get())


private val threadLocalErrorCollector = object : ThreadLocal<CoroutineLocalErrorCollector>() {
override fun initialValue() = CoroutineLocalErrorCollector()
}
private val threadLocalErrorCollector = ThreadLocal.withInitial { CoroutineLocalErrorCollector() }


private class CoroutineLocalErrorCollector : BasicErrorCollector() {
Expand Down
@@ -1,15 +1,36 @@
@file:JvmName("jvmcounter")

package io.kotest.assertions

actual val assertionCounter: AssertionCounter = ThreadLocalAssertionCounter
import kotlinx.coroutines.asContextElement
import kotlin.coroutines.CoroutineContext

actual val assertionCounter: AssertionCounter get() = threadLocalAssertionCounter.get()

/**
* A [CoroutineContext.Element] which keeps the [assertionCounter] synchronized with thread-switching coroutines.
*
* When using [assertionCounter] without the Kotest framework, this context element should be added to a
* coroutine context, e.g. via
* - `runBlocking(assertionCounterContextElement) { ... }`
* - `runTest(Dispatchers.IO + assertionCounterContextElement) { ... }`
*/
val assertionCounterContextElement: CoroutineContext.Element
get() = threadLocalAssertionCounter.asContextElement()

object ThreadLocalAssertionCounter : AssertionCounter {
private val threadLocalAssertionCounter: ThreadLocal<CoroutineLocalAssertionCounter> =
ThreadLocal.withInitial { CoroutineLocalAssertionCounter() }

private val context = object : ThreadLocal<Int>() {
override fun initialValue(): Int = 0
private class CoroutineLocalAssertionCounter : AssertionCounter {
private var value = 0

override fun get(): Int = value

override fun reset() {
value = 0
}

override fun get(): Int = context.get()
override fun reset() = context.set(0)
override fun inc() = context.set(context.get() + 1)
override fun inc() {
value++
}
}
Expand Up @@ -18,7 +18,7 @@ class AssertSoftlyTests : FunSpec({
threadIds.add(Thread.currentThread().id)
"assertSoftly block begins on $name, id $id" shouldBe "collected failure"
}
delay(10)
delay(50)
Thread.currentThread().run {
threadIds.add(Thread.currentThread().id)
"assertSoftly block ends on $name, id $id" shouldBe "collected failure"
Expand Down
Expand Up @@ -18,7 +18,7 @@ class CluesTests : FunSpec({
val threadIds = mutableSetOf<Long>()
withClue("should not fail") {
threadIds.add(Thread.currentThread().id)
delay(10)
delay(50)
threadIds.add(Thread.currentThread().id)
}
threadIds shouldHaveSize 2
Expand Down
4 changes: 0 additions & 4 deletions kotest-common/api/kotest-common.api
@@ -1,7 +1,3 @@
public final class io/kotest/common/ConcurrentHashMapKt {
public static final fun concurrentHashMap ()Ljava/util/Map;
}

public abstract interface annotation class io/kotest/common/DelicateKotest : java/lang/annotation/Annotation {
}

Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

28 changes: 10 additions & 18 deletions kotest-common/src/jvmMain/kotlin/io/kotest/mpp/replay.kt
@@ -1,9 +1,9 @@
package io.kotest.mpp

import kotlinx.coroutines.runBlocking
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.launch
import kotlinx.coroutines.newFixedThreadPoolContext
import kotlinx.coroutines.withContext

actual suspend fun replay(
times: Int,
Expand All @@ -15,23 +15,15 @@ actual suspend fun replay(
action(it)
}
} else {
val executor = Executors.newFixedThreadPool(threads, NamedThreadFactory("replay-%d"))
val error = AtomicReference<Throwable>(null)
for (k in 0 until times) {
executor.submit {
runBlocking {
try {
action(k)
} catch (t: Throwable) {
error.compareAndSet(null, t)
@OptIn(DelicateCoroutinesApi::class)
newFixedThreadPoolContext(threads, "replay").use { dispatcher ->
withContext(dispatcher) {
repeat(times) {
launch {
action(it)
}
}
}
}
executor.shutdown()
executor.awaitTermination(1, TimeUnit.DAYS)

if (error.get() != null)
throw error.get()
}
}
Expand Up @@ -2,7 +2,6 @@ package io.kotest.extensions.blockhound

import io.kotest.assertions.throwables.shouldNotThrow
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.core.annotation.DoNotParallelize
import io.kotest.core.spec.style.FunSpec
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
Expand Down Expand Up @@ -39,7 +38,6 @@ class BlockHoundCaseTest : FunSpec({
}
})

@DoNotParallelize
class BlockHoundSpecTest : FunSpec({
extension(BlockHound())

Expand All @@ -64,4 +62,13 @@ class BlockHoundSpecTest : FunSpec({
test("nested configuration").config(extensions = listOf(BlockHound(BlockHoundMode.DISABLED))) {
shouldNotThrow<BlockingOperationError> { blockInNonBlockingContext() }
}

test("parallelism").config(invocations = 2, threads = 2) {
shouldThrow<BlockingOperationError> {
withContext(Dispatchers.Default) {
@Suppress("BlockingMethodInNonBlockingContext")
Thread.sleep(2)
}
}
}
})

0 comments on commit 29ed6c9

Please sign in to comment.