Skip to content

Commit

Permalink
Improve FieldWalker, don't access JDK classes (#1799)
Browse files Browse the repository at this point in the history
* Improve FieldWalker, don't access JDK classes

* Works on future JDKs that forbid reflective access to JDK classes
* Show human-readable path to field is something fails
  • Loading branch information
elizarov committed Feb 13, 2020
1 parent 4aa3880 commit b64a23b
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 84 deletions.
185 changes: 112 additions & 73 deletions kotlinx-coroutines-core/jvm/test/FieldWalker.kt
@@ -1,115 +1,154 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import java.lang.reflect.*
import java.util.*
import java.util.Collections.*
import java.util.concurrent.atomic.*
import kotlin.collections.ArrayList
import kotlin.test.*

object FieldWalker {
sealed class Ref {
object RootRef : Ref()
class FieldRef(val parent: Any, val name: String) : Ref()
class ArrayRef(val parent: Any, val index: Int) : Ref()
}

private val fieldsCache = HashMap<Class<*>, List<Field>>()

init {
// excluded/terminal classes (don't walk them)
fieldsCache += listOf(Any::class, String::class, Thread::class, Throwable::class)
.map { it.java }
.associateWith { emptyList<Field>() }
}

/*
* Reflectively starts to walk through object graph and returns identity set of all reachable objects.
* Use [walkRefs] if you need a path from root for debugging.
*/
public fun walk(root: Any?): Set<Any> = walkRefs(root).keys

public fun assertReachableCount(expected: Int, root: Any?, predicate: (Any) -> Boolean) {
val visited = walkRefs(root)
val actual = visited.keys.filter(predicate)
if (actual.size != expected) {
val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) }
assertEquals(
expected, actual.size,
"Unexpected number objects. Expected $expected, found ${actual.size}$textDump"
)
}
}

/*
* Reflectively starts to walk through object graph and map to all the reached object to their path
* in from root. Use [showPath] do display a path if needed.
*/
public fun walk(root: Any): Set<Any> {
val result = newSetFromMap<Any>(IdentityHashMap())
result.add(root)
private fun walkRefs(root: Any?): Map<Any, Ref> {
val visited = IdentityHashMap<Any, Ref>()
if (root == null) return visited
visited[root] = Ref.RootRef
val stack = ArrayDeque<Any>()
stack.addLast(root)
while (stack.isNotEmpty()) {
val element = stack.removeLast()
val type = element.javaClass
type.visit(element, result, stack)
try {
visit(element, visited, stack)
} catch (e: Exception) {
error("Failed to visit element ${showPath(element, visited)}: $e")
}
}
return result
return visited
}

private fun Class<*>.visit(
element: Any,
result: MutableSet<Any>,
stack: ArrayDeque<Any>
) {
val fields = fields()
fields.forEach {
it.isAccessible = true
val value = it.get(element) ?: return@forEach
if (result.add(value)) {
stack.addLast(value)
private fun showPath(element: Any, visited: Map<Any, Ref>): String {
val path = ArrayList<String>()
var cur = element
while (true) {
val ref = visited.getValue(cur)
if (ref is Ref.RootRef) break
when (ref) {
is Ref.FieldRef -> {
cur = ref.parent
path += ".${ref.name}"
}
is Ref.ArrayRef -> {
cur = ref.parent
path += "[${ref.index}]"
}
}
}
path.reverse()
return path.joinToString("")
}

if (isArray && !componentType.isPrimitive) {
val array = element as Array<Any?>
array.filterNotNull().forEach {
if (result.add(it)) {
stack.addLast(it)
private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>) {
val type = element.javaClass
when {
// Special code for arrays
type.isArray && !type.componentType.isPrimitive -> {
@Suppress("UNCHECKED_CAST")
val array = element as Array<Any?>
array.forEachIndexed { index, value ->
push(value, visited, stack) { Ref.ArrayRef(element, index) }
}
}
// Special code for platform types that cannot be reflectively accessed on modern JDKs
type.name.startsWith("java.") && element is Collection<*> -> {
element.forEachIndexed { index, value ->
push(value, visited, stack) { Ref.ArrayRef(element, index) }
}
}
type.name.startsWith("java.") && element is Map<*, *> -> {
push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") }
push(element.values, visited, stack) { Ref.FieldRef(element, "values") }
}
element is AtomicReference<*> -> {
push(element.get(), visited, stack) { Ref.FieldRef(element, "value") }
}
// All the other classes are reflectively scanned
else -> fields(type).forEach { field ->
push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) }
// special case to scan Throwable cause (cannot get it reflectively)
if (element is Throwable) {
push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") }
}
}
}
}

private fun Class<*>.fields(): List<Field> {
private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) {
if (value != null && !visited.containsKey(value)) {
visited[value] = ref()
stack.addLast(value)
}
}

private fun fields(type0: Class<*>): List<Field> {
fieldsCache[type0]?.let { return it }
val result = ArrayList<Field>()
var type = this
while (type != Any::class.java) {
var type = type0
while (true) {
val fields = type.declaredFields.filter {
!it.type.isPrimitive
&& !Modifier.isStatic(it.modifiers)
&& !(it.type.isArray && it.type.componentType.isPrimitive)
}
fields.forEach { it.isAccessible = true } // make them all accessible
result.addAll(fields)
type = type.superclass
}

return result
}

// Debugging-only
@Suppress("UNUSED")
fun printPath(from: Any, to: Any) {
val pathNodes = ArrayList<String>()
val visited = newSetFromMap<Any>(IdentityHashMap())
visited.add(from)
if (findPath(from, to, visited, pathNodes)) {
pathNodes.reverse()
println(pathNodes.joinToString(" -> ", from.javaClass.simpleName + " -> ", "-> " + to.javaClass.simpleName))
} else {
println("Path from $from to $to not found")
}
}

private fun findPath(from: Any, to: Any, visited: MutableSet<Any>, pathNodes: MutableList<String>): Boolean {
if (from === to) {
return true
}

val type = from.javaClass
if (type.isArray) {
if (type.componentType.isPrimitive) return false
val array = from as Array<Any?>
array.filterNotNull().forEach {
if (findPath(it, to, visited, pathNodes)) {
return true
}
val superFields = fieldsCache[type] // will stop at Any anyway
if (superFields != null) {
result.addAll(superFields)
break
}
return false
}

val fields = type.fields()
fields.forEach {
it.isAccessible = true
val value = it.get(from) ?: return@forEach
if (!visited.add(value)) return@forEach
val found = findPath(value, to, visited, pathNodes)
if (found) {
pathNodes += from.javaClass.simpleName + ":" + it.name
return true
}
}

return false
fieldsCache[type0] = result
return result
}
}
Expand Up @@ -76,7 +76,7 @@ class ReusableCancellableContinuationTest : TestBase() {
expect(4)
ensureActive()
// Verify child was bound
assertNotNull(FieldWalker.walk(coroutineContext[Job]!!).single { it === continuation })
FieldWalker.assertReachableCount(1, coroutineContext[Job]) { it === continuation }
suspendAtomicCancellableCoroutineReusable<Unit> {
expect(5)
coroutineContext[Job]!!.cancel()
Expand All @@ -97,7 +97,7 @@ class ReusableCancellableContinuationTest : TestBase() {
cont = it
}
ensureActive()
assertTrue { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
assertTrue { FieldWalker.walk(coroutineContext[Job]).contains(cont!!) }
finish(2)
}

Expand All @@ -112,7 +112,7 @@ class ReusableCancellableContinuationTest : TestBase() {
cont = it
}
ensureActive()
assertFalse { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
FieldWalker.assertReachableCount(0, coroutineContext[Job]) { it === cont }
finish(2)
}

Expand All @@ -127,7 +127,7 @@ class ReusableCancellableContinuationTest : TestBase() {
}
expectUnreached()
} catch (e: CancellationException) {
assertFalse { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
FieldWalker.assertReachableCount(0, coroutineContext[Job]) { it === cont }
finish(2)
}
}
Expand All @@ -148,19 +148,19 @@ class ReusableCancellableContinuationTest : TestBase() {
expect(4)
ensureActive()
// Verify child was bound
assertEquals(1, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(1, currentJob) { it is CancellableContinuation<*> }
currentJob.cancel()
assertFalse(isActive)
// Child detached
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }
suspendAtomicCancellableCoroutineReusable<Unit> { it.resume(Unit) }
suspendAtomicCancellableCoroutineReusable<Unit> { it.resume(Unit) }
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }

try {
suspendAtomicCancellableCoroutineReusable<Unit> {}
} catch (e: CancellationException) {
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }
finish(5)
}
}
Expand All @@ -184,12 +184,12 @@ class ReusableCancellableContinuationTest : TestBase() {
expect(2)
val job = coroutineContext[Job]!!
// 1 for reusable CC, another one for outer joiner
assertEquals(2, FieldWalker.walk(job).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(2, job) { it is CancellableContinuation<*> }
}
expect(1)
receiver.join()
// Reference should be claimed at this point
assertEquals(0, FieldWalker.walk(receiver).count { it is CancellableContinuation<*> })
FieldWalker.assertReachableCount(0, receiver) { it is CancellableContinuation<*> }
finish(3)
}
}
Expand Up @@ -41,7 +41,7 @@ class ConsumeAsFlowLeakTest : TestBase() {
if (shouldSuspendOnSend) yield()
channel.send(second)
yield()
assertEquals(0, FieldWalker.walk(channel).count { it === second })
FieldWalker.assertReachableCount(0, channel) { it === second }
finish(6)
job.cancelAndJoin()
}
Expand Down

0 comments on commit b64a23b

Please sign in to comment.