Skip to content

Commit

Permalink
Tailcalls is more selective when reporting self-recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
som-snytt committed Mar 23, 2024
1 parent 5aa3dc5 commit 5ecfa88
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 85 deletions.
126 changes: 75 additions & 51 deletions src/compiler/scala/tools/nsc/transform/TailCalls.scala
Expand Up @@ -17,6 +17,8 @@ package transform
import symtab.Flags
import Flags.SYNTHETIC
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.util.chaining._

/** Perform tail recursive call elimination.
*
Expand Down Expand Up @@ -86,15 +88,10 @@ abstract class TailCalls extends Transform {
*/
class TailCallElimination(unit: CompilationUnit) extends AstTransformer {
private def defaultReason = "it contains a recursive call not in tail position"
private val failPositions = perRunCaches.newMap[TailContext, Position]() withDefault (_.methodPos)
private val failReasons = perRunCaches.newMap[TailContext, String]() withDefaultValue defaultReason
private def tailrecFailure(ctx: TailContext): Unit = {
val method = ctx.method
val failReason = failReasons(ctx)
val failPos = failPositions(ctx)

reporter.error(failPos, s"could not optimize @tailrec annotated $method: $failReason")
}
private val failPositions = perRunCaches.newMap[TailContext, Position]().withDefault(_.methodPos)
private val failReasons = perRunCaches.newMap[TailContext, String]().withDefaultValue(defaultReason)
private def tailrecFailure(ctx: TailContext): Unit =
reporter.error(failPositions(ctx), s"could not optimize @tailrec annotated ${ctx.method}: ${failReasons(ctx)}")

/** Has the label been accessed? Then its symbol is in this set. */
private val accessed = perRunCaches.newSet[Symbol]()
Expand All @@ -118,11 +115,9 @@ abstract class TailCalls extends Transform {
def isTransformed = isEligible && accessed(label)

def newThis(pos: Position) = {
def msg = "Creating new `this` during tailcalls\n method: %s\n current class: %s".format(
method.ownerChain.mkString(" -> "),
currentClass.ownerChain.mkString(" -> ")
)
logResult(msg)(method.newValue(nme.THIS, pos, SYNTHETIC) setInfo currentClass.typeOfThis)
def ownedBy(header: String)(sym: Symbol) = sym.ownerChain.mkString(s" $header: ", " -> ", "")
def msg = s"Creating new `this` during tailcalls\n${ownedBy("method")(method)}\n${ownedBy("current class")(currentClass)}"
logResult(msg)(method.newValue(nme.THIS, pos, SYNTHETIC).setInfo(currentClass.typeOfThis))
}
override def toString = s"${method.name} tparams=$tparams tailPos=$tailPos label=$label label info=${label.info}"

Expand Down Expand Up @@ -168,16 +163,37 @@ abstract class TailCalls extends Transform {

label
}
private def isRecursiveCall(t: Tree) = {
val receiver = t.symbol

( (receiver != null)
&& receiver.isMethod
&& (method.name == receiver.name)
&& (method.enclClass isSubClass receiver.enclClass)
)
// self-recursive calls, eagerly evaluated
object detectRecursion extends Traverser {
val detected = ListBuffer.empty[Tree]
private def ignore(sym: Symbol): Boolean =
sym.isArtifact && sym.name.containsName(nme.ANON_FUN_NAME) && sym.isLocalToBlock
override def traverse(tree: Tree) = tree match {
case Apply(fun, args) =>
if (isRecursiveCall(fun.symbol)) detected.addOne(tree)
traverse(fun)
for ((p, a) <- fun.symbol.paramLists.head.lazyZip(args) if !p.isByNameParam)
traverse(a)
case _: DefDef if ignore(tree.symbol) =>
case Function(_, _) =>
case _ => super.traverse(tree)
}
def recursiveCalls(t: Tree): List[Tree] = {
detected.clear()
traverse(t)
detected.toList
}
}
def recursiveCalls(t: Tree): List[Tree] = detectRecursion.recursiveCalls(t)
def isRecursiveCall(sym: Symbol): Boolean = (method eq sym) && tailPos

def containsRecursiveCallCandidate(t: Tree): Boolean = {
def isRecursiveCallCandidate(t: Tree) = {
val receiver = t.symbol
receiver != null && receiver.isMethod && method.name == receiver.name && method.enclClass.isSubClass(receiver.enclClass)
}
t.exists(isRecursiveCallCandidate)
}
def containsRecursiveCall(t: Tree) = t exists isRecursiveCall
}
class ClonedTailContext(val that: TailContext, override val tailPos: Boolean) extends TailContext {
def method = that.method
Expand Down Expand Up @@ -216,48 +232,48 @@ abstract class TailCalls extends Transform {
}

override def transform(tree: Tree): Tree = {
/* A possibly polymorphic apply to be considered for tail call transformation. */
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], mustTransformArgs: Boolean = true) = {
// A possibly polymorphic apply to be considered for tail call transformation.
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], transformArgs: Boolean) = {
val receiver: Tree = fun match {
case Select(qual, _) => qual
case _ => EmptyTree
}
def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
def isRecursiveCall = (ctx.method eq fun.symbol) && ctx.tailPos
def transformArgs = if (mustTransformArgs) noTailTransforms(args) else args
def transformedArgs = if (transformArgs) noTailTransforms(args) else args
def matchesTypeArgs = (ctx.tparams corresponds targs)((p, a) => !isSpecialized(p) || p == a.tpe.typeSymbol)

def isSpecialized(tparam: Symbol) =
tparam.hasAnnotation(SpecializedClass)
def isSpecialized(tparam: Symbol) = tparam.hasAnnotation(SpecializedClass)

/* Records failure reason in Context for reporting.
* Position is unchanged (by default, the method definition.)
*/
def fail(reason: String) = {
debuglog("Cannot rewrite recursive call at: " + fun.pos + " because: " + reason)
debuglog(s"Cannot rewrite recursive call at: ${fun.pos} because: $reason")
if (ctx.isMandatory) failReasons(ctx) = reason
treeCopy.Apply(tree, noTailTransform(target), transformArgs)
unrewritten
}
/* Position of failure is that of the tree being considered. */
// Position of failure is that of the tree being considered.
def failHere(reason: String) = {
if (ctx.isMandatory) failPositions(ctx) = fun.pos
fail(reason)
}
def unrewritten: Tree = treeCopy.Apply(tree, noTailTransform(target), transformedArgs)
def rewriteTailCall(recv: Tree): Tree = {
debuglog("Rewriting tail recursive call: " + fun.pos.lineContent.trim)
debuglog(s"Rewriting tail recursive call: [${fun.pos.lineContent.trim}]")
accessed += ctx.label
typedPos(fun.pos) {
val args = mapWithIndex(transformArgs)((arg, i) => mkAttributedCastHack(arg, ctx.label.info.params(i + 1).tpe))
val args = mapWithIndex(transformedArgs)((arg, i) => mkAttributedCastHack(arg, ctx.label.info.params(i + 1).tpe))
Apply(Ident(ctx.label), noTailTransform(recv) :: args)
}
}

if (!ctx.isEligible) fail("it is neither private nor final so can be overridden")
else if (!isRecursiveCall) {
if (ctx.isMandatory && receiverIsSuper) // OPT expensive check, avoid unless we will actually report the error
if (ctx.isMandatory && receiverIsSuper && !receiverIsSame) // OPT expensive check, avoid unless we will actually report the error
failHere("it contains a recursive call targeting a supertype")
else failHere(defaultReason)
else unrewritten // failHere(defaultReason)
}
else if (!matchesTypeArgs) failHere("it is called recursively with different specialized type arguments")
else if (receiver == EmptyTree) rewriteTailCall(This(currentClass))
Expand All @@ -281,23 +297,19 @@ abstract class TailCalls extends Transform {

case dd @ DefDef(_, name, _, vparamss0, _, rhs0) if isEligible(dd) =>
val newCtx = new DefDefTailContext(dd)
if (newCtx.isMandatory && !(newCtx containsRecursiveCall rhs0))
reporter.error(tree.pos, "@tailrec annotated method contains no recursive calls")

debuglog(s"Considering $name for tailcalls, with labels in tailpos: ${newCtx.tailLabels}")
val newRHS = transform(rhs0, newCtx)

deriveDefDef(tree) { rhs =>
def unreported = !failPositions.contains(newCtx) && !failReasons.contains(newCtx)
if (newCtx.isTransformed) {
/* We have rewritten the tree, but there may be nested recursive calls remaining.
* If @tailrec is given we need to fail those now.
*/
if (newCtx.isMandatory) {
for (t @ Apply(fn, _) <- newRHS ; if fn.symbol == newCtx.method) {
failPositions(newCtx) = t.pos
// any remaining self-recursive calls after transform must fail under @tailrec
if (newCtx.isMandatory)
for (remaining <- newCtx.recursiveCalls(newRHS).take(1)) {
if (unreported)
failPositions(newCtx) = remaining.pos
tailrecFailure(newCtx)
}
}
val newThis = newCtx.newThis(tree.pos)
val vpSyms = vparamss0.flatten map (_.symbol)

Expand All @@ -307,11 +319,23 @@ abstract class TailCalls extends Transform {
))
}
else {
if (newCtx.isMandatory && (newCtx containsRecursiveCall newRHS))
tailrecFailure(newCtx)

if (newCtx.isMandatory) {
if (unreported) {
val remainders = newCtx.recursiveCalls(newRHS)
if (remainders.isEmpty)
//failReasons(newCtx) = "@tailrec annotated method contains no recursive calls"
reporter.error(tree.pos, "@tailrec annotated method contains no recursive calls")
else
failPositions(newCtx) = remainders.head.pos
}
if (!unreported)
tailrecFailure(newCtx)
}
newRHS
}
}.tap { _ =>
failPositions.remove(newCtx)
failReasons.remove(newCtx)
}

// a translated match
Expand Down Expand Up @@ -375,7 +399,7 @@ abstract class TailCalls extends Transform {
)

case Apply(tapply @ TypeApply(fun, targs), vargs) =>
rewriteApply(tapply, fun, targs, vargs)
rewriteApply(tapply, fun, targs, vargs, transformArgs = true)

case Apply(fun, args) if fun.symbol == Boolean_or || fun.symbol == Boolean_and =>
treeCopy.Apply(tree, noTailTransform(fun), transformTrees(args))
Expand All @@ -392,10 +416,10 @@ abstract class TailCalls extends Transform {
if (res ne arg)
treeCopy.Apply(tree, fun, res :: Nil)
else
rewriteApply(fun, fun, Nil, args, mustTransformArgs = false)
rewriteApply(fun, fun, Nil, args, transformArgs = false)

case Apply(fun, args) =>
rewriteApply(fun, fun, Nil, args)
rewriteApply(fun, fun, Nil, args, transformArgs = true)
case Alternative(_) | Star(_) | Bind(_, _) =>
assert(false, "We should've never gotten inside a pattern")
tree
Expand Down
2 changes: 1 addition & 1 deletion test/files/neg/t12513b.check
@@ -1,4 +1,4 @@
t12513b.scala:8: error: could not optimize @tailrec annotated method f: it contains a recursive call not in tail position
@T def f: Int = { f ; 42 } // the annotation worked: error, f is not tail recursive
^
^
1 error
24 changes: 12 additions & 12 deletions test/files/neg/t1672b.check
@@ -1,16 +1,16 @@
t1672b.scala:3: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
def bar : Nothing = {
^
t1672b.scala:14: error: could not optimize @tailrec annotated method baz: it contains a recursive call not in tail position
def baz : Nothing = {
^
t1672b.scala:7: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
case _: Throwable => bar
^
t1672b.scala:18: error: could not optimize @tailrec annotated method baz: it contains a recursive call not in tail position
case _: Throwable => baz
^
t1672b.scala:29: error: could not optimize @tailrec annotated method boz: it contains a recursive call not in tail position
case _: Throwable => boz; ???
^
t1672b.scala:34: error: could not optimize @tailrec annotated method bez: it contains a recursive call not in tail position
def bez : Nothing = {
^
t1672b.scala:36: error: could not optimize @tailrec annotated method bez: it contains a recursive call not in tail position
bez
^
t1672b.scala:46: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
else 1 + (try {
^
t1672b.scala:49: error: could not optimize @tailrec annotated method bar: it contains a recursive call not in tail position
case _: Throwable => bar(i - 1)
^
5 errors
4 changes: 4 additions & 0 deletions test/files/neg/t4649.check
@@ -0,0 +1,4 @@
t4649.scala:12: error: @tailrec annotated method contains no recursive calls
@tailrec final def remove(idx: Int, count: Int): Unit =
^
1 error
17 changes: 17 additions & 0 deletions test/files/neg/t4649.scala
@@ -0,0 +1,17 @@

import annotation.tailrec

object Test {

var sz = 3
def remove(idx: Int) =
if (idx >= 0 && idx < sz)
sz -= 1
else throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${sz-1})")

@tailrec final def remove(idx: Int, count: Int): Unit =
if (count > 0) {
remove(idx) // at a glance, looks like a tailrec candidate, but must error in the end
}

}
10 changes: 5 additions & 5 deletions test/files/neg/t6526.check
@@ -1,16 +1,16 @@
t6526.scala:8: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:14: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:20: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:30: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
@tailrec def inner(i: Int): Int = 1 + inner(i)
^
^
t6526.scala:39: error: could not optimize @tailrec annotated method inner: it contains a recursive call not in tail position
def inner(i: Int): Int = 1 + inner(i)
^
^
5 errors
6 changes: 3 additions & 3 deletions test/files/neg/t6574.check
@@ -1,4 +1,4 @@
t6574.scala:4: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position
println("tail")
^
t6574.scala:3: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position
this.notTailPos[Z](a)(b)
^
1 error
10 changes: 5 additions & 5 deletions test/files/neg/tailrec-4.check
@@ -1,16 +1,16 @@
tailrec-4.scala:6: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:11: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:17: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:23: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
tailrec-4.scala:31: error: could not optimize @tailrec annotated method foo: it contains a recursive call not in tail position
@tailrec def foo: Int = foo + 1
^
^
5 errors
8 changes: 4 additions & 4 deletions test/files/neg/tailrec.check
@@ -1,12 +1,12 @@
tailrec.scala:45: error: could not optimize @tailrec annotated method facfail: it contains a recursive call not in tail position
else n * facfail(n - 1)
^
^
tailrec.scala:50: error: could not optimize @tailrec annotated method fail1: it is neither private nor final so can be overridden
@tailrec def fail1(x: Int): Int = fail1(x)
^
tailrec.scala:53: error: could not optimize @tailrec annotated method fail2: it contains a recursive call not in tail position
@tailrec final def fail2[T](xs: List[T]): List[T] = xs match {
^
tailrec.scala:55: error: could not optimize @tailrec annotated method fail2: it contains a recursive call not in tail position
case x :: xs => x :: fail2[T](xs)
^
tailrec.scala:59: error: could not optimize @tailrec annotated method fail3: it is called recursively with different specialized type arguments
@tailrec final def fail3[@specialized(Int) T](x: Int): Int = fail3(x - 1)
^
Expand Down
35 changes: 31 additions & 4 deletions test/files/pos/t4649.scala
@@ -1,10 +1,37 @@
//> abusing options -Vlog:tailcalls -Vdebug -Vprint:~tailcalls

import annotation.tailrec

// scalac: -Xfatal-warnings
//
object Test {
// @annotation.tailrec
@tailrec
def lazyFilter[E](s: LazyList[E], p: E => Boolean): LazyList[E] = s match {
case h #:: t => if (p(h)) h #:: lazyFilter(t, p) else lazyFilter(t, p)
case _ => LazyList.empty[E]
}

@tailrec
def f(i: Int): Int =
if (i <= 0) i
/* not optimized
else if (i == 27) {
val x = f(i - 1)
x
}
*/
else if (i == 42) {
val g: Int => Int = f(_)
f(i - 1)
}
else f(i - 1)

var sz = 3
def remove(idx: Int) =
if (idx >= 0 && idx < sz)
sz -= 1
else throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${sz-1})")

@tailrec final def remove(idx: Int, count: Int): Unit =
if (count > 0) {
remove(idx) // after rewrite, don't flag me as a leftover tailrec
remove(idx, count-1)
}
}

0 comments on commit 5ecfa88

Please sign in to comment.