Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom any value generator #643

Merged
merged 4 commits into from Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions agent/android/src/main/kotlin/io/mockk/ValueClassSupport.kt
Expand Up @@ -30,10 +30,19 @@ fun <T : Any> T.boxedValue(): Any? {
* @return class of boxed value, if this is value class, else just class of itself
*/
fun <T : Any> T.boxedClass(): KClass<*> {
if (!this::class.isValueClass()) return this::class
return this::class.boxedClass()
}

/**
* Get the KClass of boxed value if this is a value class.
*
* @return class of boxed value, if this is value class, else just class of itself
*/
fun KClass<*>.boxedClass(): KClass<*> {
if (!this.isValueClass()) return this

// get backing field
val backingField = this::class.valueField()
val backingField = this.valueField()

// get boxed value
return backingField.returnType.classifier as KClass<*>
Expand All @@ -60,7 +69,7 @@ private fun <T : Any> KClass<T>.valueField(): KProperty1<out T, *> {

private fun <T : Any> KClass<T>.isValueClass() = try {
this.isValue
} catch (_: UnsupportedOperationException) {
} catch (_: Throwable) {
false
}

Expand Down
15 changes: 12 additions & 3 deletions agent/jvm/src/main/kotlin/io/mockk/ValueClassSupport.kt
Expand Up @@ -30,10 +30,19 @@ fun <T : Any> T.boxedValue(): Any? {
* @return class of boxed value, if this is value class, else just class of itself
*/
fun <T : Any> T.boxedClass(): KClass<*> {
if (!this::class.isValueClass()) return this::class
return this::class.boxedClass()
}

/**
* Get the KClass of boxed value if this is a value class.
*
* @return class of boxed value, if this is value class, else just class of itself
*/
fun KClass<*>.boxedClass(): KClass<*> {
if (!this.isValueClass()) return this

// get backing field
val backingField = this::class.valueField()
val backingField = this.valueField()

// get boxed value
return backingField.returnType.classifier as KClass<*>
Expand All @@ -60,7 +69,7 @@ private fun <T : Any> KClass<T>.valueField(): KProperty1<out T, *> {

private fun <T : Any> KClass<T>.isValueClass() = try {
this.isValue
} catch (_: UnsupportedOperationException) {
} catch (_: Throwable) {
false
}

Expand Down
1 change: 1 addition & 0 deletions dsl/common/src/main/kotlin/io/mockk/API.kt
Expand Up @@ -3600,6 +3600,7 @@ data class Call(
data class MethodDescription(
val name: String,
val returnType: KClass<*>,
val returnTypeNullable: Boolean,
val returnsUnit: Boolean,
val returnsNothing: Boolean,
val isSuspend: Boolean,
Expand Down
Expand Up @@ -3,7 +3,7 @@ package io.mockk.impl.instantiation
import kotlin.reflect.KClass

open class AnyValueGenerator {
open fun anyValue(cls: KClass<*>, orInstantiateVia: () -> Any?): Any? {
open fun anyValue(cls: KClass<*>, isNullable: Boolean, orInstantiateVia: () -> Any?): Any? {
return when (cls) {
Boolean::class -> false
Byte::class -> 0.toByte()
Expand Down
Expand Up @@ -17,7 +17,7 @@ class CommonCallRecorder(
val instantiator: AbstractInstantiator,
val signatureValueGenerator: SignatureValueGenerator,
val mockFactory: MockFactory,
val anyValueGenerator: AnyValueGenerator,
val anyValueGenerator: () -> AnyValueGenerator,
val safeToString: SafeToString,
val factories: CallRecorderFactories,
val initialState: (CommonCallRecorder) -> CallRecordingState,
Expand Down
Expand Up @@ -6,6 +6,7 @@ object WasNotCalled {
val method = MethodDescription(
"wasNot Called",
Unit::class,
false,
true,
false,
false,
Expand All @@ -15,4 +16,4 @@ object WasNotCalled {
-1,
false
)
}
}
Expand Up @@ -45,7 +45,7 @@ abstract class RecordingState(recorder: CommonCallRecorder) : CallRecordingState
@Suppress("UNCHECKED_CAST")
override fun <T : Any> matcher(matcher: Matcher<*>, cls: KClass<T>): T {
val signatureValue = recorder.signatureValueGenerator.signatureValue(cls) {
recorder.anyValueGenerator.anyValue(cls) {
recorder.anyValueGenerator().anyValue(cls, isNullable = false) {
recorder.instantiator.instantiate(cls)
} as T
}
Expand All @@ -67,7 +67,7 @@ abstract class RecordingState(recorder: CommonCallRecorder) : CallRecordingState
if (invocation.method.isToString()) {
recorder.stubRepo[invocation.self]?.toStr() ?: ""
} else {
recorder.anyValueGenerator.anyValue(retType) {
recorder.anyValueGenerator().anyValue(retType, invocation.method.returnTypeNullable) {
isTemporaryMock = true
recorder.mockFactory.temporaryMock(retType)
}
Expand Down
5 changes: 4 additions & 1 deletion mockk/common/src/main/kotlin/io/mockk/impl/stub/MockKStub.kt
Expand Up @@ -83,7 +83,10 @@ open class MockKStub(
return stdObjectFunctions(invocation.self, invocation.method, invocation.args) {
if (shouldRelax(invocation)) {
if (invocation.method.returnsUnit) return Unit
return gatewayAccess.anyValueGenerator.anyValue(invocation.method.returnType) {
return gatewayAccess.anyValueGenerator().anyValue(
invocation.method.returnType,
invocation.method.returnTypeNullable
) {
childMockK(invocation.allEqMatcher(), invocation.method.returnType)
}
} else {
Expand Down
Expand Up @@ -7,7 +7,7 @@ import io.mockk.impl.log.SafeToString

data class StubGatewayAccess(
val callRecorder: () -> CallRecorder,
val anyValueGenerator: AnyValueGenerator,
val anyValueGenerator: () -> AnyValueGenerator,
val stubRepository: StubRepository,
val safeToString: SafeToString,
val mockFactory: MockKGateway.MockFactory? = null
Expand Down
Expand Up @@ -12,111 +12,111 @@ class AnyValueGeneratorTest {

@Test
fun givenByteClassWhenRequestedForAnyValueThen0IsReturned() {
assertEquals(0.toByte(), generator.anyValue(Byte::class, failOnPassThrough))
assertEquals(0.toByte(), generator.anyValue(Byte::class, false, failOnPassThrough))
}

@Test
fun givenShortClassWhenRequestedForAnyValueThen0IsReturned() {
assertEquals(0.toShort(), generator.anyValue(Short::class, failOnPassThrough))
assertEquals(0.toShort(), generator.anyValue(Short::class, false, failOnPassThrough))
}

@Test
fun givenCharClassWhenRequestedForAnyValueThen0IsReturned() {
assertEquals(0.toChar(), generator.anyValue(Char::class, failOnPassThrough))
assertEquals(0.toChar(), generator.anyValue(Char::class, false, failOnPassThrough))
}

@Test
fun givenIntClassWhenRequestedForAnyValueThen0IsReturned() {
assertEquals(0, generator.anyValue(Int::class, failOnPassThrough))
assertEquals(0, generator.anyValue(Int::class, false, failOnPassThrough))
}

@Test
fun givenLongClassWhenRequestedForAnyValueThen0IsReturned() {
assertEquals(0L, generator.anyValue(Long::class, failOnPassThrough))
assertEquals(0L, generator.anyValue(Long::class, false, failOnPassThrough))
}

@Test
fun givenFloatClassWhenRequestedForAnyValueThen0IsReturned() {
assertEquals(0F, generator.anyValue(Float::class, failOnPassThrough))
assertEquals(0F, generator.anyValue(Float::class, false, failOnPassThrough))
}

@Test
fun givenDoubleClassWhenRequestedForAnyValueThen0IsReturned() {
assertEquals(0.0, generator.anyValue(Double::class, failOnPassThrough))
assertEquals(0.0, generator.anyValue(Double::class, false, failOnPassThrough))
}

@Test
fun givenStringClassWhenRequestedForAnyValueThenEmptyStringIsReturned() {
assertEquals("", generator.anyValue(String::class, failOnPassThrough))
assertEquals("", generator.anyValue(String::class, false, failOnPassThrough))
}

@Test
fun givenBooleanArrayClassWhenRequestedForAnyValueThenEmptyBooleanArrayIsReturned() {
assertArrayEquals(BooleanArray(0), generator.anyValue(BooleanArray::class, failOnPassThrough) as BooleanArray)
assertArrayEquals(BooleanArray(0), generator.anyValue(BooleanArray::class, false, failOnPassThrough) as BooleanArray)
}

@Test
fun givenByteArrayClassWhenRequestedForAnyValueThenEmptyByteArrayIsReturned() {
assertArrayEquals(ByteArray(0), generator.anyValue(ByteArray::class, failOnPassThrough) as ByteArray)
assertArrayEquals(ByteArray(0), generator.anyValue(ByteArray::class, false, failOnPassThrough) as ByteArray)
}

@Test
fun givenCharArrayClassWhenRequestedForAnyValueThenEmptyCharArrayIsReturned() {
assertArrayEquals(CharArray(0), generator.anyValue(CharArray::class, failOnPassThrough) as CharArray)
assertArrayEquals(CharArray(0), generator.anyValue(CharArray::class, false, failOnPassThrough) as CharArray)
}

@Test
fun givenShortArrayClassWhenRequestedForAnyValueThenEmptyShortArrayIsReturned() {
assertArrayEquals(ShortArray(0), generator.anyValue(ShortArray::class, failOnPassThrough) as ShortArray)
assertArrayEquals(ShortArray(0), generator.anyValue(ShortArray::class, false, failOnPassThrough) as ShortArray)
}

@Test
fun givenIntArrayClassWhenRequestedForAnyValueThenEmptyIntArrayIsReturned() {
assertArrayEquals(IntArray(0), generator.anyValue(IntArray::class, failOnPassThrough) as IntArray)
assertArrayEquals(IntArray(0), generator.anyValue(IntArray::class, false, failOnPassThrough) as IntArray)
}

@Test
fun givenLongArrayClassWhenRequestedForAnyValueThenEmptyLongArrayIsReturned() {
assertArrayEquals(LongArray(0), generator.anyValue(LongArray::class, failOnPassThrough) as LongArray)
assertArrayEquals(LongArray(0), generator.anyValue(LongArray::class, false, failOnPassThrough) as LongArray)
}

@Test
fun givenFloatArrayClassWhenRequestedForAnyValueThenEmptyFloatArrayIsReturned() {
assertArrayEquals(FloatArray(0), generator.anyValue(FloatArray::class, failOnPassThrough) as FloatArray, 1e-6f)
assertArrayEquals(FloatArray(0), generator.anyValue(FloatArray::class, false, failOnPassThrough) as FloatArray, 1e-6f)
}

@Test
fun givenDoubleArrayClassWhenRequestedForAnyValueThenEmptyDoubleArrayIsReturned() {
assertArrayEquals(DoubleArray(0), generator.anyValue(DoubleArray::class, failOnPassThrough) as DoubleArray, 1e-6)
assertArrayEquals(DoubleArray(0), generator.anyValue(DoubleArray::class, false, failOnPassThrough) as DoubleArray, 1e-6)
}

@Test
fun givenListClassWhenRequestedForAnyValueThenEmptyListIsReturned() {
assertEquals(listOf<Any>(), generator.anyValue(List::class, failOnPassThrough) as List<*>)
assertEquals(listOf<Any>(), generator.anyValue(List::class, false, failOnPassThrough) as List<*>)
}

@Test
fun givenMapClassWhenRequestedForAnyValueThenEmptyMapIsReturned() {
assertEquals(mapOf<Any, Any>(), generator.anyValue(Map::class, failOnPassThrough) as Map<*, *>)
assertEquals(mapOf<Any, Any>(), generator.anyValue(Map::class, false, failOnPassThrough) as Map<*, *>)
}

@Test
fun givenSetClassWhenRequestedForAnyValueThenEmptySetIsReturned() {
assertEquals(setOf<Any>(), generator.anyValue(Set::class, failOnPassThrough) as Set<*>)
assertEquals(setOf<Any>(), generator.anyValue(Set::class, false, failOnPassThrough) as Set<*>)
}

@Test
fun givenArrayListClassWhenRequestedForAnyValueThenEmptyArrayListIsReturned() {
assertEquals(arrayListOf<Any>(), generator.anyValue(ArrayList::class, failOnPassThrough) as ArrayList<*>)
assertEquals(arrayListOf<Any>(), generator.anyValue(ArrayList::class, false, failOnPassThrough) as ArrayList<*>)
}

@Test
fun givenHashMapClassWhenRequestedForAnyValueThenEmptyHashMapIsReturned() {
assertEquals(hashMapOf<Any, Any>(), generator.anyValue(HashMap::class, failOnPassThrough) as HashMap<*, *>)
assertEquals(hashMapOf<Any, Any>(), generator.anyValue(HashMap::class, false, failOnPassThrough) as HashMap<*, *>)
}

@Test
fun givenHashSetClassWhenRequestedForAnyValueThenEmptyHashSetIsReturned() {
assertEquals(hashSetOf<Any>(), generator.anyValue(HashSet::class, failOnPassThrough) as HashSet<*>)
assertEquals(hashSetOf<Any>(), generator.anyValue(HashSet::class, false, failOnPassThrough) as HashSet<*>)
}
}
Expand Up @@ -51,5 +51,12 @@ class RelaxedMockingTest {
assertEquals(2, slot.captured)
}

@Test
fun testRelaxedFunction() {
val block = mockk<() -> Unit>(relaxed = true)
block()
verify { block.invoke() }
}

private fun mockCls() = mockk<MockCls>(relaxUnitFun = true)
}
}
4 changes: 2 additions & 2 deletions mockk/js/src/main/kotlin/io/mockk/impl/JsMockKGateway.kt
Expand Up @@ -34,7 +34,7 @@ class JsMockKGateway : MockKGateway {
override val mockFactory: MockFactory = JsMockFactory(
stubRepo,
instantiator,
StubGatewayAccess({ callRecorder }, anyValueGenerator, stubRepo, safeToString)
StubGatewayAccess({ callRecorder }, { anyValueGenerator }, stubRepo, safeToString)
)

override val clearer = CommonClearer(stubRepo, safeToString)
Expand Down Expand Up @@ -89,7 +89,7 @@ class JsMockKGateway : MockKGateway {
instantiator,
signatureValueGenerator,
mockFactory,
anyValueGenerator,
{ anyValueGenerator },
safeToString,
callRecorderFactories,
{ recorder -> callRecorderFactories.answeringState(recorder) },
Expand Down
Expand Up @@ -87,6 +87,7 @@ internal class StubProxyHandler(
false,
false,
false,
false,
cls,
listOf(),
-1,
Expand All @@ -112,6 +113,7 @@ internal class StubProxyHandler(
false,
false,
false,
false,
cls,
listOf(),
-1,
Expand Down
2 changes: 1 addition & 1 deletion mockk/jvm/src/main/kotlin/io/mockk/ValueClassSupport.kt
Expand Up @@ -69,7 +69,7 @@ private fun <T : Any> KClass<T>.valueField(): KProperty1<out T, *> {

private fun <T : Any> KClass<T>.isValueClass() = try {
this.isValue
} catch (_: UnsupportedOperationException) {
} catch (_: Throwable) {
false
}

Expand Down
19 changes: 16 additions & 3 deletions mockk/jvm/src/main/kotlin/io/mockk/impl/JvmMockKGateway.kt
Expand Up @@ -55,11 +55,16 @@ class JvmMockKGateway : MockKGateway {
instanceFactoryRegistryIntrnl
)

val anyValueGenerator = JvmAnyValueGenerator(instantiator.instantiate(Void::class))
val anyValueGeneratorProvider: () -> AnyValueGenerator = {
if (anyValueGenerator == null) {
anyValueGenerator = anyValueGeneratorFactory.invoke(instantiator.instantiate(Void::class))
}
anyValueGenerator!!
}
val signatureValueGenerator = JvmSignatureValueGenerator(Random())


val gatewayAccess = StubGatewayAccess({ callRecorder }, anyValueGenerator, stubRepo, safeToString)
val gatewayAccess = StubGatewayAccess({ callRecorder }, anyValueGeneratorProvider, stubRepo, safeToString)

override val mockFactory: AbstractMockFactory = JvmMockFactory(
agentFactory.proxyMaker,
Expand Down Expand Up @@ -137,7 +142,7 @@ class JvmMockKGateway : MockKGateway {
instantiator,
signatureValueGenerator,
mockFactory,
anyValueGenerator,
anyValueGeneratorProvider,
safeToString,
callRecorderFactories,
{ recorder -> callRecorderFactories.answeringState(recorder) },
Expand Down Expand Up @@ -170,6 +175,14 @@ class JvmMockKGateway : MockKGateway {
}
}

private var anyValueGenerator: AnyValueGenerator? = null
var anyValueGeneratorFactory: (voidInstance: Any) -> JvmAnyValueGenerator =
{ voidInstance -> JvmAnyValueGenerator(voidInstance) }
set(value) {
anyValueGenerator = null
field = value
}

val defaultImplementation = JvmMockKGateway()
val defaultImplementationBuilder = { defaultImplementation }
}
Expand Down
Expand Up @@ -2,11 +2,11 @@ package io.mockk.impl.instantiation

import kotlin.reflect.KClass

class JvmAnyValueGenerator(
open class JvmAnyValueGenerator(
private val voidInstance: Any
) : AnyValueGenerator() {

override fun anyValue(cls: KClass<*>, orInstantiateVia: () -> Any?): Any? {
override fun anyValue(cls: KClass<*>, isNullable: Boolean, orInstantiateVia: () -> Any?): Any? {
return when (cls) {
Void.TYPE.kotlin -> voidInstance
Void::class -> voidInstance
Expand All @@ -28,7 +28,7 @@ class JvmAnyValueGenerator(
java.util.HashMap::class -> HashMap<Any, Any>()
java.util.HashSet::class -> HashSet<Any>()

else -> super.anyValue(cls) {
else -> super.anyValue(cls, isNullable) {
if (cls.java.isArray) {
java.lang.reflect.Array.newInstance(cls.java.componentType, 0)
} else {
Expand Down