Skip to content

Commit

Permalink
Tailcalls is more selective when reporting self-recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
som-snytt committed Mar 23, 2024
1 parent 5aa3dc5 commit ece451c
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 51 deletions.
113 changes: 66 additions & 47 deletions src/compiler/scala/tools/nsc/transform/TailCalls.scala
Expand Up @@ -17,6 +17,7 @@ package transform
import symtab.Flags
import Flags.SYNTHETIC
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer

/** Perform tail recursive call elimination.
*
Expand Down Expand Up @@ -86,15 +87,10 @@ abstract class TailCalls extends Transform {
*/
class TailCallElimination(unit: CompilationUnit) extends AstTransformer {
private def defaultReason = "it contains a recursive call not in tail position"
private val failPositions = perRunCaches.newMap[TailContext, Position]() withDefault (_.methodPos)
private val failReasons = perRunCaches.newMap[TailContext, String]() withDefaultValue defaultReason
private def tailrecFailure(ctx: TailContext): Unit = {
val method = ctx.method
val failReason = failReasons(ctx)
val failPos = failPositions(ctx)

reporter.error(failPos, s"could not optimize @tailrec annotated $method: $failReason")
}
private val failPositions = perRunCaches.newMap[TailContext, Position]().withDefault(_.methodPos)
private val failReasons = perRunCaches.newMap[TailContext, String]().withDefaultValue(defaultReason)
private def tailrecFailure(ctx: TailContext): Unit =
reporter.error(failPositions(ctx), s"could not optimize @tailrec annotated ${ctx.method}: ${failReasons(ctx)}")

/** Has the label been accessed? Then its symbol is in this set. */
private val accessed = perRunCaches.newSet[Symbol]()
Expand All @@ -118,11 +114,9 @@ abstract class TailCalls extends Transform {
def isTransformed = isEligible && accessed(label)

def newThis(pos: Position) = {
def msg = "Creating new `this` during tailcalls\n method: %s\n current class: %s".format(
method.ownerChain.mkString(" -> "),
currentClass.ownerChain.mkString(" -> ")
)
logResult(msg)(method.newValue(nme.THIS, pos, SYNTHETIC) setInfo currentClass.typeOfThis)
def ownedBy(header: String)(sym: Symbol) = sym.ownerChain.mkString(s" $header: ", " -> ", "")
def msg = s"Creating new `this` during tailcalls\n${ownedBy("method")(method)}\n${ownedBy("current class")(currentClass)}"
logResult(msg)(method.newValue(nme.THIS, pos, SYNTHETIC).setInfo(currentClass.typeOfThis))
}
override def toString = s"${method.name} tparams=$tparams tailPos=$tailPos label=$label label info=${label.info}"

Expand Down Expand Up @@ -168,16 +162,37 @@ abstract class TailCalls extends Transform {

label
}
private def isRecursiveCall(t: Tree) = {
val receiver = t.symbol

( (receiver != null)
&& receiver.isMethod
&& (method.name == receiver.name)
&& (method.enclClass isSubClass receiver.enclClass)
)
// self-recursive calls, eagerly evaluated
object detectRecursion extends Traverser {
val detected = ListBuffer.empty[Tree]
private def ignore(sym: Symbol): Boolean =
sym.isArtifact && sym.name.containsName(nme.ANON_FUN_NAME) && sym.isLocalToBlock
override def traverse(tree: Tree) = tree match {
case Apply(fun, args) =>
if (isRecursiveCall(fun.symbol)) detected.addOne(tree)
traverse(fun)
for ((p, a) <- fun.symbol.paramLists.head.lazyZip(args) if !p.isByNameParam)
traverse(a)
case _: DefDef if ignore(tree.symbol) =>
case Function(_, _) =>
case _ => super.traverse(tree)
}
def recursiveCalls(t: Tree): List[Tree] = {
detected.clear()
traverse(t)
detected.toList
}
}
def recursiveCalls(t: Tree): List[Tree] = detectRecursion.recursiveCalls(t)
def isRecursiveCall(sym: Symbol): Boolean = (method eq sym) && tailPos

def containsRecursiveCallCandidate(t: Tree): Boolean = {
def isRecursiveCallCandidate(t: Tree) = {
val receiver = t.symbol
receiver != null && receiver.isMethod && method.name == receiver.name && method.enclClass.isSubClass(receiver.enclClass)
}
t.exists(isRecursiveCallCandidate)
}
def containsRecursiveCall(t: Tree) = t exists isRecursiveCall
}
class ClonedTailContext(val that: TailContext, override val tailPos: Boolean) extends TailContext {
def method = that.method
Expand Down Expand Up @@ -216,46 +231,45 @@ abstract class TailCalls extends Transform {
}

override def transform(tree: Tree): Tree = {
/* A possibly polymorphic apply to be considered for tail call transformation. */
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], mustTransformArgs: Boolean = true) = {
// A possibly polymorphic apply to be considered for tail call transformation.
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], transformArgs: Boolean) = {
val receiver: Tree = fun match {
case Select(qual, _) => qual
case _ => EmptyTree
}
def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
def isRecursiveCall = (ctx.method eq fun.symbol) && ctx.tailPos
def transformArgs = if (mustTransformArgs) noTailTransforms(args) else args
def transformedArgs = if (transformArgs) noTailTransforms(args) else args
def matchesTypeArgs = (ctx.tparams corresponds targs)((p, a) => !isSpecialized(p) || p == a.tpe.typeSymbol)

def isSpecialized(tparam: Symbol) =
tparam.hasAnnotation(SpecializedClass)
def isSpecialized(tparam: Symbol) = tparam.hasAnnotation(SpecializedClass)

/* Records failure reason in Context for reporting.
* Position is unchanged (by default, the method definition.)
*/
def fail(reason: String) = {
debuglog("Cannot rewrite recursive call at: " + fun.pos + " because: " + reason)
debuglog(s"Cannot rewrite recursive call at: ${fun.pos} because: $reason")
if (ctx.isMandatory) failReasons(ctx) = reason
treeCopy.Apply(tree, noTailTransform(target), transformArgs)
treeCopy.Apply(tree, noTailTransform(target), transformedArgs)
}
/* Position of failure is that of the tree being considered. */
// Position of failure is that of the tree being considered.
def failHere(reason: String) = {
if (ctx.isMandatory) failPositions(ctx) = fun.pos
fail(reason)
}
def rewriteTailCall(recv: Tree): Tree = {
debuglog("Rewriting tail recursive call: " + fun.pos.lineContent.trim)
debuglog(s"Rewriting tail recursive call: [${fun.pos.lineContent.trim}]")
accessed += ctx.label
typedPos(fun.pos) {
val args = mapWithIndex(transformArgs)((arg, i) => mkAttributedCastHack(arg, ctx.label.info.params(i + 1).tpe))
val args = mapWithIndex(transformedArgs)((arg, i) => mkAttributedCastHack(arg, ctx.label.info.params(i + 1).tpe))
Apply(Ident(ctx.label), noTailTransform(recv) :: args)
}
}

if (!ctx.isEligible) fail("it is neither private nor final so can be overridden")
else if (!isRecursiveCall) {
if (ctx.isMandatory && receiverIsSuper) // OPT expensive check, avoid unless we will actually report the error
if (ctx.isMandatory && receiverIsSuper && !receiverIsSame) // OPT expensive check, avoid unless we will actually report the error
failHere("it contains a recursive call targeting a supertype")
else failHere(defaultReason)
}
Expand All @@ -281,23 +295,20 @@ abstract class TailCalls extends Transform {

case dd @ DefDef(_, name, _, vparamss0, _, rhs0) if isEligible(dd) =>
val newCtx = new DefDefTailContext(dd)
if (newCtx.isMandatory && !(newCtx containsRecursiveCall rhs0))
if (newCtx.isMandatory && !newCtx.containsRecursiveCallCandidate(rhs0))
reporter.error(tree.pos, "@tailrec annotated method contains no recursive calls")

debuglog(s"Considering $name for tailcalls, with labels in tailpos: ${newCtx.tailLabels}")
val newRHS = transform(rhs0, newCtx)

deriveDefDef(tree) { rhs =>
if (newCtx.isTransformed) {
/* We have rewritten the tree, but there may be nested recursive calls remaining.
* If @tailrec is given we need to fail those now.
*/
if (newCtx.isMandatory) {
for (t @ Apply(fn, _) <- newRHS ; if fn.symbol == newCtx.method) {
failPositions(newCtx) = t.pos
// any remaining self-recursive calls after transform must fail under @tailrec
if (newCtx.isMandatory)
for (remaining <- newCtx.recursiveCalls(newRHS)) {
failPositions(newCtx) = remaining.pos
tailrecFailure(newCtx)
}
}
val newThis = newCtx.newThis(tree.pos)
val vpSyms = vparamss0.flatten map (_.symbol)

Expand All @@ -307,9 +318,17 @@ abstract class TailCalls extends Transform {
))
}
else {
if (newCtx.isMandatory && (newCtx containsRecursiveCall newRHS))
if (newCtx.isMandatory) {
val remainders = newCtx.recursiveCalls(newRHS)
if (remainders.isEmpty) {
failReasons(newCtx) = "@tailrec annotated method contains no recursive calls"
failPositions(newCtx) = tree.pos
}
else
for (remaining <- remainders)
failPositions(newCtx) = remaining.pos
tailrecFailure(newCtx)

}
newRHS
}
}
Expand Down Expand Up @@ -375,7 +394,7 @@ abstract class TailCalls extends Transform {
)

case Apply(tapply @ TypeApply(fun, targs), vargs) =>
rewriteApply(tapply, fun, targs, vargs)
rewriteApply(tapply, fun, targs, vargs, transformArgs = true)

case Apply(fun, args) if fun.symbol == Boolean_or || fun.symbol == Boolean_and =>
treeCopy.Apply(tree, noTailTransform(fun), transformTrees(args))
Expand All @@ -392,10 +411,10 @@ abstract class TailCalls extends Transform {
if (res ne arg)
treeCopy.Apply(tree, fun, res :: Nil)
else
rewriteApply(fun, fun, Nil, args, mustTransformArgs = false)
rewriteApply(fun, fun, Nil, args, transformArgs = false)

case Apply(fun, args) =>
rewriteApply(fun, fun, Nil, args)
rewriteApply(fun, fun, Nil, args, transformArgs = true)
case Alternative(_) | Star(_) | Bind(_, _) =>
assert(false, "We should've never gotten inside a pattern")
tree
Expand Down
4 changes: 4 additions & 0 deletions test/files/neg/t4649.check
@@ -0,0 +1,4 @@
t4649.scala:12: error: could not optimize @tailrec annotated method remove: @tailrec annotated method contains no recursive calls
@tailrec final def remove(idx: Int, count: Int): Unit =
^
1 error
17 changes: 17 additions & 0 deletions test/files/neg/t4649.scala
@@ -0,0 +1,17 @@

import annotation.tailrec

object Test {

var sz = 3
def remove(idx: Int) =
if (idx >= 0 && idx < sz)
sz -= 1
else throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${sz-1})")

@tailrec final def remove(idx: Int, count: Int): Unit =
if (count > 0) {
remove(idx) // at a glance, looks like a tailrec candidate, but must error in the end
}

}
35 changes: 31 additions & 4 deletions test/files/pos/t4649.scala
@@ -1,10 +1,37 @@
//> abusing options -Vlog:tailcalls -Vdebug -Vprint:~tailcalls

import annotation.tailrec

// scalac: -Xfatal-warnings
//
object Test {
// @annotation.tailrec
@tailrec
def lazyFilter[E](s: LazyList[E], p: E => Boolean): LazyList[E] = s match {
case h #:: t => if (p(h)) h #:: lazyFilter(t, p) else lazyFilter(t, p)
case _ => LazyList.empty[E]
}

@tailrec
def f(i: Int): Int =
if (i <= 0) i
/* not optimized
else if (i == 27) {
val x = f(i - 1)
x
}
*/
else if (i == 42) {
val g: Int => Int = f(_)
f(i - 1)
}
else f(i - 1)

var sz = 3
def remove(idx: Int) =
if (idx >= 0 && idx < sz)
sz -= 1
else throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${sz-1})")

@tailrec final def remove(idx: Int, count: Int): Unit =
if (count > 0) {
remove(idx) // after rewrite, don't flag me as a leftover tailrec
remove(idx, count-1)
}
}

0 comments on commit ece451c

Please sign in to comment.