From 5182656f7fa790e8c0ec931c4a385f315150210e Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Sun, 23 Jun 2019 10:02:12 +0100 Subject: [PATCH] Synthesize a PartialFunction from function literal --- .../scala/tools/nsc/typechecker/Typers.scala | 26 ++++++++++++------- test/files/pos/partialfun.scala | 3 +++ .../pos/virtpatmat_partialfun_nsdnho.scala | 6 +++++ test/files/run/partialfun.check | 4 +++ test/files/run/partialfun.scala | 9 +++++++ 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 3742e866c433..e7f9909f2182 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -3060,18 +3060,24 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper setError(fun) } } else { + // translate `x => x match { }` : PartialFunction to + // `new PartialFunction { def applyOrElse(x, default) = x match { } def isDefinedAt(x) = ... }` + def synthesizePartialFunction(tree: Tree): Tree = { + // go to outer context -- must discard the context that was created for the Function since we're discarding the function + // thus, its symbol, which serves as the current context.owner, is not the right owner + // you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner) + val outerTyper = newTyper(context.outer) + val p = vparams.head + if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe + + outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, tree, mode, pt) + } fun.body match { - // translate `x => x match { }` : PartialFunction to - // `new PartialFunction { def applyOrElse(x, default) = x match { } def isDefinedAt(x) = ... }` case Match(sel, cases) if (sel ne EmptyTree) && (pt.typeSymbol == PartialFunctionClass) => - // go to outer context -- must discard the context that was created for the Function since we're discarding the function - // thus, its symbol, which serves as the current context.owner, is not the right owner - // you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner) - val outerTyper = newTyper(context.outer) - val p = vparams.head - if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe - - outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, fun.body, mode, pt) + synthesizePartialFunction(fun.body) + + case _ if pt.typeSymbol == PartialFunctionClass => + synthesizePartialFunction(Match(EmptyTree, List(CaseDef(EmptyTree, fun.body)))) case _ => doTypedFunction(fun, resProto) } diff --git a/test/files/pos/partialfun.scala b/test/files/pos/partialfun.scala index 9f32a2202313..87c410e1ee39 100644 --- a/test/files/pos/partialfun.scala +++ b/test/files/pos/partialfun.scala @@ -8,4 +8,7 @@ object partialfun { case None => throw new MatchError(None) } (None); + // Again, but using function literal + applyPartial(_.get)(None) + } diff --git a/test/files/pos/virtpatmat_partialfun_nsdnho.scala b/test/files/pos/virtpatmat_partialfun_nsdnho.scala index 2a2a23d883c4..46536480730a 100644 --- a/test/files/pos/virtpatmat_partialfun_nsdnho.scala +++ b/test/files/pos/virtpatmat_partialfun_nsdnho.scala @@ -15,4 +15,10 @@ class Test { // at scala.tools.nsc.typechecker.SuperAccessors$SuperAccTransformer.hostForAccessorOf(SuperAccessors.scala:474) // at scala.tools.nsc.typechecker.SuperAccessors$SuperAccTransformer.needsProtectedAccessor(SuperAccessors.scala:457) val c: (Int => (Any => Any)) = { m => { case _ => m.toInt } } + + + // Again, but using function literal + val a2: (Map[Int, Int] => (Any => Any)) = { m => { _ => m - 1} } + val b2: (Int => (Any => Any)) = { m => { _ => m } } + val c2: (Int => (Any => Any)) = { m => { _ => m.toInt } } } diff --git a/test/files/run/partialfun.check b/test/files/run/partialfun.check index d4e9f494cd6f..e30dc9e89d2c 100644 --- a/test/files/run/partialfun.check +++ b/test/files/run/partialfun.check @@ -4,3 +4,7 @@ 0:isDefinedAt 1:isDefinedAt 2:apply +false +true +Vector(1, 2, 3, 4, 5) +Vector(1, 2, 3, 4, 5) diff --git a/test/files/run/partialfun.scala b/test/files/run/partialfun.scala index af55ce2e695e..8eb9f24c22a0 100644 --- a/test/files/run/partialfun.scala +++ b/test/files/run/partialfun.scala @@ -81,8 +81,17 @@ object Test { chained(()) } + def fromFunctionLiteralTest(): Unit = { + def isEven(n: Int): Boolean = PartialFunction.cond(n)(_ % 2 == 0) + println(isEven(1)) + println(isEven(2)) + println((1 to 5).map(_.toString)) + println((1 to 5).collect(_.toString)) + } + def main(args: Array[String]): Unit = { collectTest() orElseTest() + fromFunctionLiteralTest() } }