Skip to content

Commit

Permalink
Merge pull request #10674 from som-snytt/review/const-fold-warn
Browse files Browse the repository at this point in the history
Warn more constant inexactitude
  • Loading branch information
lrytz committed Jan 26, 2024
2 parents 8d598d1 + 69af723 commit 30cfb3c
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 25 deletions.
60 changes: 45 additions & 15 deletions src/compiler/scala/tools/nsc/typechecker/ConstantFolder.scala
Expand Up @@ -17,6 +17,7 @@ package typechecker
import java.lang.ArithmeticException

import scala.tools.nsc.Reporting.WarningCategory
import scala.util.control.ControlThrowable

/** This class ...
*
Expand Down Expand Up @@ -60,8 +61,8 @@ abstract class ConstantFolder {
def apply(tree: Tree, site: Symbol): Tree = if (isPastTyper) tree else
try {
tree match {
case Apply(Select(FoldableTerm(x), op), List(FoldableTerm(y))) => fold(tree, foldBinop(op, x, y), foldable = true)
case Apply(Select(ConstantTerm(x), op), List(ConstantTerm(y))) => fold(tree, foldBinop(op, x, y), foldable = false)
case Apply(Select(FoldableTerm(x), op), List(FoldableTerm(y))) => fold(tree, safelyFoldBinop(tree, site)(op, x, y), foldable = true)
case Apply(Select(ConstantTerm(x), op), List(ConstantTerm(y))) => fold(tree, safelyFoldBinop(tree, site)(op, x, y), foldable = false)
case Select(FoldableTerm(x), op) => fold(tree, foldUnop(op, x), foldable = true)
case Select(ConstantTerm(x), op) => fold(tree, foldUnop(op, x), foldable = false)
case _ => tree
Expand All @@ -86,9 +87,11 @@ abstract class ConstantFolder {
/** Set the computed constant type.
*/
private def fold(orig: Tree, folded: Constant, foldable: Boolean): Tree =
if ((folded eq null) || folded.tag == UnitTag) orig
else if (foldable) orig setType FoldableConstantType(folded)
else orig setType LiteralType(folded)
if (folded == null || folded.tag == UnitTag) orig
else orig.setType {
if (foldable) FoldableConstantType(folded)
else LiteralType(folded)
}

private def foldUnop(op: Name, x: Constant): Constant = {
val N = nme
Expand Down Expand Up @@ -127,8 +130,7 @@ abstract class ConstantFolder {
if (value != null) Constant(value) else null
}

/** These are local helpers to keep foldBinop from overly taxing the
* optimizer.
/** These are local helpers to keep foldBinop from overly taxing the optimizer.
*/
private def foldBooleanOp(op: Name, x: Constant, y: Constant): Constant = op match {
case nme.ZOR => Constant(x.booleanValue | y.booleanValue)
Expand All @@ -153,10 +155,18 @@ abstract class ConstantFolder {
case nme.GT => Constant(x.intValue > y.intValue)
case nme.LE => Constant(x.intValue <= y.intValue)
case nme.GE => Constant(x.intValue >= y.intValue)
case nme.ADD => Constant(x.intValue + y.intValue)
case nme.SUB => Constant(x.intValue - y.intValue)
case nme.MUL => Constant(x.intValue * y.intValue)
case nme.DIV => Constant(x.intValue / y.intValue)
case nme.ADD => Constant(safely(Math.addExact(x.intValue, y.intValue), x.intValue + y.intValue))
case nme.SUB => Constant(safely(Math.subtractExact(x.intValue, y.intValue), x.intValue - y.intValue))
case nme.MUL => Constant(safely(Math.multiplyExact(x.intValue, y.intValue), x.intValue * y.intValue))
case nme.DIV =>
val xd = x.intValue
val yd = y.intValue
val value =
if (yd == 0) xd / yd // Math.divideExact(xd, yd) // de-optimize
else if (yd == -1 && xd == Int.MinValue)
safely(throw new ArithmeticException("integer overflow"), xd / yd)
else xd / yd
Constant(value)
case nme.MOD => Constant(x.intValue % y.intValue)
case _ => null
}
Expand All @@ -179,10 +189,18 @@ abstract class ConstantFolder {
case nme.GT => Constant(x.longValue > y.longValue)
case nme.LE => Constant(x.longValue <= y.longValue)
case nme.GE => Constant(x.longValue >= y.longValue)
case nme.ADD => Constant(x.longValue + y.longValue)
case nme.SUB => Constant(x.longValue - y.longValue)
case nme.MUL => Constant(x.longValue * y.longValue)
case nme.DIV => Constant(x.longValue / y.longValue)
case nme.ADD => Constant(safely(Math.addExact(x.longValue, y.longValue), x.longValue + y.longValue))
case nme.SUB => Constant(safely(Math.subtractExact(x.longValue, y.longValue), x.longValue - y.longValue))
case nme.MUL => Constant(safely(Math.multiplyExact(x.longValue, y.longValue), x.longValue * y.longValue))
case nme.DIV =>
val xd = x.longValue
val yd = y.longValue
val value =
if (yd == 0) xd / yd // Math.divideExact(xd, yd) // de-optimize
else if (yd == -1 && xd == Long.MinValue)
safely(throw new ArithmeticException("long overflow"), xd / yd)
else xd / yd
Constant(value)
case nme.MOD => Constant(x.longValue % y.longValue)
case _ => null
}
Expand Down Expand Up @@ -231,4 +249,16 @@ abstract class ConstantFolder {
case _ => null
}
}
private def safelyFoldBinop(tree: Tree, site: Symbol)(op: Name, x: Constant, y: Constant): Constant =
try foldBinop(op, x, y)
catch {
case e: ConstFoldException =>
if (settings.warnConstant)
runReporting.warning(tree.pos, s"Evaluation of a constant expression results in an arithmetic error: ${e.getMessage}, using ${e.value}", WarningCategory.LintConstant, site)
Constant(e.value)
}
private def safely[A](exact: => A, inexact: A): A =
try exact
catch { case e: ArithmeticException => throw new ConstFoldException(e.getMessage, inexact) }
private class ConstFoldException(msg: String, val value: Any) extends ControlThrowable(msg)
}
3 changes: 1 addition & 2 deletions src/library/scala/collection/immutable/RedBlackTree.scala
Expand Up @@ -775,8 +775,7 @@ private[collection] object RedBlackTree {
//see #Tree docs "Colour, mutablity and size encoding"
//we make these final vals because the optimiser inlines them, without reference to the enclosing module
private[RedBlackTree] final val colourBit = 0x80000000
//really its ~colourBit but that doesnt get inlined
private[RedBlackTree] final val colourMask = colourBit - 1
private[RedBlackTree] final val colourMask = ~colourBit
private[RedBlackTree] final val initialBlackCount = colourBit
private[RedBlackTree] final val initialRedCount = 0

Expand Down
37 changes: 35 additions & 2 deletions test/files/neg/constant-warning.check
@@ -1,6 +1,39 @@
constant-warning.scala:3: warning: Evaluation of a constant expression results in an arithmetic error: / by zero
constant-warning.scala:4: warning: Evaluation of a constant expression results in an arithmetic error: / by zero
val fails = 1 + 2 / (3 - 2 - 1)
^
constant-warning.scala:6: warning: Evaluation of a constant expression results in an arithmetic error: integer overflow, using -2147483607
val addi: Int = Int.MaxValue + 42
^
constant-warning.scala:7: warning: Evaluation of a constant expression results in an arithmetic error: integer overflow, using 2147483606
val subi: Int = Int.MinValue - 42
^
constant-warning.scala:8: warning: Evaluation of a constant expression results in an arithmetic error: integer overflow, using -2
val muli: Int = Int.MaxValue * 2
^
constant-warning.scala:9: warning: Evaluation of a constant expression results in an arithmetic error: integer overflow, using -2147483648
val divi: Int = Int.MinValue / -1
^
constant-warning.scala:10: warning: Evaluation of a constant expression results in an arithmetic error: / by zero
val divz: Int = Int.MinValue / 0
^
constant-warning.scala:12: warning: Evaluation of a constant expression results in an arithmetic error: integer overflow, using 0
val long: Long = 100 * 1024 * 1024 * 1024
^
constant-warning.scala:13: warning: Evaluation of a constant expression results in an arithmetic error: long overflow, using -9223372036854775767
val addl: Long = Long.MaxValue + 42
^
constant-warning.scala:14: warning: Evaluation of a constant expression results in an arithmetic error: long overflow, using 9223372036854775766
val subl: Long = Long.MinValue - 42
^
constant-warning.scala:15: warning: Evaluation of a constant expression results in an arithmetic error: long overflow, using -2
val mull: Long = Long.MaxValue * 2
^
constant-warning.scala:16: warning: Evaluation of a constant expression results in an arithmetic error: long overflow, using -9223372036854775808
val divl: Long = Long.MinValue / -1
^
constant-warning.scala:17: warning: Evaluation of a constant expression results in an arithmetic error: / by zero
val divlz: Long = Long.MinValue / 0
^
error: No warnings can be incurred under -Werror.
1 warning
12 warnings
1 error
16 changes: 15 additions & 1 deletion test/files/neg/constant-warning.scala
@@ -1,4 +1,18 @@
// scalac: -Xlint:constant -Xfatal-warnings
//> using options -Werror -Xlint:constant
//-Vprint:cleanup (bytecode test to ensure warnable constants are folded)
object Test {
val fails = 1 + 2 / (3 - 2 - 1)

val addi: Int = Int.MaxValue + 42
val subi: Int = Int.MinValue - 42
val muli: Int = Int.MaxValue * 2
val divi: Int = Int.MinValue / -1
val divz: Int = Int.MinValue / 0

val long: Long = 100 * 1024 * 1024 * 1024
val addl: Long = Long.MaxValue + 42
val subl: Long = Long.MinValue - 42
val mull: Long = Long.MaxValue * 2
val divl: Long = Long.MinValue / -1
val divlz: Long = Long.MinValue / 0
}
9 changes: 5 additions & 4 deletions test/junit/scala/collection/immutable/VectorTest.scala
Expand Up @@ -5,7 +5,7 @@ import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test

import scala.annotation.unused
import scala.annotation.{nowarn, unused}
import scala.collection.immutable.VectorInline.{WIDTH3, WIDTH4, WIDTH5}
import scala.collection.mutable.{ListBuffer, StringBuilder}
import scala.tools.testkit.AssertUtil.intercept
Expand Down Expand Up @@ -110,10 +110,11 @@ class VectorTest {
@Test
def testBuilderAlignTo2(): Unit = {
val Large = 1 << 20
for (
size <- Seq(0, 1, 31, 1 << 5, 1 << 10, 1 << 15, 1 << 20, 9 << 20, 1 << 25, 9 << 25, 50 << 25, 1 << 30, (1 << 31) - (1 << 26) - 1000);
@nowarn val KrazyKonstant = (1 << 31) - (1 << 26) - 1000
for {
size <- Seq(0, 1, 31, 1 << 5, 1 << 10, 1 << 15, 1 << 20, 9 << 20, 1 << 25, 9 << 25, 50 << 25, 1 << 30, KrazyKonstant)
i <- Seq(0, 1, 5, 123)
) {
} {
// println((i, size))
val v = if (size < Large) Vector.tabulate(size)(_.toString) else Vector.fillSparse(size)("v")
val prefix = Vector.fill(i)("prefix")
Expand Down
2 changes: 1 addition & 1 deletion test/junit/scala/collection/mutable/ArrayBufferTest.scala
Expand Up @@ -412,7 +412,7 @@ class ArrayBufferTest {
assertThrows[Exception](rethrow(resizeUp(0, targetLen)),
_ == s"Array of array-backed collection exceeds VM length limit of $VM_MaxArraySize. Requested length: $targetLen; current length: 0")

checkExceedsMaxInt(Int.MaxValue + 1)
checkExceedsMaxInt(Int.MaxValue + 1: @nowarn)
checkExceedsVMArrayLimit(Int.MaxValue)
checkExceedsVMArrayLimit(Int.MaxValue - 1)
}
Expand Down

0 comments on commit 30cfb3c

Please sign in to comment.