Skip to content

Commit

Permalink
Merge pull request #9889 from som-snytt/topic/xlint-perf
Browse files Browse the repository at this point in the history
  • Loading branch information
SethTisue committed Mar 11, 2022
2 parents ce02613 + 6654dcf commit 2358bf0
Show file tree
Hide file tree
Showing 21 changed files with 350 additions and 148 deletions.
1 change: 1 addition & 0 deletions src/compiler/scala/tools/nsc/Reporting.scala
Expand Up @@ -404,6 +404,7 @@ object Reporting {
object LintRecurseWithDefault extends Lint; add(LintRecurseWithDefault)
object LintUnitSpecialization extends Lint; add(LintUnitSpecialization)
object LintMultiargInfix extends Lint; add(LintMultiargInfix)
object LintPerformance extends Lint; add(LintPerformance)

sealed trait Feature extends WarningCategory { override def summaryCategory: WarningCategory = Feature }
object Feature extends Feature { override def includes(o: WarningCategory): Boolean = o.isInstanceOf[Feature] }; add(Feature)
Expand Down
19 changes: 17 additions & 2 deletions src/compiler/scala/tools/nsc/settings/Warnings.scala
Expand Up @@ -98,7 +98,7 @@ trait Warnings {
|to prevent the shell from expanding patterns.""".stripMargin),
prepend = true)

// Non-lint warnings. -- TODO turn into MultiChoiceEnumeration
// Non-lint warnings.
val warnMacros = ChoiceSetting(
name = "-Wmacros",
helpArg = "mode",
Expand All @@ -117,6 +117,20 @@ trait Warnings {
val warnNumericWiden = BooleanSetting("-Wnumeric-widen", "Warn when numerics are widened.") withAbbreviation "-Ywarn-numeric-widen"
val warnOctalLiteral = BooleanSetting("-Woctal-literal", "Warn on obsolete octal syntax.") withAbbreviation "-Ywarn-octal-literal"

object PerformanceWarnings extends MultiChoiceEnumeration {
val Captured = Choice("captured", "Modification of var in closure causes boxing.")
val NonlocalReturn = Choice("nonlocal-return", "A return statement used an exception for flow control.")
}
val warnPerformance = MultiChoiceSetting(
name = "-Wperformance",
helpArg = "warning",
descr = "Enable or disable specific lints for performance",
domain = PerformanceWarnings,
default = Some(List("_"))
)
def warnCaptured = warnPerformance.contains(PerformanceWarnings.Captured)
def warnNonlocalReturn = warnPerformance.contains(PerformanceWarnings.NonlocalReturn)

object UnusedWarnings extends MultiChoiceEnumeration {
val Imports = Choice("imports", "Warn if an import selector is not referenced.")
val PatVars = Choice("patvars", "Warn if a variable bound in a pattern is unused.")
Expand Down Expand Up @@ -216,7 +230,6 @@ trait Warnings {
def warnStarsAlign = lint contains StarsAlign
def warnConstant = lint contains Constant
def lintUnused = lint contains Unused
def warnNonlocalReturn = lint contains NonlocalReturn
def lintImplicitNotFound = lint contains ImplicitNotFound
def warnSerialization = lint contains Serial
def lintValPatterns = lint contains ValPattern
Expand All @@ -241,6 +254,8 @@ trait Warnings {
if (s contains Unused) warnUnused.enable(UnusedWarnings.Linted)
else warnUnused.disable(UnusedWarnings.Linted)
if (s.contains(Deprecation) && deprecation.isDefault) deprecation.value = true
if (s.contains(NonlocalReturn)) warnPerformance.enable(PerformanceWarnings.NonlocalReturn)
else warnPerformance.disable(PerformanceWarnings.NonlocalReturn)
}

// Backward compatibility.
Expand Down
35 changes: 17 additions & 18 deletions src/compiler/scala/tools/nsc/transform/LambdaLift.scala
Expand Up @@ -18,6 +18,7 @@ import Flags._
import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.{LinkedHashMap, LinkedHashSet}
import scala.tools.nsc.Reporting.WarningCategory.LintPerformance

abstract class LambdaLift extends InfoTransform {
import global._
Expand Down Expand Up @@ -492,25 +493,23 @@ abstract class LambdaLift extends InfoTransform {
if (sym.isLocalToBlock) liftDef(withFreeParams)
else withFreeParams

case ValDef(mods, name, tpt, rhs) =>
if (sym.isCapturedVariable) {
val tpt1 = TypeTree(sym.tpe) setPos tpt.pos

val refTypeSym = sym.tpe.typeSymbol

val factoryCall = typer.typedPos(rhs.pos) {
rhs match {
case EmptyTree =>
val zeroMSym = refZeroMethod(refTypeSym)
gen.mkMethodCall(zeroMSym, Nil)
case arg =>
val createMSym = refCreateMethod(refTypeSym)
gen.mkMethodCall(createMSym, arg :: Nil)
}
case ValDef(mods, name, tpt, rhs) if sym.isCapturedVariable =>
val tpt1 = TypeTree(sym.tpe) setPos tpt.pos
val refTypeSym = sym.tpe.typeSymbol
val factoryCall = typer.typedPos(rhs.pos) {
rhs match {
case EmptyTree =>
val zeroMSym = refZeroMethod(refTypeSym)
gen.mkMethodCall(zeroMSym, Nil)
case arg =>
val createMSym = refCreateMethod(refTypeSym)
gen.mkMethodCall(createMSym, arg :: Nil)
}

treeCopy.ValDef(tree, mods, name, tpt1, factoryCall)
} else tree
}
if (settings.warnCaptured)
runReporting.warning(tree.pos, s"Modification of variable $name within a closure causes it to be boxed.", LintPerformance, sym)
treeCopy.ValDef(tree, mods, name, tpt1, factoryCall)
case ValDef(_, _, _, _) => tree
case Return(Block(stats, value)) =>
Block(stats, treeCopy.Return(tree, value)) setType tree.tpe setPos tree.pos
case Return(expr) =>
Expand Down
43 changes: 11 additions & 32 deletions src/library/scala/StringContext.scala
Expand Up @@ -14,6 +14,7 @@ package scala

import java.lang.{ StringBuilder => JLSBuilder }
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuilder

/** This class provides the basic mechanism to do String Interpolation.
* String Interpolation allows users
Expand Down Expand Up @@ -219,53 +220,31 @@ object StringContext {
val nameLength = input.length
// The final pattern is as long as all the chunks, separated by 1-character
// glob-wildcard placeholders
val patternLength = {
var n = numWildcards
for(chunk <- patternChunks) {
n += chunk.length
}
n
}
val patternLength = patternChunks.iterator.map(_.length).sum + numWildcards

// Convert the input pattern chunks into a single sequence of shorts; each
// non-negative short represents a character, while -1 represents a glob wildcard
val pattern = {
val arr = new Array[Short](patternLength)
var i = 0
var first = true
for(chunk <- patternChunks) {
if (first) first = false
else {
arr(i) = -1
i += 1
}
for(c <- chunk) {
arr(i) = c.toShort
i += 1
}
}
arr
val b = new ArrayBuilder.ofShort ; b.sizeHint(patternLength)
patternChunks.head.foreach(c => b.addOne(c.toShort))
patternChunks.tail.foreach { s => b.addOne(-1) ; s.foreach(c => b.addOne(c.toShort)) }
b.result()
}

// Lookup table for each character in the pattern to check whether or not
// it refers to a glob wildcard; a non-negative integer indicates which
// glob wildcard it represents, while -1 means it doesn't represent any
val matchIndices = {
val arr = Array.fill(patternLength + 1)(-1)
var i = 0
var j = 0
for(chunk <- patternChunks) {
if (j < numWildcards) {
i += chunk.length
arr(i) = j
i += 1
j += 1
}
patternChunks.init.zipWithIndex.foldLeft(0) { case (ttl, (chunk, i)) =>
val sum = ttl + chunk.length
arr(sum) = i
sum + 1
}
arr
}

while(patternIndex < patternLength || inputIndex < nameLength) {
while (patternIndex < patternLength || inputIndex < nameLength) {
matchIndices(patternIndex) match {
case -1 => // do nothing
case n =>
Expand Down
35 changes: 14 additions & 21 deletions src/library/scala/collection/ArrayOps.scala
Expand Up @@ -173,6 +173,9 @@ object ArrayOps {
* an implementation that copies the data to a boxed representation for use with `Arrays.sort`.
*/
private final val MaxStableSortLength = 300

/** Avoid an allocation in [[collect]]. */
private val fallback: Any => Any = _ => fallback
}

/** This class serves as a wrapper for `Array`s with many of the operations found in
Expand Down Expand Up @@ -1010,36 +1013,26 @@ final class ArrayOps[A](private val xs: Array[A]) extends AnyVal {
* `pf` to each element on which it is defined and collecting the results.
* The order of the elements is preserved.
*/
def collect[B : ClassTag](pf: PartialFunction[A, B]): Array[B] = {
var i = 0
var matched = true
def d(x: A): B = {
matched = false
null.asInstanceOf[B]
}
def collect[B: ClassTag](pf: PartialFunction[A, B]): Array[B] = {
val fallback: Any => Any = ArrayOps.fallback
val b = ArrayBuilder.make[B]
while(i < xs.length) {
matched = true
val v = pf.applyOrElse(xs(i), d)
if(matched) b += v
var i = 0
while (i < xs.length) {
val v = pf.applyOrElse(xs(i), fallback)
if (v.asInstanceOf[AnyRef] ne fallback) b.addOne(v.asInstanceOf[B])
i += 1
}
b.result()
}

/** Finds the first element of the array for which the given partial function is defined, and applies the
* partial function to it. */
def collectFirst[B](f: PartialFunction[A, B]): Option[B] = {
def collectFirst[B](@deprecatedName("f","2.13.9") pf: PartialFunction[A, B]): Option[B] = {
val fallback: Any => Any = ArrayOps.fallback
var i = 0
var matched = true
def d(x: A): B = {
matched = false
null.asInstanceOf[B]
}
while(i < xs.length) {
matched = true
val v = f.applyOrElse(xs(i), d)
if(matched) return Some(v)
while (i < xs.length) {
val v = pf.applyOrElse(xs(i), fallback)
if (v.asInstanceOf[AnyRef] ne fallback) return Some(v.asInstanceOf[B])
i += 1
}
None
Expand Down
25 changes: 15 additions & 10 deletions src/library/scala/collection/Iterable.scala
Expand Up @@ -597,11 +597,13 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable
val bldr = m.getOrElseUpdate(k, iterableFactory.newBuilder[B])
bldr += f(elem)
}
var result = immutable.Map.empty[K, CC[B]]
m.foreach { case (k, v) =>
result = result + ((k, v.result()))
object result extends Function[(K, Builder[B, CC[B]]), Unit] {
var built = immutable.Map.empty[K, CC[B]]
def apply(kv: (K, Builder[B, CC[B]])) =
built = built.updated(kv._1, kv._2.result())
}
result
m.foreach(result)
result.built
}

/**
Expand Down Expand Up @@ -663,13 +665,16 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable
* @return collection with intermediate results
*/
def scanRight[B](z: B)(op: (A, B) => B): CC[B] = {
var scanned = z :: immutable.Nil
var acc = z
for (x <- reversed) {
acc = op(x, acc)
scanned ::= acc
object scanner extends (A => Unit) {
var acc = z
var scanned = acc :: immutable.Nil
def apply(x: A) = {
acc = op(x, acc)
scanned ::= acc
}
}
iterableFactory.from(scanned)
reversed.foreach(scanner)
iterableFactory.from(scanner.scanned)
}

def map[B](f: A => B): CC[B] = iterableFactory.from(new View.Map(this, f))
Expand Down
53 changes: 29 additions & 24 deletions src/library/scala/collection/IterableOnce.scala
Expand Up @@ -508,8 +508,11 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A] =>
* elements of this $coll, and the other elements.
*/
def splitAt(n: Int): (C, C) = {
var i = 0
span { _ => if (i < n) { i += 1; true } else false }
object spanner extends (A => Boolean) {
var i = 0
def apply(a: A) = i < n && { i += 1 ; true }
}
span(spanner)
}

/** Applies a side-effecting function to each element in this collection.
Expand Down Expand Up @@ -984,19 +987,20 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A] =>
if (isEmpty)
throw new UnsupportedOperationException("empty.maxBy")

var maxF: B = null.asInstanceOf[B]
var maxElem: A = null.asInstanceOf[A]
var first = true

for (elem <- this) {
val fx = f(elem)
if (first || cmp.gt(fx, maxF)) {
maxElem = elem
maxF = fx
first = false
object maximizer extends (A => Unit) {
var maxF: B = null.asInstanceOf[B]
var maxElem: A = null.asInstanceOf[A]
var first = true
def apply(elem: A) = {
val fx = f(elem)
if (first && { first = false ; true } || cmp.gt(fx, maxF)) {
maxElem = elem
maxF = fx
}
}
}
maxElem
foreach(maximizer)
maximizer.maxElem
}

/** Finds the first element which yields the largest value measured by function f.
Expand Down Expand Up @@ -1031,19 +1035,20 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A] =>
if (isEmpty)
throw new UnsupportedOperationException("empty.minBy")

var minF: B = null.asInstanceOf[B]
var minElem: A = null.asInstanceOf[A]
var first = true

for (elem <- this) {
val fx = f(elem)
if (first || cmp.lt(fx, minF)) {
minElem = elem
minF = fx
first = false
object minimizer extends (A => Unit) {
var minF: B = null.asInstanceOf[B]
var minElem: A = null.asInstanceOf[A]
var first = true
def apply(elem: A) = {
val fx = f(elem)
if (first && { first = false ; true } || cmp.lt(fx, minF)) {
minElem = elem
minF = fx
}
}
}
minElem
foreach(minimizer)
minimizer.minElem
}

/** Finds the first element which yields the smallest value measured by function f.
Expand Down

0 comments on commit 2358bf0

Please sign in to comment.