Skip to content

Commit

Permalink
Merge pull request #10640 from lrytz/t12921-backport
Browse files Browse the repository at this point in the history
[backport] Fix RedBlackTree.doFrom / doTo / doUntil
  • Loading branch information
SethTisue committed Jan 18, 2024
2 parents c3cd9a7 + c8b7f62 commit e8ace4b
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 5 deletions.
23 changes: 19 additions & 4 deletions src/library/scala/collection/immutable/RedBlackTree.scala
Expand Up @@ -25,6 +25,21 @@ import scala.annotation.tailrec
* optimizations behind a reasonably clean API.
*/
private[collection] object NewRedBlackTree {
def validate[A](tree: Tree[A, _])(implicit ordering: Ordering[A]): tree.type = {
def impl(tree: Tree[A, _], keyProp: A => Boolean): Int = {
assert(keyProp(tree.key), s"key check failed: $tree")
if (tree.isRed) {
assert(tree.left == null || tree.left.isBlack, s"red-red left $tree")
assert(tree.right == null || tree.right.isBlack, s"red-red right $tree")
}
val leftBlacks = if (tree.left == null) 0 else impl(tree.left, k => keyProp(k) && ordering.compare(k, tree.key) < 0)
val rightBlacks = if (tree.right == null) 0 else impl(tree.right, k => keyProp(k) && ordering.compare(k, tree.key) > 0)
assert(leftBlacks == rightBlacks, s"not balanced: $tree")
leftBlacks + (if (tree.isBlack) 1 else 0)
}
if (tree != null) impl(tree, _ => true)
tree
}

def isEmpty(tree: Tree[_, _]): Boolean = tree eq null

Expand Down Expand Up @@ -447,23 +462,23 @@ private[collection] object NewRedBlackTree {
if (ordering.lt(tree.key, from)) return doFrom(tree.right, from)
val newLeft = doFrom(tree.left, from)
if (newLeft eq tree.left) tree
else if (newLeft eq null) upd(tree.right, tree.key, tree.value, overwrite = false)
else if (newLeft eq null) maybeBlacken(upd(tree.right, tree.key, tree.value, overwrite = false))
else join(newLeft, tree.key, tree.value, tree.right)
}
private[this] def doTo[A, B](tree: Tree[A, B], to: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
if (tree eq null) return null
if (ordering.lt(to, tree.key)) return doTo(tree.left, to)
val newRight = doTo(tree.right, to)
if (newRight eq tree.right) tree
else if (newRight eq null) upd(tree.left, tree.key, tree.value, overwrite = false)
else join (tree.left, tree.key, tree.value, newRight)
else if (newRight eq null) maybeBlacken(upd(tree.left, tree.key, tree.value, overwrite = false))
else join(tree.left, tree.key, tree.value, newRight)
}
private[this] def doUntil[A, B](tree: Tree[A, B], until: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
if (tree eq null) return null
if (ordering.lteq(until, tree.key)) return doUntil(tree.left, until)
val newRight = doUntil(tree.right, until)
if (newRight eq tree.right) tree
else if (newRight eq null) upd(tree.left, tree.key, tree.value, overwrite = false)
else if (newRight eq null) maybeBlacken(upd(tree.left, tree.key, tree.value, overwrite = false))
else join(tree.left, tree.key, tree.value, newRight)
}

Expand Down
84 changes: 84 additions & 0 deletions test/junit/scala/collection/immutable/SortedSetTest.scala
@@ -1,8 +1,10 @@
package scala.collection.immutable

import org.junit.Assert.assertEquals
import org.junit.Test

import scala.tools.testing.AllocationTest
import scala.tools.testing.AssertUtil.assertThrows


class SortedSetTest extends AllocationTest{
Expand All @@ -23,4 +25,86 @@ class SortedSetTest extends AllocationTest{
val ord = Ordering[String]
exactAllocates(168)(SortedSet("a", "b")(ord))
}

@Test def redBlackValidate(): Unit = {
import NewRedBlackTree._
def redLeaf(x: Int) = RedTree(x, null, null, null)
def blackLeaf(x: Int) = BlackTree(x, null, null, null)

validate(redLeaf(1))
validate(blackLeaf(1))
assertThrows[AssertionError](validate(RedTree(2, null, redLeaf(1), null)), _.contains("red-red"))
assertThrows[AssertionError](validate(RedTree(2, null, blackLeaf(1), null)), _.contains("not balanced"))
validate(RedTree(2, null, blackLeaf(1), blackLeaf(3)))
validate(BlackTree(2, null, blackLeaf(1), blackLeaf(3)))
assertThrows[AssertionError](validate(RedTree(4, null, blackLeaf(1), blackLeaf(3))), _.contains("key check"))
}

@Test def t12921(): Unit = {
val s1 = TreeSet(6, 1, 11, 9, 10, 8)
NewRedBlackTree.validate(s1.tree)

val s2 = s1.from(2)
NewRedBlackTree.validate(s2.tree)
assertEquals(Set(6, 8, 9, 10, 11), s2)

val s3 = s2 ++ Seq(7,3,5)
NewRedBlackTree.validate(s3.tree)
assertEquals(Set(3, 5, 6, 7, 8, 9, 10, 11), s3)

val s4 = s3.from(4)
NewRedBlackTree.validate(s4.tree)
assertEquals(Set(5, 6, 7, 8, 9, 10, 11), s4)
}

@Test def t12921b(): Unit = {
import NewRedBlackTree._
val t = BlackTree(
5,
null,
BlackTree(
3,
null,
RedTree(1, null, null, null),
RedTree(4, null, null, null)
),
BlackTree(7, null, RedTree(6, null, null, null), null)
)
validate(t)
validate(from(t, 2))
}

@Test def t12921c(): Unit = {
import NewRedBlackTree._
val t = BlackTree(
8,
null,
BlackTree(4, null, null, RedTree(6, null, null, null)),
BlackTree(
12,
null,
RedTree(10, null, null, null),
RedTree(14, null, null, null)
)
)
validate(t)
validate(to(t, 13))
}

@Test def t12921d(): Unit = {
import NewRedBlackTree._
val t = BlackTree(
8,
null,
BlackTree(4, null, null, RedTree(6, null, null, null)),
BlackTree(
12,
null,
RedTree(10, null, null, null),
RedTree(14, null, null, null)
)
)
validate(t)
validate(until(t, 13))
}
}
2 changes: 1 addition & 1 deletion test/scalacheck/redblacktree.scala
Expand Up @@ -65,7 +65,7 @@ abstract class RedBlackTreeTest extends Properties("RedBlackTree") {
def genInput: Gen[(Tree[String, Int], ModifyParm, Tree[String, Int])] = for {
tree <- genTree
parm <- genParm(tree)
} yield (tree, parm, modify(tree, parm))
} yield (tree, parm, validate(modify(tree, parm)))
}

trait RedBlackTreeInvariants {
Expand Down

0 comments on commit e8ace4b

Please sign in to comment.