Skip to content

Commit

Permalink
Do not request additional worker from 'yield' calls and during post-e…
Browse files Browse the repository at this point in the history
…xecution phase in LimitingDispatcher

Fixes #1704
Fixes #1706
  • Loading branch information
qwwdfsad committed Dec 20, 2019
1 parent 12a0318 commit 4e19954
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 23 deletions.
34 changes: 22 additions & 12 deletions kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt
Expand Up @@ -372,25 +372,34 @@ internal class CoroutineScheduler(
* Dispatches execution of a runnable [block] with a hint to a scheduler whether
* this [block] may execute blocking operations (IO, system calls, locking primitives etc.)
*
* @param taskContext concurrency context of given [block]
* @param fair whether the task should be dispatched fairly (strict FIFO) or not (semi-FIFO)
* [taskContext] -- concurrency context of given [block].
* [tailDispatch] -- whether this [dispatch] call is the last action the (presumably) worker thread does in its current task.
* If `true`, then the task will be dispatched in a FIFO manner and no additional workers will be requested,
* but only if the current thread is a corresponding worker thread.
* Note that caller cannot be ensured that it is being executed on worker thread for the following reasons:
* -- [CoroutineStart.UNDISPATCHED]
* -- Concurrent [close] that effectively shutdowns the worker thread
*/
fun dispatch(block: Runnable, taskContext: TaskContext = NonBlockingContext, fair: Boolean = false) {
fun dispatch(block: Runnable, taskContext: TaskContext = NonBlockingContext, tailDispatch: Boolean = false) {
trackTask() // this is needed for virtual time support
val task = createTask(block, taskContext)
// try to submit the task to the local queue and act depending on the result
val notAdded = submitToLocalQueue(task, fair)
val currentWorker = currentWorker()
val notAdded = currentWorker.submitToLocalQueue(task, tailDispatch)
if (notAdded != null) {
if (!addToGlobalQueue(notAdded)) {
// Global queue is closed in the last step of close/shutdown -- no more tasks should be accepted
throw RejectedExecutionException("$schedulerName was terminated")
}
}
val skipUnpark = tailDispatch && currentWorker != null
// Checking 'task' instead of 'notAdded' is completely okay
if (task.mode == TaskMode.NON_BLOCKING) {
if (skipUnpark) return
signalCpuWork()
} else {
signalBlockingWork()
// Increment blocking tasks anyway
signalBlockingWork(skipUnpark = skipUnpark)
}
}

Expand All @@ -404,9 +413,10 @@ internal class CoroutineScheduler(
return TaskImpl(block, nanoTime, taskContext)
}

private fun signalBlockingWork() {
private fun signalBlockingWork(skipUnpark: Boolean) {
// Use state snapshot to avoid thread overprovision
val stateSnapshot = incrementBlockingTasks()
if (skipUnpark) return
if (tryUnpark()) return
if (tryCreateWorker(stateSnapshot)) return
tryUnpark() // Try unpark again in case there was race between permit release and parking
Expand Down Expand Up @@ -481,19 +491,19 @@ internal class CoroutineScheduler(
* Returns `null` if task was successfully added or an instance of the
* task that was not added or replaced (thus should be added to global queue).
*/
private fun submitToLocalQueue(task: Task, fair: Boolean): Task? {
val worker = currentWorker() ?: return task
private fun Worker?.submitToLocalQueue(task: Task, tailDispatch: Boolean): Task? {
if (this === null) return task
/*
* This worker could have been already terminated from this thread by close/shutdown and it should not
* accept any more tasks into its local queue.
*/
if (worker.state === WorkerState.TERMINATED) return task
if (state === WorkerState.TERMINATED) return task
// Do not add CPU tasks in local queue if we are not able to execute it
if (task.mode === TaskMode.NON_BLOCKING && worker.state === WorkerState.BLOCKING) {
if (task.mode === TaskMode.NON_BLOCKING && state === WorkerState.BLOCKING) {
return task
}
worker.mayHaveLocalTasks = true
return worker.localQueue.add(task, fair = fair)
mayHaveLocalTasks = true
return localQueue.add(task, fair = tailDispatch)
}

private fun currentWorker(): Worker? = (Thread.currentThread() as? Worker)?.takeIf { it.scheduler == this }
Expand Down
14 changes: 9 additions & 5 deletions kotlinx-coroutines-core/jvm/src/scheduling/Dispatcher.kt
Expand Up @@ -65,7 +65,7 @@ open class ExperimentalCoroutineDispatcher(

override fun dispatchYield(context: CoroutineContext, block: Runnable): Unit =
try {
coroutineScheduler.dispatch(block, fair = true)
coroutineScheduler.dispatch(block, tailDispatch = true)
} catch (e: RejectedExecutionException) {
DefaultExecutor.dispatchYield(context, block)
}
Expand Down Expand Up @@ -101,9 +101,9 @@ open class ExperimentalCoroutineDispatcher(
return LimitingDispatcher(this, parallelism, TaskMode.NON_BLOCKING)
}

internal fun dispatchWithContext(block: Runnable, context: TaskContext, fair: Boolean) {
internal fun dispatchWithContext(block: Runnable, context: TaskContext, tailDispatch: Boolean) {
try {
coroutineScheduler.dispatch(block, context, fair)
coroutineScheduler.dispatch(block, context, tailDispatch)
} catch (e: RejectedExecutionException) {
// Context shouldn't be lost here to properly invoke before/after task
DefaultExecutor.enqueue(coroutineScheduler.createTask(block, context))
Expand Down Expand Up @@ -147,15 +147,15 @@ private class LimitingDispatcher(

override fun dispatch(context: CoroutineContext, block: Runnable) = dispatch(block, false)

private fun dispatch(block: Runnable, fair: Boolean) {
private fun dispatch(block: Runnable, tailDispatch: Boolean) {
var taskToSchedule = block
while (true) {
// Commit in-flight tasks slot
val inFlight = inFlightTasks.incrementAndGet()

// Fast path, if parallelism limit is not reached, dispatch task and return
if (inFlight <= parallelism) {
dispatcher.dispatchWithContext(taskToSchedule, this, fair)
dispatcher.dispatchWithContext(taskToSchedule, this, tailDispatch)
return
}

Expand Down Expand Up @@ -185,6 +185,10 @@ private class LimitingDispatcher(
}
}

override fun dispatchYield(context: CoroutineContext, block: Runnable) {
dispatch(block, tailDispatch = true)
}

override fun toString(): String {
return "${super.toString()}[dispatcher = $dispatcher]"
}
Expand Down
Expand Up @@ -194,10 +194,10 @@ class BlockingCoroutineDispatcherTest : SchedulerTestBase() {
fun testYield() = runBlocking {
corePoolSize = 1
maxPoolSize = 1
val ds = blockingDispatcher(1)
val outerJob = launch(ds) {
val bd = blockingDispatcher(1)
val outerJob = launch(bd) {
expect(1)
val innerJob = launch(ds) {
val innerJob = launch(bd) {
// Do nothing
expect(3)
}
Expand All @@ -215,6 +215,21 @@ class BlockingCoroutineDispatcherTest : SchedulerTestBase() {
finish(5)
}

@Test
fun testUndispatchedYield() = runTest {
expect(1)
corePoolSize = 1
maxPoolSize = 1
val blockingDispatcher = blockingDispatcher(1)
val job = launch(blockingDispatcher, CoroutineStart.UNDISPATCHED) {
expect(2)
yield()
}
expect(3)
job.join()
finish(4)
}

@Test(expected = IllegalArgumentException::class)
fun testNegativeParallelism() {
blockingDispatcher(-1)
Expand Down
Expand Up @@ -21,7 +21,6 @@ class BlockingCoroutineDispatcherThreadLimitStressTest : SchedulerTestBase() {
private val concurrentWorkers = AtomicInteger(0)

@Test
@Ignore
fun testLimitParallelismToOne() = runTest {
val limitingDispatcher = blockingDispatcher(1)
// Do in bursts to avoid OOM
Expand Down
Expand Up @@ -117,6 +117,18 @@ class CoroutineDispatcherTest : SchedulerTestBase() {
finish(5)
}

@Test
fun testUndispatchedYield() = runTest {
expect(1)
val job = launch(dispatcher, CoroutineStart.UNDISPATCHED) {
expect(2)
yield()
}
expect(3)
job.join()
finish(4)
}

@Test
fun testThreadName() = runBlocking {
val initialCount = Thread.getAllStackTraces().keys.asSequence()
Expand Down
Expand Up @@ -10,7 +10,6 @@ import org.junit.Test
import org.junit.runner.*
import org.junit.runners.*
import java.util.*
import java.util.concurrent.*
import kotlin.test.*

@RunWith(Parameterized::class)
Expand Down Expand Up @@ -79,6 +78,10 @@ class CoroutineSchedulerCloseStressTest(private val mode: Mode) : TestBase() {
} else {
if (rnd.nextBoolean()) {
delay(1000)
val t = Thread.currentThread()
if (!t.name.contains("DefaultDispatcher-worker")) {
val a = 2
}
} else {
yield()
}
Expand Down
Expand Up @@ -82,7 +82,7 @@ class CoroutineSchedulerTest : TestBase() {
it.dispatch(Runnable {
expect(2)
finishLatch.countDown()
}, fair = true)
}, tailDispatch = true)
})

startLatch.countDown()
Expand Down

0 comments on commit 4e19954

Please sign in to comment.