Skip to content

Commit

Permalink
Fixed memory leak on a race between adding/removing from lock-free li…
Browse files Browse the repository at this point in the history
…st (#1845)

* The problem was introduced by #1565. When doing concurrent add+removeFirst the following can happen:
  - "add" completes, but has not correct prev pointer in next node yet
  - "removeFirst" removes freshly added element
  - "add" performs "finishAdd" that adjust prev pointer of the next node and thus removed element is pointed from the list again
* A separate LockFreeLinkedListAddRemoveStressTest is added that reproduces this problem.
* The old LockFreeLinkedListAtomicLFStressTest is refactored a bit.
  • Loading branch information
elizarov committed Mar 6, 2020
1 parent c67aed0 commit 7df61ee
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 29 deletions.
15 changes: 10 additions & 5 deletions kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ public actual open class LockFreeLinkedListNode {
final override fun updatedNext(affected: Node, next: Node): Any = next.removed()

final override fun finishOnSuccess(affected: Node, next: Node) {
// Complete removal operation here. It bails out if next node is also removed and it becomes
// Complete removal operation here. It bails out if next node is also removed. It becomes
// responsibility of the next's removes to call correctPrev which would help fix all the links.
next.correctPrev(null)
}
Expand Down Expand Up @@ -531,7 +531,12 @@ public actual open class LockFreeLinkedListNode {
private fun finishAdd(next: Node) {
next._prev.loop { nextPrev ->
if (this.next !== next) return // this or next was removed or another node added, remover/adder fixes up links
if (next._prev.compareAndSet(nextPrev, this)) return
if (next._prev.compareAndSet(nextPrev, this)) {
// This newly added node could have been removed, and the above CAS would have added it physically again.
// Let us double-check for this situation and correct if needed
if (isRemoved) next.correctPrev(null)
return
}
}
}

Expand All @@ -546,15 +551,15 @@ public actual open class LockFreeLinkedListNode {
* * When this node is removed. In this case there is no need to waste time on corrections, because
* remover of this node will ultimately call [correctPrev] on the next node and that will fix all
* the links from this node, too.
* * When [op] descriptor is not `null` and and operation descriptor that is [OpDescriptor.isEarlierThan]
* * When [op] descriptor is not `null` and operation descriptor that is [OpDescriptor.isEarlierThan]
* that current [op] is found while traversing the list. This `null` result will be translated
* by callers to [RETRY_ATOMIC].
*/
private tailrec fun correctPrev(op: OpDescriptor?): Node? {
val oldPrev = _prev.value
var prev: Node = oldPrev
var last: Node? = null // will be set so that last.next === prev
while (true) { // move the the left until first non-removed node
while (true) { // move the left until first non-removed node
val prevNext: Any = prev._next.value
when {
// fast path to find quickly find prev node when everything is properly linked
Expand All @@ -565,7 +570,7 @@ public actual open class LockFreeLinkedListNode {
// Note: retry from scratch on failure to update prev
return correctPrev(op)
}
return prev // return a correct prev
return prev // return the correct prev
}
// slow path when we need to help remove operations
this.isRemoved -> return null // nothing to do, this node was removed, bail out asap to save time
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.internal

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import java.util.concurrent.*
import kotlin.concurrent.*
import kotlin.test.*

class LockFreeLinkedListAddRemoveStressTest : TestBase() {
private class Node : LockFreeLinkedListNode()

private val nRepeat = 100_000 * stressTestMultiplier
private val list = LockFreeLinkedListHead()
private val barrier = CyclicBarrier(3)
private val done = atomic(false)
private val removed = atomic(0)

@Test
fun testStressAddRemove() {
val threads = ArrayList<Thread>()
threads += testThread("adder") {
val node = Node()
list.addLast(node)
if (node.remove()) removed.incrementAndGet()
}
threads += testThread("remover") {
val node = list.removeFirstOrNull()
if (node != null) removed.incrementAndGet()
}
try {
for (i in 1..nRepeat) {
barrier.await()
barrier.await()
assertEquals(i, removed.value)
list.validate()
}
} finally {
done.value = true
barrier.await()
threads.forEach { it.join() }
}
}

private fun testThread(name: String, op: () -> Unit) = thread(name = name) {
while (true) {
barrier.await()
if (done.value) break
op()
barrier.await()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.internal
Expand All @@ -19,9 +19,9 @@ import kotlin.test.*
class LockFreeLinkedListAtomicLFStressTest {
private val env = LockFreedomTestEnvironment("LockFreeLinkedListAtomicLFStressTest")

data class IntNode(val i: Int) : LockFreeLinkedListNode()
private data class Node(val i: Long) : LockFreeLinkedListNode()

private val TEST_DURATION_SEC = 5 * stressTestMultiplier
private val nSeconds = 5 * stressTestMultiplier

private val nLists = 4
private val nAdderThreads = 4
Expand All @@ -32,7 +32,8 @@ class LockFreeLinkedListAtomicLFStressTest {
private val undone = AtomicLong()
private val missed = AtomicLong()
private val removed = AtomicLong()
val error = AtomicReference<Throwable>()
private val error = AtomicReference<Throwable>()
private val index = AtomicLong()

@Test
fun testStress() {
Expand All @@ -42,23 +43,23 @@ class LockFreeLinkedListAtomicLFStressTest {
when (rnd.nextInt(4)) {
0 -> {
val list = lists[rnd.nextInt(nLists)]
val node = IntNode(threadId)
val node = Node(index.incrementAndGet())
addLastOp(list, node)
randomSpinWaitIntermission()
tryRemoveOp(node)
}
1 -> {
// just to test conditional add
val list = lists[rnd.nextInt(nLists)]
val node = IntNode(threadId)
val node = Node(index.incrementAndGet())
addLastIfTrueOp(list, node)
randomSpinWaitIntermission()
tryRemoveOp(node)
}
2 -> {
// just to test failed conditional add and burn some time
val list = lists[rnd.nextInt(nLists)]
val node = IntNode(threadId)
val node = Node(index.incrementAndGet())
addLastIfFalseOp(list, node)
}
3 -> {
Expand All @@ -68,8 +69,8 @@ class LockFreeLinkedListAtomicLFStressTest {
check(idx1 < idx2) // that is our global order
val list1 = lists[idx1]
val list2 = lists[idx2]
val node1 = IntNode(threadId)
val node2 = IntNode(-threadId - 1)
val node1 = Node(index.incrementAndGet())
val node2 = Node(index.incrementAndGet())
addTwoOp(list1, node1, list2, node2)
randomSpinWaitIntermission()
tryRemoveOp(node1)
Expand All @@ -91,13 +92,13 @@ class LockFreeLinkedListAtomicLFStressTest {
removeTwoOp(list1, list2)
}
}
env.performTest(TEST_DURATION_SEC) {
val _undone = undone.get()
val _missed = missed.get()
val _removed = removed.get()
println(" Adders undone $_undone node additions")
println(" Adders missed $_missed nodes")
println("Remover removed $_removed nodes")
env.performTest(nSeconds) {
val undone = undone.get()
val missed = missed.get()
val removed = removed.get()
println(" Adders undone $undone node additions")
println(" Adders missed $missed nodes")
println("Remover removed $removed nodes")
}
error.get()?.let { throw it }
assertEquals(missed.get(), removed.get())
Expand All @@ -106,19 +107,19 @@ class LockFreeLinkedListAtomicLFStressTest {
lists.forEach { it.validate() }
}

private fun addLastOp(list: LockFreeLinkedListHead, node: IntNode) {
private fun addLastOp(list: LockFreeLinkedListHead, node: Node) {
list.addLast(node)
}

private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: IntNode) {
assertTrue(list.addLastIf(node, { true }))
private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: Node) {
assertTrue(list.addLastIf(node) { true })
}

private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: IntNode) {
assertFalse(list.addLastIf(node, { false }))
private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: Node) {
assertFalse(list.addLastIf(node) { false })
}

private fun addTwoOp(list1: LockFreeLinkedListHead, node1: IntNode, list2: LockFreeLinkedListHead, node2: IntNode) {
private fun addTwoOp(list1: LockFreeLinkedListHead, node1: Node, list2: LockFreeLinkedListHead, node2: Node) {
val add1 = list1.describeAddLast(node1)
val add2 = list2.describeAddLast(node2)
val op = object : AtomicOp<Any?>() {
Expand All @@ -138,7 +139,7 @@ class LockFreeLinkedListAtomicLFStressTest {
assertTrue(op.perform(null) == null)
}

private fun tryRemoveOp(node: IntNode) {
private fun tryRemoveOp(node: Node) {
if (node.remove())
undone.incrementAndGet()
else
Expand All @@ -165,5 +166,4 @@ class LockFreeLinkedListAtomicLFStressTest {
val success = op.perform(null) == null
if (success) removed.addAndGet(2)
}

}

0 comments on commit 7df61ee

Please sign in to comment.