Skip to content

Commit

Permalink
Merge pull request #10775 from lrytz/t12990
Browse files Browse the repository at this point in the history
Desugar switchable matches with guards in async methods
  • Loading branch information
lrytz committed May 8, 2024
2 parents b5bfde4 + aeb5061 commit ccdcde3
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/compiler/scala/tools/nsc/transform/async/AsyncPhase.scala
Expand Up @@ -31,6 +31,8 @@ abstract class AsyncPhase extends Transform with TypingTransformers with AnfTran
stateDiagram: ((Symbol, Tree) => Option[String => Unit]),
allowExceptionsToPropagate: Boolean) extends PlainAttachment

def hasAsyncAttachment(dd: DefDef) = dd.hasAttachment[AsyncAttachment]

// Optimization: avoid the transform altogether if there are no async blocks in a unit.
private val sourceFilesToTransform = perRunCaches.newSet[SourceFile]()
private val awaits: mutable.Set[Symbol] = perRunCaches.newSet[Symbol]()
Expand Down
Expand Up @@ -209,6 +209,8 @@ trait MatchOptimization extends MatchTreeMaking with MatchApproximation {
trait SwitchEmission extends TreeMakers with MatchMonadInterface {
import treeInfo.isGuardedCase

def inAsync: Boolean

abstract class SwitchMaker {
abstract class SwitchableTreeMakerExtractor { def unapply(x: TreeMaker): Option[Tree] }
val SwitchableTreeMaker: SwitchableTreeMakerExtractor
Expand Down Expand Up @@ -497,7 +499,7 @@ trait MatchOptimization extends MatchTreeMaking with MatchApproximation {
class RegularSwitchMaker(scrutSym: Symbol, matchFailGenOverride: Option[Tree => Tree], val unchecked: Boolean) extends SwitchMaker { import CODE._
val switchableTpe = Set(ByteTpe, ShortTpe, IntTpe, CharTpe, StringTpe)
val alternativesSupported = true
val canJump = true
val canJump = !inAsync

// Constant folding sets the type of a constant tree to `ConstantType(Constant(folded))`
// The tree itself can be a literal, an ident, a selection, ...
Expand Down
Expand Up @@ -65,7 +65,17 @@ trait PatternMatching extends Transform
def newTransformer(unit: CompilationUnit): AstTransformer = new MatchTransformer(unit)

class MatchTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
private var inAsync = false

override def transform(tree: Tree): Tree = tree match {
case dd: DefDef if async.hasAsyncAttachment(dd) =>
val wasInAsync = inAsync
try {
inAsync = true
super.transform(dd)
} finally
inAsync = wasInAsync

case CaseDef(UnApply(Apply(Select(qual, nme.unapply), Ident(nme.SELECTOR_DUMMY) :: Nil), (bind@Bind(b, Ident(nme.WILDCARD))) :: Nil), guard, body)
if guard.isEmpty && qual.symbol == definitions.NonFatalModule =>
transform(treeCopy.CaseDef(
Expand Down Expand Up @@ -103,16 +113,17 @@ trait PatternMatching extends Transform
}

def translator(selectorPos: Position): MatchTranslator with CodegenCore = {
new OptimizingMatchTranslator(localTyper, selectorPos)
new OptimizingMatchTranslator(localTyper, selectorPos, inAsync)
}

}


class OptimizingMatchTranslator(val typer: analyzer.Typer, val selectorPos: Position) extends MatchTranslator
with MatchOptimizer
with MatchAnalyzer
with Solver
class OptimizingMatchTranslator(val typer: analyzer.Typer, val selectorPos: Position, val inAsync: Boolean)
extends MatchTranslator
with MatchOptimizer
with MatchAnalyzer
with Solver
}

trait Debugging {
Expand Down
50 changes: 50 additions & 0 deletions test/async/run/switch-await-in-guard.scala
@@ -0,0 +1,50 @@
//> using options -Xasync

import scala.tools.partest.async.OptionAwait._
import org.junit.Assert._

object Test {
def main(args: Array[String]): Unit = {
assertEquals(Some(22), sw1(11))
assertEquals(Some(3), sw1(3))

assertEquals(Some(22), sw2(11))
assertEquals(Some(3), sw2(3))

assertEquals(Some(22), sw3(11))
assertEquals(Some(44), sw3(22))
assertEquals(Some(3), sw3(3))

assertEquals(Some("22"), swS("11"))
assertEquals(Some("3"), swS("3"))
}

private def sw1(i: Int) = optionally {
i match {
case 11 if value(Some(430)) > 42 => 22
case p => p
}
}

private def sw2(i: Int) = optionally {
i match {
case 11 => if (value(Some(430)) > 42) 22 else i
case p => p
}
}

private def sw3(i: Int) = optionally {
i match {
case 11 => if (value(Some(430)) > 42) 22 else i
case 22 | 33 => 44
case p => p
}
}

private def swS(s: String) = optionally {
s match {
case "11" if value(Some(430)) > 42 => "22"
case p => p
}
}
}
42 changes: 42 additions & 0 deletions test/junit/scala/tools/nsc/backend/jvm/BytecodeTest.scala
Expand Up @@ -1003,4 +1003,46 @@ class BytecodeTest extends BytecodeTesting {
val lines = compileMethod(c1).instructions.collect { case l: LineNumber => l }
assertSameCode(List(LineNumber(2, Label(0))), lines)
}


@Test
def t12990(): Unit = {
val komp = BytecodeTesting.newCompiler(extraArgs = "-Xasync")
val code =
"""import scala.tools.nsc.OptionAwait._
|
|class C {
| def sw1(i: Int) = optionally {
| i match {
| case 11 if value(Some(430)) > 42 => 22
| case p => p
| }
| }
| def sw2(i: Int) = optionally {
| i match {
| case 11 => if (value(Some(430)) > 42) 22 else i
| case p => p
| }
| }
| def sw3(i: Int) = optionally {
| i match {
| case 11 => if (value(Some(430)) > 42) 22 else i
| case 22 | 33 => 44
| case p => p
| }
| }
|}
|""".stripMargin
val cs = komp.compileClasses(code)

val sm1 = getMethod(cs.find(_.name == "C$stateMachine$async$1").get, "apply")
assertSame(1, sm1.instructions.count(_.opcode == TABLESWITCH))

val sm2 = getMethod(cs.find(_.name == "C$stateMachine$async$2").get, "apply")
assertSame(2, sm2.instructions.count(_.opcode == TABLESWITCH))

val sm3 = getMethod(cs.find(_.name == "C$stateMachine$async$3").get, "apply")
assertSame(1, sm3.instructions.count(_.opcode == TABLESWITCH))
assertSame(1, sm3.instructions.count(_.opcode == LOOKUPSWITCH))
}
}

0 comments on commit ccdcde3

Please sign in to comment.