Skip to content

Commit

Permalink
Make sure isEmpty's type isn't constant folded
Browse files Browse the repository at this point in the history
Also test the same is true for unapply.
  • Loading branch information
dwijnand committed Feb 17, 2021
1 parent 56da205 commit a275b61
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 8 deletions.
9 changes: 9 additions & 0 deletions src/compiler/scala/tools/nsc/Global.scala
Expand Up @@ -1016,11 +1016,20 @@ class Global(var currentSettings: Settings, reporter0: Reporter)
&& rootMirror.isMirrorInitialized
)
override def isPastTyper = isPast(currentRun.typerPhase)
def isBeforeErasure = isBefore(currentRun.erasurePhase)
def isPast(phase: Phase) = (
(curRun ne null)
&& isGlobalInitialized // defense against init order issues
&& (globalPhase.id > phase.id)
)
def isBefore(phase: Phase) = (
(curRun ne null)
&& isGlobalInitialized // defense against init order issues
&& (phase match {
case NoPhase => true // if phase is NoPhase then that phase ain't comin', so we're "before it"
case _ => globalPhase.id < phase.id
})
)

// TODO - trim these to the absolute minimum.
@inline final def exitingErasure[T](op: => T): T = exitingPhase(currentRun.erasurePhase)(op)
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -1240,8 +1240,8 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
inferExprAlternative(tree, pt)
adaptAfterOverloadResolution(tree, mode, pt, original)
case NullaryMethodType(restpe) => // (2)
val resTpDeconst =
if (isPastTyper || (tree.symbol.isAccessor && tree.symbol.hasFlag(STABLE) && treeInfo.isExprSafeToInline(tree))) restpe
val resTpDeconst = // keep constant types when they are safe to fold. erasure eliminates constant types modulo some exceptions, so keep those.
if (isBeforeErasure && tree.symbol.isAccessor && tree.symbol.hasFlag(STABLE) && treeInfo.isExprSafeToInline(tree)) restpe
else restpe.deconst
adapt(tree setType resTpDeconst, mode, pt, original)
case TypeRef(_, ByNameParamClass, arg :: Nil) if mode.inExprMode => // (2)
Expand Down
36 changes: 36 additions & 0 deletions test/files/run/patmat-no-inline-isEmpty.check
@@ -0,0 +1,36 @@
[[syntax trees at end of patmat]] // newSource1.scala
package <empty> {
object A extends scala.AnyRef {
def <init>(): A.type = {
A.super.<init>();
()
};
def unapplySeq(a: Int): Wrap = new Wrap(a)
};
class T extends scala.AnyRef {
def <init>(): T = {
T.super.<init>();
()
};
def t: Any = {
case <synthetic> val x1: Int(2) = 2;
case5(){
<synthetic> val o7: Wrap = A.unapplySeq(x1);
if (o7.isEmpty.unary_!)
{
val xs: Seq[Int] = o7.get.toSeq;
matchEnd4(xs)
}
else
case6()
};
case6(){
matchEnd4("other")
};
matchEnd4(x: Any){
x
}
}
}
}

31 changes: 31 additions & 0 deletions test/files/run/patmat-no-inline-isEmpty.scala
@@ -0,0 +1,31 @@
import scala.tools.partest._

object Test extends DirectTest {
def depCode =
"""class Wrap(private val a: Int) extends AnyVal {
| def isEmpty: false = { println("confirm seq isEmpty method doesn't get elided"); false }
| def get = this
| def lengthCompare(len: Int) = Integer.compare(1, len)
| def apply(i: Int) = if (i == 0) a else Nil(i)
| def drop(n: Int): scala.Seq[Int] = if (n == 0) toSeq else Nil
| def toSeq: scala.Seq[Int] = List(a)
|}
""".stripMargin

override def code =
"""object A {
| def unapplySeq(a: Int) = new Wrap(a)
|}
|class T {
| def t: Any = 2 match {
| case A(xs @ _*) => xs
| case _ => "other"
| }
|}
""".stripMargin

def show(): Unit = Console.withErr(System.out) {
compileString(newCompiler("-usejavacp"))(depCode)
compileString(newCompiler("-usejavacp", "-cp", testOutput.path, "-Vprint:patmat"))(code)
}
}
25 changes: 25 additions & 0 deletions test/files/run/patmat-no-inline-unapply.check
@@ -0,0 +1,25 @@
[[syntax trees at end of patmat]] // newSource1.scala
package <empty> {
class T extends scala.AnyRef {
def <init>(): T = {
T.super.<init>();
()
};
def t: Any = {
case <synthetic> val x1: Int(2) = 2;
case5(){
if (A.unapply(x1))
matchEnd4("ok")
else
case6()
};
case6(){
matchEnd4("other")
};
matchEnd4(x: Any){
x
}
}
}
}

23 changes: 23 additions & 0 deletions test/files/run/patmat-no-inline-unapply.scala
@@ -0,0 +1,23 @@
import scala.tools.partest._

object Test extends DirectTest {
def depCode =
"""object A {
| def unapply(a: Int): true = true
|}
""".stripMargin

override def code =
"""class T {
| def t: Any = 2 match {
| case A() => "ok"
| case _ => "other"
| }
|}
""".stripMargin

def show(): Unit = Console.withErr(System.out) {
compileString(newCompiler("-usejavacp"))(depCode)
compileString(newCompiler("-usejavacp", "-cp", testOutput.path, "-Vprint:patmat"))(code)
}
}
12 changes: 6 additions & 6 deletions test/files/run/patmat-seq.check
Expand Up @@ -66,7 +66,7 @@ package <empty> {
case <synthetic> val x1: Int(2) = 2;
case16(){
<synthetic> val o18: collection.SeqFactory.UnapplySeqWrapper[Int] = A.unapplySeq(x1);
if (false.unary_!)
if (o18.isEmpty.unary_!)
{
val xs: Seq[Int] = o18.get.toSeq;
matchEnd15(xs)
Expand All @@ -76,7 +76,7 @@ package <empty> {
};
case17(){
<synthetic> val o20: collection.SeqFactory.UnapplySeqWrapper[Int] = A.unapplySeq(x1);
if (false.unary_!.&&(o20.get.!=(null).&&(o20.get.lengthCompare(2).==(0))))
if (o20.isEmpty.unary_!.&&(o20.get.!=(null).&&(o20.get.lengthCompare(2).==(0))))
{
val x: Int = o20.get.apply(0);
val y: Int = o20.get.apply(1);
Expand All @@ -87,7 +87,7 @@ package <empty> {
};
case19(){
<synthetic> val o22: collection.SeqFactory.UnapplySeqWrapper[Int] = A.unapplySeq(x1);
if (false.unary_!.&&(o22.get.!=(null).&&(o22.get.lengthCompare(1).>=(0))))
if (o22.isEmpty.unary_!.&&(o22.get.!=(null).&&(o22.get.lengthCompare(1).>=(0))))
{
val x: Int = o22.get.apply(0);
val xs: Seq[Int] = o22.get.drop(1);
Expand Down Expand Up @@ -267,7 +267,7 @@ package <empty> {
case <synthetic> val x1: Int(2) = 2;
case16(){
<synthetic> val o18: scala.collection.SeqOps = A.unapplySeq(x1);
if (false.unary_!())
if (scala.collection.SeqFactory.UnapplySeqWrapper.isEmpty$extension(o18).unary_!())
{
val xs: Seq = scala.collection.SeqFactory.UnapplySeqWrapper.toSeq$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o18));
matchEnd15(xs)
Expand All @@ -277,7 +277,7 @@ package <empty> {
};
case17(){
<synthetic> val o20: scala.collection.SeqOps = A.unapplySeq(x1);
if (false.unary_!().&&(new collection.SeqFactory.UnapplySeqWrapper(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o20)).!=(null).&&(scala.collection.SeqFactory.UnapplySeqWrapper.lengthCompare$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o20), 2).==(0))))
if (scala.collection.SeqFactory.UnapplySeqWrapper.isEmpty$extension(o20).unary_!().&&(new collection.SeqFactory.UnapplySeqWrapper(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o20)).!=(null).&&(scala.collection.SeqFactory.UnapplySeqWrapper.lengthCompare$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o20), 2).==(0))))
{
val x: Int = unbox(scala.collection.SeqFactory.UnapplySeqWrapper.apply$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o20), 0));
val y: Int = unbox(scala.collection.SeqFactory.UnapplySeqWrapper.apply$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o20), 1));
Expand All @@ -288,7 +288,7 @@ package <empty> {
};
case19(){
<synthetic> val o22: scala.collection.SeqOps = A.unapplySeq(x1);
if (false.unary_!().&&(new collection.SeqFactory.UnapplySeqWrapper(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o22)).!=(null).&&(scala.collection.SeqFactory.UnapplySeqWrapper.lengthCompare$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o22), 1).>=(0))))
if (scala.collection.SeqFactory.UnapplySeqWrapper.isEmpty$extension(o22).unary_!().&&(new collection.SeqFactory.UnapplySeqWrapper(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o22)).!=(null).&&(scala.collection.SeqFactory.UnapplySeqWrapper.lengthCompare$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o22), 1).>=(0))))
{
val x: Int = unbox(scala.collection.SeqFactory.UnapplySeqWrapper.apply$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o22), 0));
val xs: Seq = scala.collection.SeqFactory.UnapplySeqWrapper.drop$extension(scala.collection.SeqFactory.UnapplySeqWrapper.get$extension(o22), 1);
Expand Down

0 comments on commit a275b61

Please sign in to comment.