Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance type inference for intercept and assertThrows #2043

Open
wants to merge 1 commit into
base: 3.2.x-new
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 7 additions & 5 deletions dotty/core/src/main/scala/org/scalatest/Assertions.scala
Expand Up @@ -727,6 +727,7 @@ trait Assertions extends TripleEquals {
* function. If the function throws an exception that's an instance of the specified type,
* this method returns that exception. Else, whether the passed function returns normally
* or completes abruptly with a different exception, this method throws <code>TestFailedException</code>.
* If the type parameter is left to the type inferencer, any Exception will be caught.
*
* <p>
* Note that the type specified as this method's type parameter may represent any subtype of
Expand Down Expand Up @@ -754,18 +755,18 @@ trait Assertions extends TripleEquals {
* @throws TestFailedException if the passed function does not complete abruptly with an exception
* that's an instance of the specified type.
*/
inline def intercept[T <: AnyRef](f: => Any)(implicit classTag: ClassTag[T]): T =
inline def intercept[T <: AnyRef](f: => Any)(implicit classTag: ClassTag[T]): T with Exception =
${ source.Position.withPosition[T]('{(pos: source.Position) => interceptImpl[T](f, classTag, pos) }) }

private final def interceptImpl[T <: AnyRef](f: => Any, classTag: ClassTag[T], pos: source.Position): T = {
private final def interceptImpl[T <: AnyRef](f: => Any, classTag: ClassTag[T], pos: source.Position): T with Exception = {
val clazz = classTag.runtimeClass
val caught = try {
f
None
}
catch {
case u: Throwable => {
if (!clazz.isAssignableFrom(u.getClass)) {
if (!clazz.isAssignableFrom(u.getClass) && clazz.getName != "scala.runtime.Nothing$") {
val s = Resources.wrongException(clazz.getName, u.getClass.getName)
throw newAssertionFailedException(Some(s), Some(u), pos, Vector.empty)
}
Expand All @@ -778,7 +779,7 @@ trait Assertions extends TripleEquals {
case None =>
val message = Resources.exceptionExpected(clazz.getName)
throw newAssertionFailedException(Some(message), None, pos, Vector.empty)
case Some(e) => e.asInstanceOf[T] // I know this cast will succeed, becuase isAssignableFrom succeeded above
case Some(e) => e.asInstanceOf[T with Exception] // We know this cast will succeed, because isAssignableFrom succeeded above
}
}

Expand All @@ -788,6 +789,7 @@ trait Assertions extends TripleEquals {
* function. If the function throws an exception that's an instance of the specified type,
* this method returns <code>Succeeded</code>. Else, whether the passed function returns normally
* or completes abruptly with a different exception, this method throws <code>TestFailedException</code>.
* If the type parameter is left to the type inferencer, any Exception will be caught.
*
* <p>
* Note that the type specified as this method's type parameter may represent any subtype of
Expand Down Expand Up @@ -827,7 +829,7 @@ trait Assertions extends TripleEquals {
}
catch {
case u: Throwable => {
if (!clazz.isAssignableFrom(u.getClass)) {
if (!clazz.isAssignableFrom(u.getClass) && clazz.getName != "scala.runtime.Nothing$") {
val s = Resources.wrongException(clazz.getName, u.getClass.getName)
throw newAssertionFailedException(Some(s), Some(u), pos, Vector.empty)
}
Expand Down
10 changes: 6 additions & 4 deletions jvm/core/src/main/scala/org/scalatest/Assertions.scala
Expand Up @@ -716,6 +716,7 @@ trait Assertions extends TripleEquals {
* function. If the function throws an exception that's an instance of the specified type,
* this method returns that exception. Else, whether the passed function returns normally
* or completes abruptly with a different exception, this method throws <code>TestFailedException</code>.
* If the type parameter is left to the type inferencer, any Exception will be caught.
*
* <p>
* Note that the type specified as this method's type parameter may represent any subtype of
Expand Down Expand Up @@ -743,15 +744,15 @@ trait Assertions extends TripleEquals {
* @throws TestFailedException if the passed function does not complete abruptly with an exception
* that's an instance of the specified type.
*/
def intercept[T <: AnyRef](f: => Any)(implicit classTag: ClassTag[T], pos: source.Position): T = {
def intercept[T <: AnyRef](f: => Any)(implicit classTag: ClassTag[T], pos: source.Position): T with Exception = {
val clazz = classTag.runtimeClass
val caught = try {
f
None
}
catch {
case u: Throwable => {
if (!clazz.isAssignableFrom(u.getClass)) {
if (!clazz.isAssignableFrom(u.getClass) && clazz.getName != "scala.runtime.Nothing$") {
val s = Resources.wrongException(clazz.getName, u.getClass.getName)
throw newAssertionFailedException(Some(s), Some(u), pos, Vector.empty)
}
Expand All @@ -764,7 +765,7 @@ trait Assertions extends TripleEquals {
case None =>
val message = Resources.exceptionExpected(clazz.getName)
throw newAssertionFailedException(Some(message), None, pos, Vector.empty)
case Some(e) => e.asInstanceOf[T] // I know this cast will succeed, becuase isAssignableFrom succeeded above
case Some(e) => e.asInstanceOf[T with Exception] // We know this cast will succeed, because isAssignableFrom succeeded above
}
}

Expand All @@ -774,6 +775,7 @@ trait Assertions extends TripleEquals {
* function. If the function throws an exception that's an instance of the specified type,
* this method returns <code>Succeeded</code>. Else, whether the passed function returns normally
* or completes abruptly with a different exception, this method throws <code>TestFailedException</code>.
* If the type parameter is left to the type inferencer, any Exception will be caught.
*
* <p>
* Note that the type specified as this method's type parameter may represent any subtype of
Expand Down Expand Up @@ -810,7 +812,7 @@ trait Assertions extends TripleEquals {
}
catch {
case u: Throwable => {
if (!clazz.isAssignableFrom(u.getClass)) {
if (!clazz.isAssignableFrom(u.getClass) && clazz.getName != "scala.runtime.Nothing$") {
val s = Resources.wrongException(clazz.getName, u.getClass.getName)
throw newAssertionFailedException(Some(s), Some(u), pos, Vector.empty)
}
Expand Down
Expand Up @@ -131,6 +131,14 @@ class AssertionsSpec extends AnyFunSpec {
assert(result eq e)
}

it("should catch any exception when the type parameter is left blank") {
val e = new RuntimeException
val result = intercept {
throw e
}
assert(result == e)
}

describe("when the bit of code throws the wrong exception") {
it("should include that wrong exception as the TFE's cause") {
val wrongException = new RuntimeException("oops!")
Expand Down Expand Up @@ -180,6 +188,13 @@ class AssertionsSpec extends AnyFunSpec {
assert(caught.isInstanceOf[TestFailedException])
}

it("should catch any exception if the type parameter is left blank") {
val result = assertThrows {
throw new IllegalArgumentException
}
assert(result eq Succeeded)
}

describe("when the bit of code throws the wrong exception") {
it("should include that wrong exception as the TFE's cause") {
val wrongException = new RuntimeException("oops!")
Expand Down