Skip to content

Commit

Permalink
Prevent setting Dispatchers.Main concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
dkhalanskyjb committed Nov 17, 2021
1 parent f6c8fdb commit 1a8af52
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 22 deletions.
12 changes: 5 additions & 7 deletions kotlinx-coroutines-test/common/src/TestCoroutineDispatchers.kt
Expand Up @@ -7,7 +7,6 @@ package kotlinx.coroutines.test
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.test.internal.*
import kotlinx.coroutines.test.internal.TestMainDispatcher
import kotlin.coroutines.*

Expand Down Expand Up @@ -84,7 +83,8 @@ import kotlin.coroutines.*
public fun UnconfinedTestDispatcher(
scheduler: TestCoroutineScheduler? = null,
name: String? = null
): TestDispatcher = UnconfinedTestDispatcherImpl(scheduler ?: mainTestScheduler ?: TestCoroutineScheduler(), name)
): TestDispatcher = UnconfinedTestDispatcherImpl(
scheduler ?: TestMainDispatcher.currentTestScheduler ?: TestCoroutineScheduler(), name)

private class UnconfinedTestDispatcherImpl(
override val scheduler: TestCoroutineScheduler,
Expand Down Expand Up @@ -141,7 +141,8 @@ private class UnconfinedTestDispatcherImpl(
public fun StandardTestDispatcher(
scheduler: TestCoroutineScheduler? = null,
name: String? = null
): TestDispatcher = StandardTestDispatcherImpl(scheduler ?: mainTestScheduler ?: TestCoroutineScheduler(), name)
): TestDispatcher = StandardTestDispatcherImpl(
scheduler ?: TestMainDispatcher.currentTestScheduler ?: TestCoroutineScheduler(), name)

private class StandardTestDispatcherImpl(
override val scheduler: TestCoroutineScheduler = TestCoroutineScheduler(),
Expand All @@ -154,7 +155,4 @@ private class StandardTestDispatcherImpl(
}

override fun toString(): String = "${name ?: "StandardTestDispatcher"}[scheduler=$scheduler]"
}

private val mainTestScheduler
get() = ((Dispatchers.Main as? TestMainDispatcher)?.delegate as? TestDispatcher)?.scheduler
}
2 changes: 1 addition & 1 deletion kotlinx-coroutines-test/common/src/TestDispatchers.kt
Expand Up @@ -21,7 +21,7 @@ import kotlin.jvm.*
@ExperimentalCoroutinesApi
public fun Dispatchers.setMain(dispatcher: CoroutineDispatcher) {
require(dispatcher !is TestMainDispatcher) { "Dispatchers.setMain(Dispatchers.Main) is prohibited, probably Dispatchers.resetMain() should be used instead" }
getTestMainDispatcher().delegate = dispatcher
getTestMainDispatcher().setDispatcher(dispatcher)
}

/**
Expand Down
65 changes: 57 additions & 8 deletions kotlinx-coroutines-test/common/src/internal/TestMainDispatcher.kt
Expand Up @@ -3,40 +3,89 @@
*/

package kotlinx.coroutines.test.internal

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.test.*
import kotlin.coroutines.*

/**
* The testable main dispatcher used by kotlinx-coroutines-test.
* It is a [MainCoroutineDispatcher] that delegates all actions to a settable delegate.
*/
internal class TestMainDispatcher(var delegate: CoroutineDispatcher):
internal class TestMainDispatcher(delegate: CoroutineDispatcher):
MainCoroutineDispatcher(),
Delay
{
private val mainDispatcher = delegate // the initial value passed to the constructor
private val mainDispatcher = delegate
private var delegate = NonConcurrentlyModifiable(mainDispatcher, "Dispatchers.Main")

private val delay
get() = delegate as? Delay ?: defaultDelay
get() = delegate.value as? Delay ?: defaultDelay

override val immediate: MainCoroutineDispatcher
get() = (delegate as? MainCoroutineDispatcher)?.immediate ?: this
get() = (delegate.value as? MainCoroutineDispatcher)?.immediate ?: this

override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.value.dispatch(context, block)

override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.dispatch(context, block)
override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.value.isDispatchNeeded(context)

override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.isDispatchNeeded(context)
override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.value.dispatchYield(context, block)

override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.dispatchYield(context, block)
fun setDispatcher(dispatcher: CoroutineDispatcher) {
delegate.value = dispatcher
}

fun resetDispatcher() {
delegate = mainDispatcher
delegate.value = mainDispatcher
}

override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) =
delay.scheduleResumeAfterDelay(timeMillis, continuation)

override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle =
delay.invokeOnTimeout(timeMillis, block, context)

companion object {
internal val currentTestDispatcher
get() = (Dispatchers.Main as? TestMainDispatcher)?.delegate?.value as? TestDispatcher

internal val currentTestScheduler
get() = currentTestDispatcher?.scheduler
}

/**
* A wrapper around a value that attempts to throw when writing happens concurrently with reading.
*
* The read operations never throw. Instead, the failures detected inside them will be remembered and thrown on the
* next modification.
*/
private class NonConcurrentlyModifiable<T>(private val initialValue: T, private val name: String) {
private val readers = atomic(0) // number of concurrent readers
private val isWriting = atomic(false) // a modification is happening currently
private val exceptionWhenReading: AtomicRef<Throwable?> = atomic(null) // exception from reading
private val _value = atomic(initialValue) // the backing field for the value

private fun concurrentWW() = IllegalStateException("$name is modified concurrently")
private fun concurrentRW() = IllegalStateException("$name is used concurrently with setting it")

var value: T
get() {
readers.incrementAndGet()
if (isWriting.value) exceptionWhenReading.value = concurrentRW()
val result = _value.value
readers.decrementAndGet()
return result
}
set(value: T) {
exceptionWhenReading.getAndSet(null)?.let { throw it }
if (readers.value != 0) throw concurrentRW()
if (!isWriting.compareAndSet(expect = false, update = true)) throw concurrentWW()
_value.value = value
isWriting.value = false
if (readers.value != 0) throw concurrentRW()
}
}
}

@Suppress("INVISIBLE_MEMBER")
Expand Down
6 changes: 2 additions & 4 deletions kotlinx-coroutines-test/common/test/TestDispatchersTest.kt
Expand Up @@ -24,7 +24,7 @@ class TestDispatchersTest: OrderedExecutionTestBase() {
@NoJs
@Test
fun testMainMocking() = runTest {
val mainAtStart = mainTestDispatcher
val mainAtStart = TestMainDispatcher.currentTestDispatcher
assertNotNull(mainAtStart)
withContext(Dispatchers.Main) {
delay(10)
Expand All @@ -35,7 +35,7 @@ class TestDispatchersTest: OrderedExecutionTestBase() {
withContext(Dispatchers.Main) {
delay(10)
}
assertSame(mainAtStart, mainTestDispatcher)
assertSame(mainAtStart, TestMainDispatcher.currentTestDispatcher)
}

/** Tests that the mocked [Dispatchers.Main] correctly forwards [Delay] methods. */
Expand Down Expand Up @@ -96,5 +96,3 @@ class TestDispatchersTest: OrderedExecutionTestBase() {
}
}
}

private val mainTestDispatcher get() = ((Dispatchers.Main as? TestMainDispatcher)?.delegate as? TestDispatcher)
4 changes: 2 additions & 2 deletions kotlinx-coroutines-test/js/test/FailingTests.kt
Expand Up @@ -25,11 +25,11 @@ class FailingTests {
@Test
fun testAfterTestIsConcurrent() = runTest {
try {
val mainAtStart = (Dispatchers.Main as? TestMainDispatcher)?.delegate as? TestDispatcher ?: return@runTest
val mainAtStart = TestMainDispatcher.currentTestDispatcher ?: return@runTest
withContext(Dispatchers.Default) {
// context switch
}
assertNotSame(mainAtStart, (Dispatchers.Main as TestMainDispatcher).delegate)
assertNotSame(mainAtStart, TestMainDispatcher.currentTestDispatcher!!)
} finally {
assertTrue(tearDownEntered)
}
Expand Down

0 comments on commit 1a8af52

Please sign in to comment.