Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthesize a PartialFunction from function literal #8172

Merged
merged 1 commit into from Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions spec/06-expressions.md
Expand Up @@ -1177,7 +1177,7 @@ If there is no expected type for the function literal, all formal parameter type
The eventual run-time value of an anonymous function is determined by the expected type:
- a subclass of one of the builtin function types, `scala.Function$n$[$S_1 , \ldots , S_n$, $R\,$]` (with $S_i$ and $R$ fully defined),
- a [single-abstract-method (SAM) type](#sam-conversion);
- `PartialFunction[$T$, $U$]`, if the function literal is of the shape `x => x match { $\ldots$ }`
- `PartialFunction[$T$, $U$]`
- some other type.

The standard anonymous function evaluates in the same way as the following instance creation expression:
Expand All @@ -1192,7 +1192,15 @@ The same evaluation holds for a SAM type, except that the instantiated type is g

The underlying platform may provide more efficient ways of constructing these instances, such as Java 8's `invokedynamic` bytecode and `LambdaMetaFactory` class.

A `PartialFunction`'s value receives an additional `isDefinedAt` member, which is derived from the pattern match in the function literal, with each case's body being replaced by `true`, and an added default (if none was given) that evaluates to `false`.
When a `PartialFunction` is required, an additional member `isDefinedAt`
is synthesized, which simply returns `true`.
However, if the function literal has the shape `x => x match { $\ldots$ }`,
then `isDefinedAt` is derived from the pattern match in the following way:
each case from the match expression evaluates to `true`,
and if there is no default case,
a default case is added that evalutes to `false`.
For more details on how that is implemented see
["Pattern Matching Anonymous Functions"](08-pattern-matching.html#pattern-matching-anonymous-functions).

###### Example
Examples of anonymous functions:
Expand Down
34 changes: 18 additions & 16 deletions src/compiler/scala/tools/nsc/typechecker/Typers.scala
Expand Up @@ -3059,23 +3059,25 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper

setError(fun)
}
} else {
fun.body match {
// translate `x => x match { <cases> }` : PartialFunction to
// `new PartialFunction { def applyOrElse(x, default) = x match { <cases> } def isDefinedAt(x) = ... }`
case Match(sel, cases) if (sel ne EmptyTree) && (pt.typeSymbol == PartialFunctionClass) =>
// go to outer context -- must discard the context that was created for the Function since we're discarding the function
// thus, its symbol, which serves as the current context.owner, is not the right owner
// you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner)
val outerTyper = newTyper(context.outer)
val p = vparams.head
if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe

outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, fun.body, mode, pt)

case _ => doTypedFunction(fun, resProto)
} else if (pt.typeSymbol == PartialFunctionClass) {
// translate `x => x match { <cases> }` : PartialFunction to
// `new PartialFunction { def applyOrElse(x, default) = x match { <cases> } def isDefinedAt(x) = ... }`
val funBody = fun.body match {
case Match(sel, _) if sel ne EmptyTree => fun.body
case funBody =>
atPos(funBody.pos.makeTransparent) {
Match(EmptyTree, List(CaseDef(Bind(nme.DEFAULT_CASE, Ident(nme.WILDCARD)), funBody)))
}
}
}
// go to outer context -- must discard the context that was created for the Function since we're discarding the function
// thus, its symbol, which serves as the current context.owner, is not the right owner
// you won't know you're using the wrong owner until lambda lift crashes (unless you know better than to use the wrong owner)
val outerTyper = newTyper(context.outer)
val p = vparams.head
if (p.tpt.tpe == null) p.tpt setType outerTyper.typedType(p.tpt).tpe

outerTyper.synthesizePartialFunction(p.name, p.pos, paramSynthetic = false, funBody, mode, pt)
} else doTypedFunction(fun, resProto)
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions test/files/pos/partialfun.scala
Expand Up @@ -8,4 +8,7 @@ object partialfun {
case None => throw new MatchError(None)
} (None);

// Again, but using function literal
applyPartial(_.get)(None)

}
6 changes: 6 additions & 0 deletions test/files/pos/virtpatmat_partialfun_nsdnho.scala
Expand Up @@ -15,4 +15,10 @@ class Test {
// at scala.tools.nsc.typechecker.SuperAccessors$SuperAccTransformer.hostForAccessorOf(SuperAccessors.scala:474)
// at scala.tools.nsc.typechecker.SuperAccessors$SuperAccTransformer.needsProtectedAccessor(SuperAccessors.scala:457)
val c: (Int => (Any => Any)) = { m => { case _ => m.toInt } }


// Again, but using function literal
val a2: (Map[Int, Int] => (Any => Any)) = { m => { _ => m - 1} }
val b2: (Int => (Any => Any)) = { m => { _ => m } }
val c2: (Int => (Any => Any)) = { m => { _ => m.toInt } }
}
38 changes: 38 additions & 0 deletions test/files/run/partialfun.check
Expand Up @@ -4,3 +4,41 @@
0:isDefinedAt
1:isDefinedAt
2:apply

false
true
Vector(1, 2, 3, 4, 5)
Vector(1, 2, 3, 4, 5)

testing function literal syntax with methods overloaded for Function1 and PartialFunction
base case: a method that takes a Function1, so no overloading
fn only
fn only
fn only

base case: a method that takes a PartialFunction, so no overloading
pf only
pf only
pf only

test case: a method that is overloaded for Funtion1 and PartialFunction
fn wins
pf wins
pf wins

testing eta-expansion with methods overloaded for Function1 and PartialFunction
base case: a method that takes a Function1, so no overloading
fn only
fn only
fn only
fn only

base case: a method that takes a PartialFunction, so no overloading
pf only
pf only
pf only

test case: a method that is overloaded for Funtion1 and PartialFunction
fn wins
fn wins
fn wins
67 changes: 67 additions & 0 deletions test/files/run/partialfun.scala
Expand Up @@ -81,8 +81,75 @@ object Test {
chained(())
}

def fromFunctionLiteralTest(): Unit = {
def isEven(n: Int): Boolean = PartialFunction.cond(n)(_ % 2 == 0)
println(isEven(1))
println(isEven(2))
println((1 to 5).map(_.toString))
println((1 to 5).collect(_.toString))
}

def takeFunction1(fn: String => String) = println("fn only")
def takePartialFunction(pf: PartialFunction[String, String]) = println("pf only")
def takeFunctionLike(fn: String => String) = println("fn wins")
def takeFunctionLike(pf: PartialFunction[String, String]) = println("pf wins")

def testOverloadingWithFunction1(): Unit = {
println("testing function literal syntax with methods overloaded for Function1 and PartialFunction")
println("base case: a method that takes a Function1, so no overloading")
takeFunction1(_.reverse)
takeFunction1 { case s => s.reverse }
takeFunction1 { case s: String => s.reverse }
println()

println("base case: a method that takes a PartialFunction, so no overloading")
takePartialFunction(_.reverse)
takePartialFunction { case s => s.reverse }
takePartialFunction { case s: String => s.reverse }
println()

println("test case: a method that is overloaded for Funtion1 and PartialFunction")
takeFunctionLike(_.reverse)
takeFunctionLike { case s => s.reverse }
takeFunctionLike { case s: String => s.reverse }
}

def reverse(s: String): String = s.reverse

def testEtaExpansion(): Unit = {
println("testing eta-expansion with methods overloaded for Function1 and PartialFunction")
println("base case: a method that takes a Function1, so no overloading")
takeFunction1(x => reverse(x))
takeFunction1(reverse(_))
takeFunction1(reverse _)
takeFunction1(reverse)
println()

println("base case: a method that takes a PartialFunction, so no overloading")
takePartialFunction(x => reverse(x))
takePartialFunction(reverse(_))
takePartialFunction(reverse _)
//takePartialFunction(reverse) // can't pass a method to a method that takes a PartialFunction
println()

println("test case: a method that is overloaded for Funtion1 and PartialFunction")
takeFunctionLike(x => reverse(x))
takeFunctionLike(reverse(_))
takeFunctionLike(reverse _)
//takeFunctionLike(reverse) // can't pass a method to a method overloaded to take a Function1 or a PartialFunction
}

def main(args: Array[String]): Unit = {
collectTest()
orElseTest()
println()

fromFunctionLiteralTest()
println()

testOverloadingWithFunction1()
println()

testEtaExpansion()
}
}