Skip to content

Commit

Permalink
Adjust keepAlive for test under JDK 9
Browse files Browse the repository at this point in the history
  • Loading branch information
som-snytt committed Mar 27, 2024
1 parent 4e03eb5 commit 436d7a7
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 43 deletions.
23 changes: 17 additions & 6 deletions src/reflect/scala/reflect/internal/util/ScalaClassLoader.scala
Expand Up @@ -51,17 +51,28 @@ final class RichClassLoader(private val self: JClassLoader) extends AnyVal {
tryToInitializeClass[AnyRef](path).map(_.getConstructor().newInstance()).orNull

/** Create an instance with ctor args, or invoke errorFn before throwing. */
def create[T <: AnyRef : ClassTag](path: String, errorFn: String => Unit)(args: AnyRef*): T = {
def create[T <: AnyRef : ClassTag](path: String, errorFn: String => Unit)(args: Any*): T = {
def fail(msg: String) = error(msg, new IllegalArgumentException(msg))
def error(msg: String, e: Throwable) = { errorFn(msg) ; throw e }
def error(msg: String, e: Throwable) = { errorFn(msg); throw e }
try {
val clazz = Class.forName(path, /*initialize =*/ true, /*loader =*/ self)
if (classTag[T].runtimeClass isAssignableFrom clazz) {
if (classTag[T].runtimeClass.isAssignableFrom(clazz)) {
val ctor = {
val maybes = clazz.getConstructors filter (c => c.getParameterCount == args.size &&
(c.getParameterTypes zip args).forall { case (k, a) => k isAssignableFrom a.getClass })
val bySize = clazz.getConstructors.filter(_.getParameterCount == args.size)
if (bySize.isEmpty) fail(s"No constructor takes ${args.size} parameters.")
def isAssignable(k: Class[?], a: Any): Boolean =
if (k == classOf[Int]) a.isInstanceOf[Integer]
else if (k == classOf[Boolean]) a.isInstanceOf[java.lang.Boolean]
else if (k == classOf[Long]) a.isInstanceOf[java.lang.Long]
else k.isAssignableFrom(a.getClass)
val maybes = bySize.filter(c => c.getParameterTypes.zip(args).forall { case (k, a) => isAssignable(k, a) })
if (maybes.size == 1) maybes.head
else fail(s"Constructor must accept arg list (${args map (_.getClass.getName) mkString ", "}): ${path}")
else if (bySize.size == 1)
fail(s"One constructor takes ${args.size} parameters but ${
bySize.head.getParameterTypes.zip(args).collect { case (k, a) if !isAssignable(k, a) => s"$k != ${a.getClass}" }.mkString("; ")
}.")
else
fail(s"Constructor must accept arg list (${args.map(_.getClass.getName).mkString(", ")}): ${path}")
}
(ctor.newInstance(args: _*)).asInstanceOf[T]
} else {
Expand Down
4 changes: 2 additions & 2 deletions test/files/jvm/scala-concurrent-tck.check
Expand Up @@ -124,8 +124,6 @@ starting rejectedExecutionException
finished rejectedExecutionException
starting testNameOfGlobalECThreads
finished testNameOfGlobalECThreads
starting testUncaughtExceptionReporting
finished testUncaughtExceptionReporting
starting testOnSuccessCustomEC
finished testOnSuccessCustomEC
starting testKeptPromiseCustomEC
Expand All @@ -136,3 +134,5 @@ starting testOnComplete
finished testOnComplete
starting testMap
finished testMap
starting testUncaughtExceptionReporting
finished testUncaughtExceptionReporting
143 changes: 108 additions & 35 deletions test/files/jvm/scala-concurrent-tck.scala
Expand Up @@ -4,6 +4,7 @@ import scala.concurrent.{
TimeoutException,
ExecutionException,
ExecutionContext,
ExecutionContextExecutorService,
CanAwait,
Await,
Awaitable,
Expand All @@ -15,7 +16,7 @@ import scala.reflect.{classTag, ClassTag}
import scala.tools.testkit.AssertUtil.{Fast, Slow, assertThrows, waitFor, waitForIt}
import scala.util.{Try, Success, Failure}
import scala.util.chaining._
import java.util.concurrent.CountDownLatch
import java.util.concurrent.{CountDownLatch, ThreadPoolExecutor}
import java.util.concurrent.TimeUnit.{MILLISECONDS => Milliseconds, SECONDS => Seconds}

trait TestBase {
Expand All @@ -29,7 +30,7 @@ trait TestBase {
def apply(proof: => Boolean): Unit = q offer Try(proof)
})
var tried: Try[Boolean] = null
def check = { tried = q.poll(5000L, Milliseconds) ; tried != null }
def check = q.poll(5000L, Milliseconds).tap(tried = _) != null
waitForIt(check, progress = Slow, label = "concurrent-tck")
assert(tried.isSuccess)
assert(tried.get)
Expand Down Expand Up @@ -747,7 +748,7 @@ class Blocking extends TestBase {

class BlockContexts extends TestBase {
import ExecutionContext.Implicits._
import scala.concurrent.{ Await, Awaitable, BlockContext }
import scala.concurrent.BlockContext

private def getBlockContext(body: => BlockContext): BlockContext = await(Future(body))

Expand Down Expand Up @@ -877,7 +878,6 @@ class GlobalExecutionContext extends TestBase {
}

class CustomExecutionContext extends TestBase {
import scala.concurrent.{ ExecutionContext, Awaitable }

def defaultEC = ExecutionContext.global

Expand Down Expand Up @@ -987,37 +987,6 @@ class CustomExecutionContext extends TestBase {
assert(count >= 1)
}

def testUncaughtExceptionReporting(): Unit = once { done =>
val example = new InterruptedException
val latch = new CountDownLatch(1)
@volatile var thread: Thread = null
@volatile var reported: Throwable = null
val ec = ExecutionContext.fromExecutorService(null, t => {
reported = t
latch.countDown()
})

@tailrec def waitForThreadDeath(turns: Int): Boolean =
turns > 0 && (thread != null && !thread.isAlive || { Thread.sleep(10L) ; waitForThreadDeath(turns - 1) })

def truthfully(b: Boolean): Option[Boolean] = if (b) Some(true) else None

// jdk17 thread receives pool exception handler, so wait for thread to die slow and painful expired keepalive
def threadIsDead =
waitFor(truthfully(waitForThreadDeath(turns = 100)), progress = Slow, label = "concurrent-tck-thread-death")

try {
ec.execute(() => {
thread = Thread.currentThread
throw example
})
latch.await(2, Seconds)
done(threadIsDead && (reported eq example))
}
finally ec.shutdown()
}

test("testUncaughtExceptionReporting")(testUncaughtExceptionReporting())
test("testOnSuccessCustomEC")(testOnSuccessCustomEC())
test("testKeptPromiseCustomEC")(testKeptPromiseCustomEC())
test("testCallbackChainCustomEC")(testCallbackChainCustomEC())
Expand Down Expand Up @@ -1076,6 +1045,103 @@ class ExecutionContextPrepare extends TestBase {
test("testMap")(testMap())
}

class ReportingExecutionContext extends TestBase {
final val slowly = false // true for using default FJP with long keepAlive (60 secs)
@volatile var thread: Thread = null
@volatile var reportedOn: Thread = null
@volatile var reported: Throwable = null
val latch = new CountDownLatch(1)

def report(t: Thread, e: Throwable): Unit = {
reportedOn = t
reported = e
latch.countDown()
}
def underlyingPool = {
import java.util.concurrent.{LinkedBlockingQueue, RejectedExecutionHandler, ThreadFactory, ThreadPoolExecutor}
val coreSize = 4
val maxSize = 4
val keepAlive = 2000L
val q = new LinkedBlockingQueue[Runnable]
val factory: ThreadFactory = (r: Runnable) => new Thread(r).tap(_.setUncaughtExceptionHandler(report))
val handler: RejectedExecutionHandler = (r: Runnable, x: ThreadPoolExecutor) => ???
new ThreadPoolExecutor(coreSize, maxSize, keepAlive, Milliseconds, q, factory, handler)
}
def ecesFromUnderlyingPool = ExecutionContext.fromExecutorService(underlyingPool, report(null, _))

def ecesUsingDefaultFactory = {
import java.util.concurrent.{ForkJoinPool, RejectedExecutionHandler, ThreadPoolExecutor}
import java.util.function.Predicate
import scala.reflect.internal.util.RichClassLoader._

val path = "java.util.concurrent.ForkJoinPool"
val n = 2 // parallelism
val factory = scala.concurrent.TestUtil.threadFactory(report)
val ueh: Thread.UncaughtExceptionHandler = report(_, _)
val async = true
val coreSize = 4
val maxSize = 4
val minRun = 1 // minimumRunnable for liveness
val saturate: Predicate[ForkJoinPool] = (fjp: ForkJoinPool) => false // whether to continue after blocking at maxSize
val keepAlive = 2000L
val fjp = this.getClass.getClassLoader.create[ForkJoinPool](path, _ => ())(n, factory, ueh, async, coreSize, maxSize, minRun, saturate, keepAlive, Milliseconds)
//ForkJoinPool(int parallelism, ForkJoinPool.ForkJoinWorkerThreadFactory factory, Thread.UncaughtExceptionHandler handler, boolean asyncMode, int corePoolSize, int maximumPoolSize, int minimumRunnable, Predicate<? super ForkJoinPool> saturate, long keepAliveTime, TimeUnit unit)
new ExecutionContextExecutorService {
// Members declared in scala.concurrent.ExecutionContext
def reportFailure(cause: Throwable): Unit = report(null, cause)

// Members declared in java.util.concurrent.Executor
def execute(r: Runnable): Unit = fjp.execute(r)

// Members declared in java.util.concurrent.ExecutorService
def awaitTermination(x$1: Long, x$2: java.util.concurrent.TimeUnit): Boolean = ???
def invokeAll[T](x$1: java.util.Collection[_ <: java.util.concurrent.Callable[T]], x$2: Long, x$3: java.util.concurrent.TimeUnit): java.util.List[java.util.concurrent.Future[T]] = ???
def invokeAll[T](x$1: java.util.Collection[_ <: java.util.concurrent.Callable[T]]): java.util.List[java.util.concurrent.Future[T]] = ???
def invokeAny[T](x$1: java.util.Collection[_ <: java.util.concurrent.Callable[T]], x$2: Long, x$3: java.util.concurrent.TimeUnit): T = ???
def invokeAny[T](x$1: java.util.Collection[_ <: java.util.concurrent.Callable[T]]): T = ???
def isShutdown(): Boolean = fjp.isShutdown
def isTerminated(): Boolean = fjp.isTerminated
def shutdown(): Unit = fjp.shutdown()
def shutdownNow(): java.util.List[Runnable] = fjp.shutdownNow()
def submit(r: Runnable): java.util.concurrent.Future[_] = fjp.submit(r)
def submit[T](task: Runnable, res: T): java.util.concurrent.Future[T] = fjp.submit(task, res)
def submit[T](task: java.util.concurrent.Callable[T]): java.util.concurrent.Future[T] = fjp.submit(task)
}
}

def ecesUsingDefaultFJP = ExecutionContext.fromExecutorService(null, report(null, _))

def testUncaughtExceptionReporting(ec: ExecutionContextExecutorService): Unit = once {
done =>
val example = new InterruptedException

@tailrec def spinForThreadDeath(turns: Int): Boolean =
turns > 0 && (thread != null && !thread.isAlive || { Thread.sleep(100L); spinForThreadDeath(turns - 1) })

def truthfully(b: Boolean): Option[Boolean] = if (b) Some(true) else None

// jdk17 thread receives pool exception handler, so wait for thread to die slow and painful expired keepalive
def threadIsDead = waitFor(truthfully(spinForThreadDeath(turns = 10)), progress = if (slowly) Slow else Fast, label = "concurrent-tck-thread-death")

try {
ec.execute(() => {
thread = Thread.currentThread
throw example
})
latch.await(2, Seconds)
done(threadIsDead && (reported eq example))
}
finally ec.shutdown()
}

test("testUncaughtExceptionReporting")(testUncaughtExceptionReporting {
import scala.util.Properties.isJavaAtLeast
if (slowly) ecesUsingDefaultFJP
else if (isJavaAtLeast(9)) ecesUsingDefaultFactory
else ecesFromUnderlyingPool
})
}

object Test
extends App {
new FutureCallbacks
Expand All @@ -1088,6 +1154,13 @@ extends App {
new GlobalExecutionContext
new CustomExecutionContext
new ExecutionContextPrepare
new ReportingExecutionContext

System.exit(0)
}

package scala.concurrent {
object TestUtil {
def threadFactory(uncaughtExceptionHandler: Thread.UncaughtExceptionHandler) = new impl.ExecutionContextImpl.DefaultThreadFactory(daemonic=true, maxBlockers=256, prefix="test-thread", uncaughtExceptionHandler)
}
}

0 comments on commit 436d7a7

Please sign in to comment.