Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tailcalls is more selective when reporting self-recursion #10723

Draft
wants to merge 2 commits into
base: 2.13.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/compiler/scala/tools/nsc/settings/Warnings.scala
Expand Up @@ -121,6 +121,7 @@ trait Warnings {
val warnValueDiscard = BooleanSetting("-Wvalue-discard", "Warn when non-Unit expression results are unused.") withAbbreviation "-Ywarn-value-discard"
val warnNumericWiden = BooleanSetting("-Wnumeric-widen", "Warn when numerics are widened.") withAbbreviation "-Ywarn-numeric-widen"
val warnOctalLiteral = BooleanSetting("-Woctal-literal", "Warn on obsolete octal syntax.") withAbbreviation "-Ywarn-octal-literal"
val strictTailRec = BooleanSetting("-Wstrict-tailrec", "@tailrec warns on calls in functions or by-name args.")

object PerformanceWarnings extends MultiChoiceEnumeration {
val Captured = Choice("captured", "Modification of var in closure causes boxing.")
Expand Down
127 changes: 76 additions & 51 deletions src/compiler/scala/tools/nsc/transform/TailCalls.scala
Expand Up @@ -17,6 +17,8 @@ package transform
import symtab.Flags
import Flags.SYNTHETIC
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.util.chaining._

/** Perform tail recursive call elimination.
*
Expand Down Expand Up @@ -85,16 +87,12 @@ abstract class TailCalls extends Transform {
* </p>
*/
class TailCallElimination(unit: CompilationUnit) extends AstTransformer {
private val strict = settings.strictTailRec.value
private def defaultReason = "it contains a recursive call not in tail position"
private val failPositions = perRunCaches.newMap[TailContext, Position]() withDefault (_.methodPos)
private val failReasons = perRunCaches.newMap[TailContext, String]() withDefaultValue defaultReason
private def tailrecFailure(ctx: TailContext): Unit = {
val method = ctx.method
val failReason = failReasons(ctx)
val failPos = failPositions(ctx)

reporter.error(failPos, s"could not optimize @tailrec annotated $method: $failReason")
}
private val failPositions = perRunCaches.newMap[TailContext, Position]().withDefault(_.methodPos)
private val failReasons = perRunCaches.newMap[TailContext, String]().withDefaultValue(defaultReason)
private def tailrecFailure(ctx: TailContext): Unit =
reporter.error(failPositions(ctx), s"could not optimize @tailrec annotated ${ctx.method}: ${failReasons(ctx)}")

/** Has the label been accessed? Then its symbol is in this set. */
private val accessed = perRunCaches.newSet[Symbol]()
Expand All @@ -118,11 +116,9 @@ abstract class TailCalls extends Transform {
def isTransformed = isEligible && accessed(label)

def newThis(pos: Position) = {
def msg = "Creating new `this` during tailcalls\n method: %s\n current class: %s".format(
method.ownerChain.mkString(" -> "),
currentClass.ownerChain.mkString(" -> ")
)
logResult(msg)(method.newValue(nme.THIS, pos, SYNTHETIC) setInfo currentClass.typeOfThis)
def ownedBy(header: String)(sym: Symbol) = sym.ownerChain.mkString(s" $header: ", " -> ", "")
def msg = s"Creating new `this` during tailcalls\n${ownedBy("method")(method)}\n${ownedBy("current class")(currentClass)}"
logResult(msg)(method.newValue(nme.THIS, pos, SYNTHETIC).setInfo(currentClass.typeOfThis))
}
override def toString = s"${method.name} tparams=$tparams tailPos=$tailPos label=$label label info=${label.info}"

Expand Down Expand Up @@ -168,16 +164,37 @@ abstract class TailCalls extends Transform {

label
}
private def isRecursiveCall(t: Tree) = {
val receiver = t.symbol

( (receiver != null)
&& receiver.isMethod
&& (method.name == receiver.name)
&& (method.enclClass isSubClass receiver.enclClass)
)
// self-recursive calls, eagerly evaluated
object detectRecursion extends Traverser {
val detected = ListBuffer.empty[Tree]
private def ignore(sym: Symbol): Boolean =
sym.isArtifact && sym.name.containsName(nme.ANON_FUN_NAME) && sym.isLocalToBlock
override def traverse(tree: Tree) = tree match {
case Apply(fun, args) =>
if (isRecursiveCall(fun.symbol)) detected.addOne(tree)
traverse(fun)
for ((p, a) <- fun.symbol.paramLists.head.lazyZip(args) if strict || !p.isByNameParam)
traverse(a)
case _: DefDef if !strict && ignore(tree.symbol) =>
case Function(_, _) if !strict =>
case _ => super.traverse(tree)
}
def recursiveCalls(t: Tree): List[Tree] = {
detected.clear()
traverse(t)
detected.toList
}
}
def recursiveCalls(t: Tree): List[Tree] = detectRecursion.recursiveCalls(t)
def isRecursiveCall(sym: Symbol): Boolean = (method eq sym) && tailPos

def containsRecursiveCallCandidate(t: Tree): Boolean = {
def isRecursiveCallCandidate(t: Tree) = {
val receiver = t.symbol
receiver != null && receiver.isMethod && method.name == receiver.name && method.enclClass.isSubClass(receiver.enclClass)
}
t.exists(isRecursiveCallCandidate)
}
def containsRecursiveCall(t: Tree) = t exists isRecursiveCall
}
class ClonedTailContext(val that: TailContext, override val tailPos: Boolean) extends TailContext {
def method = that.method
Expand Down Expand Up @@ -216,48 +233,48 @@ abstract class TailCalls extends Transform {
}

override def transform(tree: Tree): Tree = {
/* A possibly polymorphic apply to be considered for tail call transformation. */
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], mustTransformArgs: Boolean = true) = {
// A possibly polymorphic apply to be considered for tail call transformation.
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], transformArgs: Boolean) = {
val receiver: Tree = fun match {
case Select(qual, _) => qual
case _ => EmptyTree
}
def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
def isRecursiveCall = (ctx.method eq fun.symbol) && ctx.tailPos
def transformArgs = if (mustTransformArgs) noTailTransforms(args) else args
def transformedArgs = if (transformArgs) noTailTransforms(args) else args
def matchesTypeArgs = (ctx.tparams corresponds targs)((p, a) => !isSpecialized(p) || p == a.tpe.typeSymbol)

def isSpecialized(tparam: Symbol) =
tparam.hasAnnotation(SpecializedClass)
def isSpecialized(tparam: Symbol) = tparam.hasAnnotation(SpecializedClass)

/* Records failure reason in Context for reporting.
* Position is unchanged (by default, the method definition.)
*/
def fail(reason: String) = {
debuglog("Cannot rewrite recursive call at: " + fun.pos + " because: " + reason)
debuglog(s"Cannot rewrite recursive call at: ${fun.pos} because: $reason")
if (ctx.isMandatory) failReasons(ctx) = reason
treeCopy.Apply(tree, noTailTransform(target), transformArgs)
unrewritten
}
/* Position of failure is that of the tree being considered. */
// Position of failure is that of the tree being considered.
def failHere(reason: String) = {
if (ctx.isMandatory) failPositions(ctx) = fun.pos
fail(reason)
}
def unrewritten: Tree = treeCopy.Apply(tree, noTailTransform(target), transformedArgs)
def rewriteTailCall(recv: Tree): Tree = {
debuglog("Rewriting tail recursive call: " + fun.pos.lineContent.trim)
debuglog(s"Rewriting tail recursive call: [${fun.pos.lineContent.trim}]")
accessed += ctx.label
typedPos(fun.pos) {
val args = mapWithIndex(transformArgs)((arg, i) => mkAttributedCastHack(arg, ctx.label.info.params(i + 1).tpe))
val args = mapWithIndex(transformedArgs)((arg, i) => mkAttributedCastHack(arg, ctx.label.info.params(i + 1).tpe))
Apply(Ident(ctx.label), noTailTransform(recv) :: args)
}
}

if (!ctx.isEligible) fail("it is neither private nor final so can be overridden")
else if (!isRecursiveCall) {
if (ctx.isMandatory && receiverIsSuper) // OPT expensive check, avoid unless we will actually report the error
if (ctx.isMandatory && receiverIsSuper && !receiverIsSame) // OPT expensive check, avoid unless we will actually report the error
failHere("it contains a recursive call targeting a supertype")
else failHere(defaultReason)
else unrewritten // failHere(defaultReason)
}
else if (!matchesTypeArgs) failHere("it is called recursively with different specialized type arguments")
else if (receiver == EmptyTree) rewriteTailCall(This(currentClass))
Expand All @@ -281,23 +298,19 @@ abstract class TailCalls extends Transform {

case dd @ DefDef(_, name, _, vparamss0, _, rhs0) if isEligible(dd) =>
val newCtx = new DefDefTailContext(dd)
if (newCtx.isMandatory && !(newCtx containsRecursiveCall rhs0))
reporter.error(tree.pos, "@tailrec annotated method contains no recursive calls")

debuglog(s"Considering $name for tailcalls, with labels in tailpos: ${newCtx.tailLabels}")
val newRHS = transform(rhs0, newCtx)

deriveDefDef(tree) { rhs =>
def unreported = !failPositions.contains(newCtx) && !failReasons.contains(newCtx)
if (newCtx.isTransformed) {
/* We have rewritten the tree, but there may be nested recursive calls remaining.
* If @tailrec is given we need to fail those now.
*/
if (newCtx.isMandatory) {
for (t @ Apply(fn, _) <- newRHS ; if fn.symbol == newCtx.method) {
failPositions(newCtx) = t.pos
// any remaining self-recursive calls after transform must fail under @tailrec
if (newCtx.isMandatory)
for (remaining <- newCtx.recursiveCalls(newRHS).take(1)) {
if (unreported)
failPositions(newCtx) = remaining.pos
tailrecFailure(newCtx)
}
}
val newThis = newCtx.newThis(tree.pos)
val vpSyms = vparamss0.flatten map (_.symbol)

Expand All @@ -307,11 +320,23 @@ abstract class TailCalls extends Transform {
))
}
else {
if (newCtx.isMandatory && (newCtx containsRecursiveCall newRHS))
tailrecFailure(newCtx)

if (newCtx.isMandatory) {
if (unreported) {
val remainders = newCtx.recursiveCalls(newRHS)
if (remainders.isEmpty)
//failReasons(newCtx) = "@tailrec annotated method contains no recursive calls"
reporter.error(tree.pos, "@tailrec annotated method contains no recursive calls")
else
failPositions(newCtx) = remainders.head.pos
}
if (!unreported)
tailrecFailure(newCtx)
}
newRHS
}
}.tap { _ =>
failPositions.remove(newCtx)
failReasons.remove(newCtx)
}

// a translated match
Expand Down Expand Up @@ -375,7 +400,7 @@ abstract class TailCalls extends Transform {
)

case Apply(tapply @ TypeApply(fun, targs), vargs) =>
rewriteApply(tapply, fun, targs, vargs)
rewriteApply(tapply, fun, targs, vargs, transformArgs = true)

case Apply(fun, args) if fun.symbol == Boolean_or || fun.symbol == Boolean_and =>
treeCopy.Apply(tree, noTailTransform(fun), transformTrees(args))
Expand All @@ -392,10 +417,10 @@ abstract class TailCalls extends Transform {
if (res ne arg)
treeCopy.Apply(tree, fun, res :: Nil)
else
rewriteApply(fun, fun, Nil, args, mustTransformArgs = false)
rewriteApply(fun, fun, Nil, args, transformArgs = false)

case Apply(fun, args) =>
rewriteApply(fun, fun, Nil, args)
rewriteApply(fun, fun, Nil, args, transformArgs = true)
case Alternative(_) | Star(_) | Bind(_, _) =>
assert(false, "We should've never gotten inside a pattern")
tree
Expand Down
2 changes: 1 addition & 1 deletion test/files/neg/t12513b.check
@@ -1,4 +1,4 @@
t12513b.scala:8: error: could not optimize @tailrec annotated method f: it contains a recursive call not in tail position
@T def f: Int = { f ; 42 } // the annotation worked: error, f is not tail recursive
^
^
1 error
24 changes: 12 additions & 12 deletions test/files/neg/t1672b.check
@@ -1,16 +1,16 @@
t1672b.scala:3: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
def bar : Nothing = {
^
t1672b.scala:14: error: could not optimize @tailrec annotated method baz: it contains a recursive call not in tail position
def baz : Nothing = {
^
t1672b.scala:7: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
case _: Throwable => bar
^
t1672b.scala:18: error: could not optimize @tailrec annotated method baz: it contains a recursive call not in tail position
case _: Throwable => baz
^
t1672b.scala:29: error: could not optimize @tailrec annotated method boz: it contains a recursive call not in tail position
case _: Throwable => boz; ???
^
t1672b.scala:34: error: could not optimize @tailrec annotated method bez: it contains a recursive call not in tail position
def bez : Nothing = {
^
t1672b.scala:36: error: could not optimize @tailrec annotated method bez: it contains a recursive call not in tail position
bez
^
t1672b.scala:46: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
else 1 + (try {
^
t1672b.scala:49: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
case _: Throwable => bar(i - 1)
^
5 errors
4 changes: 4 additions & 0 deletions test/files/neg/t4649.check
@@ -0,0 +1,4 @@
t4649.scala:12: error: @tailrec annotated method contains no recursive calls
@tailrec final def remove(idx: Int, count: Int): Unit =
^
1 error
17 changes: 17 additions & 0 deletions test/files/neg/t4649.scala
@@ -0,0 +1,17 @@

import annotation.tailrec

object Test {

var sz = 3
def remove(idx: Int) =
if (idx >= 0 && idx < sz)
sz -= 1
else throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${sz-1})")

@tailrec final def remove(idx: Int, count: Int): Unit =
if (count > 0) {
remove(idx) // at a glance, looks like a tailrec candidate, but must error in the end
}

}
7 changes: 7 additions & 0 deletions test/files/neg/t4649b.check
@@ -0,0 +1,7 @@
t4649b.scala:8: error: could not optimize @tailrec annotated method lazyFilter: it contains a recursive call not in tail position
case h #:: t => if (p(h)) h #:: lazyFilter(t, p) else lazyFilter(t, p) // error
^
t4649b.scala:21: error: could not optimize @tailrec annotated method f: it contains a recursive call not in tail position
val g: Int => Int = f(_) // error
^
2 errors
25 changes: 25 additions & 0 deletions test/files/neg/t4649b.scala
@@ -0,0 +1,25 @@
//> using options -Wstrict-tailrec

import annotation.tailrec

object Test {
@tailrec
def lazyFilter[E](s: LazyList[E], p: E => Boolean): LazyList[E] = s match {
case h #:: t => if (p(h)) h #:: lazyFilter(t, p) else lazyFilter(t, p) // error
}

@tailrec
def f(i: Int): Int =
if (i <= 0) i
/* not optimized
else if (i == 27) {
val x = f(i - 1)
x
}
*/
else if (i == 42) {
val g: Int => Int = f(_) // error
f(i - 1)
}
else f(i - 1)
}
10 changes: 5 additions & 5 deletions test/files/neg/t6526.check
@@ -1,16 +1,16 @@
t6526.scala:8: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:14: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:20: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:30: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:39: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
def inner(i: Int): Int = 1 + inner(i)
^
^
5 errors
6 changes: 3 additions & 3 deletions test/files/neg/t6574.check
@@ -1,4 +1,4 @@
t6574.scala:4: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position
println("tail")
^
t6574.scala:3: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position
this.notTailPos[Z](a)(b)
^
1 error
10 changes: 5 additions & 5 deletions test/files/neg/tailrec-4.check
@@ -1,16 +1,16 @@
tailrec-4.scala:6: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:11: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:17: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:23: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:31: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
5 errors