diff --git a/spec/06-expressions.md b/spec/06-expressions.md index f2ef365214b1..3ed327e85aed 100644 --- a/spec/06-expressions.md +++ b/spec/06-expressions.md @@ -1177,7 +1177,7 @@ If there is no expected type for the function literal, all formal parameter type The eventual run-time value of an anonymous function is determined by the expected type: - a subclass of one of the builtin function types, `scala.Function$n$[$S_1 , \ldots , S_n$, $R\,$]` (with $S_i$ and $R$ fully defined), - a [single-abstract-method (SAM) type](#sam-conversion); - - `PartialFunction[$T$, $U$]`, if the function literal is of the shape `x => x match { $\ldots$ }` + - `PartialFunction[$T$, $U$]` - some other type. The standard anonymous function evaluates in the same way as the following instance creation expression: @@ -1192,7 +1192,15 @@ The same evaluation holds for a SAM type, except that the instantiated type is g The underlying platform may provide more efficient ways of constructing these instances, such as Java 8's `invokedynamic` bytecode and `LambdaMetaFactory` class. -A `PartialFunction`'s value receives an additional `isDefinedAt` member, which is derived from the pattern match in the function literal, with each case's body being replaced by `true`, and an added default (if none was given) that evaluates to `false`. +When a `PartialFunction` is required, an additional member `isDefinedAt` +is synthesized, which simply returns `true`. +However, if the function literal has the shape `x => x match { $\ldots$ }`, +then `isDefinedAt` is derived from the pattern match in the following way: +each case from the match expression evaluates to `true`, +and if there is no default case, +a default case is added that evalutes to `false`. +For more details on how that is implemented see +["Pattern Matching Anonymous Functions"](08-pattern-matching.html#pattern-matching-anonymous-functions). ###### Example Examples of anonymous functions: diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 3742e866c433..e7fdab5e4aa6 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -3059,23 +3059,25 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper setError(fun) } - } else { - fun.body match { - // translate `x => x match { }` : PartialFunction to - // `new PartialFunction { def applyOrElse(x, default) = x match { } def isDefinedAt(x) = ... }` - case Match(sel, cases) if (sel ne EmptyTree) && (pt.typeSymbol == PartialFunctionClass) => - // go to outer context -- must discard the context that was created for the Function since we're discarding the function - // thus, its symbol, which serves as the current context.owner, is not the right owner - // you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner) - val outerTyper = newTyper(context.outer) - val p = vparams.head - if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe - - outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, fun.body, mode, pt) - - case _ => doTypedFunction(fun, resProto) + } else if (pt.typeSymbol == PartialFunctionClass) { + // translate `x => x match { }` : PartialFunction to + // `new PartialFunction { def applyOrElse(x, default) = x match { } def isDefinedAt(x) = ... }` + val funBody = fun.body match { + case Match(sel, _) if sel ne EmptyTree => fun.body + case funBody => + atPos(funBody.pos.makeTransparent) { + Match(EmptyTree, List(CaseDef(Bind(nme.DEFAULT_CASE, Ident(nme.WILDCARD)), funBody))) + } } - } + // go to outer context -- must discard the context that was created for the Function since we're discarding the function + // thus, its symbol, which serves as the current context.owner, is not the right owner + // you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner) + val outerTyper = newTyper(context.outer) + val p = vparams.head + if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe + + outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, funBody, mode, pt) + } else doTypedFunction(fun, resProto) } } } diff --git a/test/files/pos/partialfun.scala b/test/files/pos/partialfun.scala index 9f32a2202313..87c410e1ee39 100644 --- a/test/files/pos/partialfun.scala +++ b/test/files/pos/partialfun.scala @@ -8,4 +8,7 @@ object partialfun { case None => throw new MatchError(None) } (None); + // Again, but using function literal + applyPartial(_.get)(None) + } diff --git a/test/files/pos/virtpatmat_partialfun_nsdnho.scala b/test/files/pos/virtpatmat_partialfun_nsdnho.scala index 2a2a23d883c4..46536480730a 100644 --- a/test/files/pos/virtpatmat_partialfun_nsdnho.scala +++ b/test/files/pos/virtpatmat_partialfun_nsdnho.scala @@ -15,4 +15,10 @@ class Test { // at scala.tools.nsc.typechecker.SuperAccessors$SuperAccTransformer.hostForAccessorOf(SuperAccessors.scala:474) // at scala.tools.nsc.typechecker.SuperAccessors$SuperAccTransformer.needsProtectedAccessor(SuperAccessors.scala:457) val c: (Int => (Any => Any)) = { m => { case _ => m.toInt } } + + + // Again, but using function literal + val a2: (Map[Int, Int] => (Any => Any)) = { m => { _ => m - 1} } + val b2: (Int => (Any => Any)) = { m => { _ => m } } + val c2: (Int => (Any => Any)) = { m => { _ => m.toInt } } } diff --git a/test/files/run/partialfun.check b/test/files/run/partialfun.check index d4e9f494cd6f..06f56ae9c3b6 100644 --- a/test/files/run/partialfun.check +++ b/test/files/run/partialfun.check @@ -4,3 +4,41 @@ 0:isDefinedAt 1:isDefinedAt 2:apply + +false +true +Vector(1, 2, 3, 4, 5) +Vector(1, 2, 3, 4, 5) + +testing function literal syntax with methods overloaded for Function1 and PartialFunction +base case: a method that takes a Function1, so no overloading +fn only +fn only +fn only + +base case: a method that takes a PartialFunction, so no overloading +pf only +pf only +pf only + +test case: a method that is overloaded for Funtion1 and PartialFunction +fn wins +pf wins +pf wins + +testing eta-expansion with methods overloaded for Function1 and PartialFunction +base case: a method that takes a Function1, so no overloading +fn only +fn only +fn only +fn only + +base case: a method that takes a PartialFunction, so no overloading +pf only +pf only +pf only + +test case: a method that is overloaded for Funtion1 and PartialFunction +fn wins +fn wins +fn wins diff --git a/test/files/run/partialfun.scala b/test/files/run/partialfun.scala index af55ce2e695e..21fc9a53252e 100644 --- a/test/files/run/partialfun.scala +++ b/test/files/run/partialfun.scala @@ -81,8 +81,75 @@ object Test { chained(()) } + def fromFunctionLiteralTest(): Unit = { + def isEven(n: Int): Boolean = PartialFunction.cond(n)(_ % 2 == 0) + println(isEven(1)) + println(isEven(2)) + println((1 to 5).map(_.toString)) + println((1 to 5).collect(_.toString)) + } + + def takeFunction1(fn: String => String) = println("fn only") + def takePartialFunction(pf: PartialFunction[String, String]) = println("pf only") + def takeFunctionLike(fn: String => String) = println("fn wins") + def takeFunctionLike(pf: PartialFunction[String, String]) = println("pf wins") + + def testOverloadingWithFunction1(): Unit = { + println("testing function literal syntax with methods overloaded for Function1 and PartialFunction") + println("base case: a method that takes a Function1, so no overloading") + takeFunction1(_.reverse) + takeFunction1 { case s => s.reverse } + takeFunction1 { case s: String => s.reverse } + println() + + println("base case: a method that takes a PartialFunction, so no overloading") + takePartialFunction(_.reverse) + takePartialFunction { case s => s.reverse } + takePartialFunction { case s: String => s.reverse } + println() + + println("test case: a method that is overloaded for Funtion1 and PartialFunction") + takeFunctionLike(_.reverse) + takeFunctionLike { case s => s.reverse } + takeFunctionLike { case s: String => s.reverse } + } + + def reverse(s: String): String = s.reverse + + def testEtaExpansion(): Unit = { + println("testing eta-expansion with methods overloaded for Function1 and PartialFunction") + println("base case: a method that takes a Function1, so no overloading") + takeFunction1(x => reverse(x)) + takeFunction1(reverse(_)) + takeFunction1(reverse _) + takeFunction1(reverse) + println() + + println("base case: a method that takes a PartialFunction, so no overloading") + takePartialFunction(x => reverse(x)) + takePartialFunction(reverse(_)) + takePartialFunction(reverse _) + //takePartialFunction(reverse) // can't pass a method to a method that takes a PartialFunction + println() + + println("test case: a method that is overloaded for Funtion1 and PartialFunction") + takeFunctionLike(x => reverse(x)) + takeFunctionLike(reverse(_)) + takeFunctionLike(reverse _) + //takeFunctionLike(reverse) // can't pass a method to a method overloaded to take a Function1 or a PartialFunction + } + def main(args: Array[String]): Unit = { collectTest() orElseTest() + println() + + fromFunctionLiteralTest() + println() + + testOverloadingWithFunction1() + println() + + testEtaExpansion() } }