Skip to content

Commit

Permalink
Merge pull request #10448 from dragonfly-ai/2.13.x
Browse files Browse the repository at this point in the history
Prevent ArrayBuilder capacity overflow/infinite looping in ArrayBuilder.ensureSize(size:Int).
  • Loading branch information
som-snytt committed Dec 14, 2023
2 parents 43aa816 + 862435d commit ed8005f
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 59 deletions.
7 changes: 0 additions & 7 deletions src/library/scala/collection/IterableOnce.scala
Expand Up @@ -283,13 +283,6 @@ object IterableOnce {
case src: Iterable[A] => src.copyToArray[B](xs, start, len)
case src => src.iterator.copyToArray[B](xs, start, len)
}

@inline private[collection] def checkArraySizeWithinVMLimit(size: Int): Unit = {
import scala.runtime.PStatics.VM_MaxArraySize
if (size > VM_MaxArraySize) {
throw new Exception(s"Size of array-backed collection exceeds VM array size limit of ${VM_MaxArraySize}")
}
}
}

/** This implementation trait can be mixed into an `IterableOnce` to get the basic methods that are shared between
Expand Down
55 changes: 29 additions & 26 deletions src/library/scala/collection/mutable/ArrayBuffer.scala
Expand Up @@ -15,11 +15,10 @@ package collection
package mutable

import java.util.Arrays

import scala.annotation.nowarn
import scala.annotation.tailrec
import scala.annotation.{nowarn, tailrec}
import scala.collection.Stepper.EfficientSplit
import scala.collection.generic.DefaultSerializable
import scala.runtime.PStatics.VM_MaxArraySize

/** An implementation of the `Buffer` class using an array to
* represent the assembled sequence internally. Append, update and random
Expand Down Expand Up @@ -70,13 +69,6 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
array = ArrayBuffer.ensureSize(array, size0, n)
}

// TODO 3.T: should be `protected`, perhaps `protected[this]`
/** Ensure that the internal array has at least `n` additional cells more than `size0`. */
private[mutable] def ensureAdditionalSize(n: Int): Unit = {
// `.toLong` to ensure `Long` arithmetic is used and prevent `Int` overflow
array = ArrayBuffer.ensureSize(array, size0, size0.toLong + n)
}

/** Uses the given size to resize internal storage, if necessary.
*
* @param size Expected maximum number of elements.
Expand Down Expand Up @@ -147,10 +139,10 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)

def addOne(elem: A): this.type = {
mutationCount += 1
ensureAdditionalSize(1)
val oldSize = size0
size0 = oldSize + 1
this(oldSize) = elem
val newSize = size0 + 1
ensureSize(newSize)
size0 = newSize
this(size0 - 1) = elem
this
}

Expand All @@ -161,7 +153,7 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
val elemsLength = elems.size0
if (elemsLength > 0) {
mutationCount += 1
ensureAdditionalSize(elemsLength)
ensureSize(size0 + elemsLength)
Array.copy(elems.array, 0, array, length, elemsLength)
size0 = length + elemsLength
}
Expand All @@ -173,7 +165,7 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
def insert(@deprecatedName("n", "2.13.0") index: Int, elem: A): Unit = {
checkWithinBounds(index, index)
mutationCount += 1
ensureAdditionalSize(1)
ensureSize(size0 + 1)
Array.copy(array, index, array, index + 1, size0 - index)
size0 += 1
this(index) = elem
Expand All @@ -191,7 +183,7 @@ class ArrayBuffer[A] private (initialElements: Array[AnyRef], initialSize: Int)
val elemsLength = elems.size
if (elemsLength > 0) {
mutationCount += 1
ensureAdditionalSize(elemsLength)
ensureSize(size0 + elemsLength)
val len = size0
Array.copy(array, index, array, index + elemsLength, len - index)
// if `elems eq this`, this copy is safe because
Expand Down Expand Up @@ -314,24 +306,35 @@ object ArrayBuffer extends StrictOptimizedSeqFactory[ArrayBuffer] {

def empty[A]: ArrayBuffer[A] = new ArrayBuffer[A]()

@inline private def checkArrayLengthLimit(length: Int, currentLength: Int): Unit =
if (length > VM_MaxArraySize)
throw new Exception(s"Array of array-backed collection exceeds VM length limit of $VM_MaxArraySize. Requested length: $length; current length: $currentLength")
else if (length < 0)
throw new Exception(s"Overflow while resizing array of array-backed collection. Requested length: $length; current length: $currentLength; increase: ${length - currentLength}")

/**
* The increased size for an array-backed collection.
*
* @param arrayLen the length of the backing array
* @param targetLen the minimum length to resize up to
* @return -1 if no resizing is needed, or the size for the new array otherwise
* @return
* - `-1` if no resizing is needed, else
* - `VM_MaxArraySize` if `arrayLen` is too large to be doubled, else
* - `max(targetLen, arrayLen * 2, , DefaultInitialSize)`.
* - Throws an exception if `targetLen` exceeds `VM_MaxArraySize` or is negative (overflow).
*/
private def resizeUp(arrayLen: Long, targetLen: Long): Int = {
if (targetLen <= arrayLen) -1
private[mutable] def resizeUp(arrayLen: Int, targetLen: Int): Int = {
if (targetLen > 0 && targetLen <= arrayLen) -1
else {
if (targetLen > Int.MaxValue) throw new Exception(s"Collections cannot have more than ${Int.MaxValue} elements")
IterableOnce.checkArraySizeWithinVMLimit(targetLen.toInt) // safe because `targetSize <= Int.MaxValue`

val newLen = math.max(targetLen, math.max(arrayLen * 2, DefaultInitialSize))
math.min(newLen, scala.runtime.PStatics.VM_MaxArraySize).toInt
checkArrayLengthLimit(targetLen, arrayLen)
if (arrayLen > VM_MaxArraySize / 2) VM_MaxArraySize
else math.max(targetLen, math.max(arrayLen * 2, DefaultInitialSize))
}
}

// if necessary, copy (curSize elements of) the array to a new array of capacity n.
// Should use Array.copyOf(array, resizeEnsuring(array.length))?
private def ensureSize(array: Array[AnyRef], curSize: Int, targetSize: Long): Array[AnyRef] = {
private def ensureSize(array: Array[AnyRef], curSize: Int, targetSize: Int): Array[AnyRef] = {
val newLen = resizeUp(array.length, targetSize)
if (newLen < 0) array
else {
Expand Down
25 changes: 14 additions & 11 deletions src/library/scala/collection/mutable/ArrayBuilder.scala
Expand Up @@ -13,6 +13,7 @@
package scala.collection
package mutable

import scala.collection.mutable.ArrayBuffer.resizeUp
import scala.reflect.ClassTag

/** A builder class for arrays.
Expand All @@ -34,15 +35,11 @@ sealed abstract class ArrayBuilder[T]
override def knownSize: Int = size

protected[this] final def ensureSize(size: Int): Unit = {
if (capacity < size || capacity == 0) {
var newsize = if (capacity == 0) 16 else capacity * 2
while (newsize < size) newsize *= 2
resize(newsize)
}
val newLen = resizeUp(capacity, size)
if (newLen > 0) resize(newLen)
}

override final def sizeHint(size: Int): Unit =
if (capacity < size) resize(size)
override final def sizeHint(size: Int): Unit = if (capacity < size) resize(size)

def clear(): Unit = size = 0

Expand Down Expand Up @@ -491,17 +488,23 @@ object ArrayBuilder {
protected def elems: Array[Unit] = throw new UnsupportedOperationException()

def addOne(elem: Unit): this.type = {
size += 1
val newSize = size + 1
ensureSize(newSize)
size = newSize
this
}

override def addAll(xs: IterableOnce[Unit]): this.type = {
size += xs.iterator.size
val newSize = size + xs.iterator.size
ensureSize(newSize)
size = newSize
this
}

override def addAll(xs: Array[_ <: Unit], offset: Int, length: Int): this.type = {
size += length
val newSize = size + length
ensureSize(newSize)
size = newSize
this
}

Expand All @@ -517,7 +520,7 @@ object ArrayBuilder {
case _ => false
}

protected[this] def resize(size: Int): Unit = ()
protected[this] def resize(size: Int): Unit = capacity = size

override def toString = "ArrayBuilder.ofUnit"
}
Expand Down
2 changes: 1 addition & 1 deletion src/library/scala/collection/mutable/PriorityQueue.scala
Expand Up @@ -89,7 +89,7 @@ sealed class PriorityQueue[A](implicit val ord: Ordering[A])
def p_size0_=(s: Int) = size0 = s
def p_array = array
def p_ensureSize(n: Int) = super.ensureSize(n)
def p_ensureAdditionalSize(n: Int) = super.ensureAdditionalSize(n)
def p_ensureAdditionalSize(n: Int) = super.ensureSize(size0 + n)
def p_swap(a: Int, b: Int): Unit = {
val h = array(a)
array(a) = array(b)
Expand Down
4 changes: 3 additions & 1 deletion src/library/scala/runtime/PStatics.scala
Expand Up @@ -15,5 +15,7 @@ package scala.runtime
// things that should be in `Statics`, but can't be yet for bincompat reasons
// TODO 3.T: move to `Statics`
private[scala] object PStatics {
final val VM_MaxArraySize = 2147483645 // == `Int.MaxValue - 2`, hotspot limit
// `Int.MaxValue - 8` traditional soft limit to maximize compatibility with diverse JVMs
// See https://stackoverflow.com/a/8381338 for example
final val VM_MaxArraySize = 2147483639
}
30 changes: 17 additions & 13 deletions test/junit/scala/collection/mutable/ArrayBufferTest.scala
@@ -1,10 +1,11 @@
package scala.collection.mutable

import org.junit.Test
import org.junit.Assert.{assertEquals, assertTrue}
import org.junit.Test

import java.lang.reflect.InvocationTargetException
import scala.annotation.nowarn
import scala.runtime.PStatics.VM_MaxArraySize
import scala.tools.testkit.AssertUtil.{assertSameElements, assertThrows, fail}
import scala.tools.testkit.ReflectUtil.{getMethodAccessible, _}
import scala.util.chaining._
Expand Down Expand Up @@ -387,28 +388,31 @@ class ArrayBufferTest {
// scala/bug#7880 and scala/bug#12464
@Test def `ensureSize must terminate and have limits`(): Unit = {
val sut = getMethodAccessible[ArrayBuffer.type]("resizeUp")
def resizeUp(arrayLen: Long, targetLen: Long): Int = sut.invoke(ArrayBuffer, arrayLen, targetLen).asInstanceOf[Int]
def resizeUp(arrayLen: Int, targetLen: Int): Int = sut.invoke(ArrayBuffer, arrayLen, targetLen).asInstanceOf[Int]

// check termination and correctness
assertTrue(7 < ArrayBuffer.DefaultInitialSize) // condition of test
assertTrue(7 < ArrayBuffer.DefaultInitialSize) // condition of test
assertEquals(ArrayBuffer.DefaultInitialSize, resizeUp(7, 10))
assertEquals(Int.MaxValue - 2, resizeUp(Int.MaxValue / 2, Int.MaxValue / 2 + 1)) // was: ok
assertEquals(Int.MaxValue - 2, resizeUp(Int.MaxValue / 2, Int.MaxValue / 2 + 2)) // was: ok
assertEquals(-1, resizeUp(Int.MaxValue / 2 + 1, Int.MaxValue / 2 + 1)) // was: wrong
assertEquals(Int.MaxValue - 2, resizeUp(Int.MaxValue / 2 + 1, Int.MaxValue / 2 + 2)) // was: hang
assertEquals(Int.MaxValue - 2, resizeUp(Int.MaxValue / 2, Int.MaxValue - 2))
assertEquals(VM_MaxArraySize, resizeUp(Int.MaxValue / 2, Int.MaxValue / 2 + 1)) // `MaxValue / 2` cannot be doubled
assertEquals(VM_MaxArraySize / 2 * 2, resizeUp(VM_MaxArraySize / 2, VM_MaxArraySize / 2 + 1)) // `VM_MaxArraySize / 2` can be doubled
assertEquals(VM_MaxArraySize, resizeUp(Int.MaxValue / 2, Int.MaxValue / 2 + 2))
assertEquals(-1, resizeUp(Int.MaxValue / 2 + 1, Int.MaxValue / 2 + 1)) // no resize needed
assertEquals(VM_MaxArraySize, resizeUp(Int.MaxValue / 2 + 1, Int.MaxValue / 2 + 2))
assertEquals(VM_MaxArraySize, resizeUp(Int.MaxValue / 2, VM_MaxArraySize))
assertEquals(123456*2+33, resizeUp(123456, 123456*2+33)) // use targetLen if it's larger than double the current

// check limits
def rethrow(op: => Any): Unit =
try op catch { case e: InvocationTargetException => throw e.getCause }
def checkExceedsMaxInt(targetLen: Long): Unit =
def checkExceedsMaxInt(targetLen: Int): Unit = {
assertThrows[Exception](rethrow(resizeUp(0, targetLen)),
_ == "Collections cannot have more than 2147483647 elements")
def checkExceedsVMArrayLimit(targetLen: Long): Unit =
_ == s"Overflow while resizing array of array-backed collection. Requested length: $targetLen; current length: 0; increase: $targetLen")
}
def checkExceedsVMArrayLimit(targetLen: Int): Unit =
assertThrows[Exception](rethrow(resizeUp(0, targetLen)),
_ == "Size of array-backed collection exceeds VM array size limit of 2147483645")
_ == s"Array of array-backed collection exceeds VM length limit of $VM_MaxArraySize. Requested length: $targetLen; current length: 0")

checkExceedsMaxInt(Int.MaxValue + 1L)
checkExceedsMaxInt(Int.MaxValue + 1)
checkExceedsVMArrayLimit(Int.MaxValue)
checkExceedsVMArrayLimit(Int.MaxValue - 1)
}
Expand Down
35 changes: 35 additions & 0 deletions test/junit/scala/collection/mutable/ArrayBuilderTest.scala
@@ -0,0 +1,35 @@
package scala.collection.mutable

import org.junit.Assert.assertEquals
import org.junit.Test

import scala.runtime.PStatics.VM_MaxArraySize
import scala.tools.testkit.AssertUtil.assertThrows

class ArrayBuilderTest {
@Test
def t12617: Unit = {
val ab: ArrayBuilder[Unit] = ArrayBuilder.make[Unit]

// ArrayBuilder.ofUnit.addAll doesn't iterate if the iterator has a `knownSize`
ab.addAll(new Iterator[Unit] {
override def knownSize: Int = VM_MaxArraySize
def hasNext: Boolean = true
def next(): Unit = ()
})

// reached maximum size without entering an infinite loop?
assertEquals(ab.length, VM_MaxArraySize)

// expect an exception when trying to grow larger than maximum size by addOne
assertThrows[Exception](ab.addOne(()), _.endsWith("Requested length: 2147483640; current length: 2147483639"))

val arr = Array[Unit]((), (), (), (), (), (), (), (), (), (), (), ())

// expect an exception when trying to grow larger than maximum size by addAll(iterator)
assertThrows[Exception](ab.addAll(arr.iterator), _.endsWith("Requested length: -2147483645; current length: 2147483639; increase: 12"))

// expect an exception when trying to grow larger than maximum size by addAll(array)
assertThrows[Exception](ab.addAll(arr))
}
}

0 comments on commit ed8005f

Please sign in to comment.