Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More efficient ArraySeq iteration #8300

Merged
merged 1 commit into from Aug 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions build.sbt
Expand Up @@ -95,6 +95,32 @@ val mimaFilterSettings = Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.reflect.api.TypeTags.TypeTagImpl"),
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.reflect.api.Universe.TypeTagImpl"),
ProblemFilters.exclude[MissingClassProblem]("scala.reflect.macros.Attachments$"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("scala.collection.immutable.ArraySeq.stepper"),
ProblemFilters.exclude[ReversedAbstractMethodProblem]("scala.collection.immutable.ArraySeq.stepper"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("scala.collection.mutable.ArraySeq.stepper"),
ProblemFilters.exclude[ReversedAbstractMethodProblem]("scala.collection.mutable.ArraySeq.stepper"),
ProblemFilters.exclude[FinalClassProblem]("scala.collection.ArrayOps$GroupedIterator"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcB$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcZ$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcV$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcD$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcJ$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcV$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcB$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$GroupedIterator"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcF$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcC$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcF$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcS$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcI$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcC$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcJ$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcD$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ReverseIterator$mcZ$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcI$sp"),
ProblemFilters.exclude[MissingTypesProblem]("scala.collection.ArrayOps$ArrayIterator$mcS$sp"),
),
)

Expand Down
10 changes: 7 additions & 3 deletions src/library/scala/collection/ArrayOps.scala
Expand Up @@ -119,9 +119,11 @@ object ArrayOps {
def withFilter(q: A => Boolean): WithFilter[A] = new WithFilter[A](a => p(a) && q(a), xs)
}

private final class ArrayIterator[@specialized(Specializable.Everything) A](xs: Array[A]) extends AbstractIterator[A] {
@SerialVersionUID(3L)
private[collection] final class ArrayIterator[@specialized(Specializable.Everything) A](xs: Array[A]) extends AbstractIterator[A] with Serializable {
private[this] var pos = 0
private[this] val len = xs.length
override def knownSize = len - pos
def hasNext: Boolean = pos < len
def next(): A = try {
val r = xs(pos)
Expand All @@ -134,7 +136,8 @@ object ArrayOps {
}
}

private final class ReverseIterator[@specialized(Specializable.Everything) A](xs: Array[A]) extends AbstractIterator[A] {
@SerialVersionUID(3L)
private final class ReverseIterator[@specialized(Specializable.Everything) A](xs: Array[A]) extends AbstractIterator[A] with Serializable {
private[this] var pos = xs.length-1
def hasNext: Boolean = pos >= 0
def next(): A = try {
Expand All @@ -149,7 +152,8 @@ object ArrayOps {
}
}

private class GroupedIterator[A](xs: Array[A], groupSize: Int) extends AbstractIterator[Array[A]] {
@SerialVersionUID(3L)
private final class GroupedIterator[A](xs: Array[A], groupSize: Int) extends AbstractIterator[Array[A]] with Serializable {
private[this] var pos = 0
def hasNext: Boolean = pos < xs.length
def next(): Array[A] = {
Expand Down
83 changes: 56 additions & 27 deletions src/library/scala/collection/immutable/ArraySeq.scala
Expand Up @@ -18,6 +18,7 @@ import java.util.Arrays
import scala.annotation.unchecked.uncheckedVariance
import scala.collection.Stepper.EfficientSplit
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Builder, ArraySeq => MutableArraySeq}
import scala.collection.convert.impl._
import scala.reflect.ClassTag
import scala.runtime.ScalaRunTime
import scala.util.Sorting
Expand Down Expand Up @@ -55,33 +56,7 @@ sealed abstract class ArraySeq[+A]
protected def evidenceIterableFactory: ArraySeq.type = ArraySeq
protected def iterableEvidence: ClassTag[A @uncheckedVariance] = elemTag.asInstanceOf[ClassTag[A]]

override def stepper[S <: Stepper[_]](implicit shape: StepperShape[A, S]): S with EfficientSplit = {
import scala.collection.convert.impl._
val isRefShape = shape.shape == StepperShape.ReferenceShape
val s = if (isRefShape) unsafeArray match {
case a: Array[Int] => AnyStepper.ofParIntStepper (new IntArrayStepper(a, 0, a.length))
case a: Array[Long] => AnyStepper.ofParLongStepper (new LongArrayStepper(a, 0, a.length))
case a: Array[Double] => AnyStepper.ofParDoubleStepper(new DoubleArrayStepper(a, 0, a.length))
case a: Array[Byte] => AnyStepper.ofParIntStepper (new WidenedByteArrayStepper(a, 0, a.length))
case a: Array[Short] => AnyStepper.ofParIntStepper (new WidenedShortArrayStepper(a, 0, a.length))
case a: Array[Char] => AnyStepper.ofParIntStepper (new WidenedCharArrayStepper(a, 0, a.length))
case a: Array[Float] => AnyStepper.ofParDoubleStepper(new WidenedFloatArrayStepper(a, 0, a.length))
case a: Array[Boolean] => new BoxedBooleanArrayStepper(a, 0, a.length)
case a: Array[AnyRef] => new ObjectArrayStepper(a, 0, a.length)
} else {
unsafeArray match {
case a: Array[AnyRef] => shape.parUnbox(new ObjectArrayStepper(a, 0, a.length).asInstanceOf[AnyStepper[A] with EfficientSplit])
case a: Array[Int] => new IntArrayStepper(a, 0, a.length)
case a: Array[Long] => new LongArrayStepper(a, 0, a.length)
case a: Array[Double] => new DoubleArrayStepper(a, 0, a.length)
case a: Array[Byte] => new WidenedByteArrayStepper(a, 0, a.length)
case a: Array[Short] => new WidenedShortArrayStepper(a, 0, a.length)
case a: Array[Char] => new WidenedCharArrayStepper(a, 0, a.length)
case a: Array[Float] => new WidenedFloatArrayStepper(a, 0, a.length)
}
}
s.asInstanceOf[S with EfficientSplit]
}
def stepper[S <: Stepper[_]](implicit shape: StepperShape[A, S]): S with EfficientSplit

@throws[ArrayIndexOutOfBoundsException]
def apply(i: Int): A
Expand Down Expand Up @@ -276,6 +251,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
new ArraySeq.ofRef(a)
}
}
override def iterator: Iterator[T] = new ArrayOps.ArrayIterator[T](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[T, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
new ObjectArrayStepper(unsafeArray, 0, unsafeArray.length)
else shape.parUnbox(new ObjectArrayStepper(unsafeArray, 0, unsafeArray.length).asInstanceOf[AnyStepper[T] with EfficientSplit])
).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -296,6 +277,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
Arrays.sort(a)
new ArraySeq.ofByte(a)
} else super.sorted[B]
override def iterator: Iterator[Byte] = new ArrayOps.ArrayIterator[Byte](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Byte, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
AnyStepper.ofParIntStepper(new WidenedByteArrayStepper(unsafeArray, 0, unsafeArray.length))
else new WidenedByteArrayStepper(unsafeArray, 0, unsafeArray.length)
).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -316,6 +303,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
Arrays.sort(a)
new ArraySeq.ofShort(a)
} else super.sorted[B]
override def iterator: Iterator[Short] = new ArrayOps.ArrayIterator[Short](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Short, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
AnyStepper.ofParIntStepper(new WidenedShortArrayStepper(unsafeArray, 0, unsafeArray.length))
else new WidenedShortArrayStepper(unsafeArray, 0, unsafeArray.length)
).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -336,6 +329,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
Arrays.sort(a)
new ArraySeq.ofChar(a)
} else super.sorted[B]
override def iterator: Iterator[Char] = new ArrayOps.ArrayIterator[Char](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Char, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
AnyStepper.ofParIntStepper(new WidenedCharArrayStepper(unsafeArray, 0, unsafeArray.length))
else new WidenedCharArrayStepper(unsafeArray, 0, unsafeArray.length)
).asInstanceOf[S with EfficientSplit]

override def addString(sb: StringBuilder, start: String, sep: String, end: String): StringBuilder =
(new MutableArraySeq.ofChar(unsafeArray)).addString(sb, start, sep, end)
Expand All @@ -359,6 +358,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
Arrays.sort(a)
new ArraySeq.ofInt(a)
} else super.sorted[B]
override def iterator: Iterator[Int] = new ArrayOps.ArrayIterator[Int](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Int, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
AnyStepper.ofParIntStepper(new IntArrayStepper(unsafeArray, 0, unsafeArray.length))
else new IntArrayStepper(unsafeArray, 0, unsafeArray.length)
).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -379,6 +384,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
Arrays.sort(a)
new ArraySeq.ofLong(a)
} else super.sorted[B]
override def iterator: Iterator[Long] = new ArrayOps.ArrayIterator[Long](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Long, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
AnyStepper.ofParLongStepper(new LongArrayStepper(unsafeArray, 0, unsafeArray.length))
else new LongArrayStepper(unsafeArray, 0, unsafeArray.length)
).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -392,6 +403,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
case that: ofFloat => Arrays.equals(unsafeArray, that.unsafeArray)
case _ => super.equals(that)
}
override def iterator: Iterator[Float] = new ArrayOps.ArrayIterator[Float](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Float, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
AnyStepper.ofParDoubleStepper(new WidenedFloatArrayStepper(unsafeArray, 0, unsafeArray.length))
else new WidenedFloatArrayStepper(unsafeArray, 0, unsafeArray.length)
).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -405,6 +422,12 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
case that: ofDouble => Arrays.equals(unsafeArray, that.unsafeArray)
case _ => super.equals(that)
}
override def iterator: Iterator[Double] = new ArrayOps.ArrayIterator[Double](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Double, S]): S with EfficientSplit = (
if(shape.shape == StepperShape.ReferenceShape)
AnyStepper.ofParDoubleStepper(new DoubleArrayStepper(unsafeArray, 0, unsafeArray.length))
else new DoubleArrayStepper(unsafeArray, 0, unsafeArray.length)
).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -425,6 +448,9 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
Sorting.stableSort(a)
new ArraySeq.ofBoolean(a)
} else super.sorted[B]
override def iterator: Iterator[Boolean] = new ArrayOps.ArrayIterator[Boolean](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Boolean, S]): S with EfficientSplit =
new BoxedBooleanArrayStepper(unsafeArray, 0, unsafeArray.length).asInstanceOf[S with EfficientSplit]
}

@SerialVersionUID(3L)
Expand All @@ -438,5 +464,8 @@ object ArraySeq extends StrictOptimizedClassTagSeqFactory[ArraySeq] { self =>
case that: ofUnit => unsafeArray.length == that.unsafeArray.length
case _ => super.equals(that)
}
override def iterator: Iterator[Unit] = new ArrayOps.ArrayIterator[Unit](unsafeArray)
override def stepper[S <: Stepper[_]](implicit shape: StepperShape[Unit, S]): S with EfficientSplit =
new ObjectArrayStepper[AnyRef](unsafeArray.asInstanceOf[Array[AnyRef]], 0, unsafeArray.length).asInstanceOf[S with EfficientSplit]
}
}