Skip to content

Commit

Permalink
allow custom AnyValueGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
hrach committed Jun 23, 2021
1 parent a48a030 commit 070ab2d
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CommonCallRecorder(
val instantiator: AbstractInstantiator,
val signatureValueGenerator: SignatureValueGenerator,
val mockFactory: MockFactory,
val anyValueGenerator: AnyValueGenerator,
val anyValueGenerator: () -> AnyValueGenerator,
val safeToString: SafeToString,
val factories: CallRecorderFactories,
val initialState: (CommonCallRecorder) -> CallRecordingState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ abstract class RecordingState(recorder: CommonCallRecorder) : CallRecordingState
@Suppress("UNCHECKED_CAST")
override fun <T : Any> matcher(matcher: Matcher<*>, cls: KClass<T>): T {
val signatureValue = recorder.signatureValueGenerator.signatureValue(cls) {
recorder.anyValueGenerator.anyValue(cls, isNullable = false) {
recorder.anyValueGenerator().anyValue(cls, isNullable = false) {
recorder.instantiator.instantiate(cls)
} as T
}
Expand All @@ -67,7 +67,7 @@ abstract class RecordingState(recorder: CommonCallRecorder) : CallRecordingState
if (invocation.method.isToString()) {
recorder.stubRepo[invocation.self]?.toStr() ?: ""
} else {
recorder.anyValueGenerator.anyValue(retType, invocation.method.returnTypeNullable) {
recorder.anyValueGenerator().anyValue(retType, invocation.method.returnTypeNullable) {
isTemporaryMock = true
recorder.mockFactory.temporaryMock(retType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ open class MockKStub(
return stdObjectFunctions(invocation.self, invocation.method, invocation.args) {
if (shouldRelax(invocation)) {
if (invocation.method.returnsUnit) return Unit
return gatewayAccess.anyValueGenerator.anyValue(
return gatewayAccess.anyValueGenerator().anyValue(
invocation.method.returnType,
invocation.method.returnTypeNullable
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import io.mockk.impl.log.SafeToString

data class StubGatewayAccess(
val callRecorder: () -> CallRecorder,
val anyValueGenerator: AnyValueGenerator,
val anyValueGenerator: () -> AnyValueGenerator,
val stubRepository: StubRepository,
val safeToString: SafeToString,
val mockFactory: MockKGateway.MockFactory? = null
Expand Down
4 changes: 2 additions & 2 deletions mockk/js/src/main/kotlin/io/mockk/impl/JsMockKGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class JsMockKGateway : MockKGateway {
override val mockFactory: MockFactory = JsMockFactory(
stubRepo,
instantiator,
StubGatewayAccess({ callRecorder }, anyValueGenerator, stubRepo, safeToString)
StubGatewayAccess({ callRecorder }, { anyValueGenerator }, stubRepo, safeToString)
)

override val clearer = CommonClearer(stubRepo, safeToString)
Expand Down Expand Up @@ -89,7 +89,7 @@ class JsMockKGateway : MockKGateway {
instantiator,
signatureValueGenerator,
mockFactory,
anyValueGenerator,
{ anyValueGenerator },
safeToString,
callRecorderFactories,
{ recorder -> callRecorderFactories.answeringState(recorder) },
Expand Down
19 changes: 16 additions & 3 deletions mockk/jvm/src/main/kotlin/io/mockk/impl/JvmMockKGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@ class JvmMockKGateway : MockKGateway {
instanceFactoryRegistryIntrnl
)

val anyValueGenerator = JvmAnyValueGenerator(instantiator.instantiate(Void::class))
val anyValueGeneratorProvider: () -> AnyValueGenerator = {
if (anyValueGenerator == null) {
anyValueGenerator = anyValueGeneratorFactory.invoke(instantiator.instantiate(Void::class))
}
anyValueGenerator!!
}
val signatureValueGenerator = JvmSignatureValueGenerator(Random())


val gatewayAccess = StubGatewayAccess({ callRecorder }, anyValueGenerator, stubRepo, safeToString)
val gatewayAccess = StubGatewayAccess({ callRecorder }, anyValueGeneratorProvider, stubRepo, safeToString)

override val mockFactory: AbstractMockFactory = JvmMockFactory(
agentFactory.proxyMaker,
Expand Down Expand Up @@ -137,7 +142,7 @@ class JvmMockKGateway : MockKGateway {
instantiator,
signatureValueGenerator,
mockFactory,
anyValueGenerator,
anyValueGeneratorProvider,
safeToString,
callRecorderFactories,
{ recorder -> callRecorderFactories.answeringState(recorder) },
Expand Down Expand Up @@ -170,6 +175,14 @@ class JvmMockKGateway : MockKGateway {
}
}

private var anyValueGenerator: AnyValueGenerator? = null
var anyValueGeneratorFactory: (voidInstance: Any) -> JvmAnyValueGenerator =
{ voidInstance -> JvmAnyValueGenerator(voidInstance) }
set(value) {
anyValueGenerator = null
field = value
}

val defaultImplementation = JvmMockKGateway()
val defaultImplementationBuilder = { defaultImplementation }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.mockk.impl.instantiation

import kotlin.reflect.KClass

class JvmAnyValueGenerator(
open class JvmAnyValueGenerator(
private val voidInstance: Any
) : AnyValueGenerator() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class JvmMockFactory(
"This can help if it's last call in the chain"
}

gatewayAccess.anyValueGenerator.anyValue(cls, isNullable = false) {
gatewayAccess.anyValueGenerator().anyValue(cls, isNullable = false) {
instantiator.instantiate(cls)
} as T
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.mockk.it

import io.mockk.impl.JvmMockKGateway
import io.mockk.impl.instantiation.JvmAnyValueGenerator
import io.mockk.mockk
import kotlinx.coroutines.runBlocking
import kotlin.reflect.KClass
import kotlin.test.Test
import kotlin.test.assertEquals

@Suppress("UNUSED_PARAMETER")
class NullableValueGeneratorTest {
class NullableValueGenerator(
voidInstance: Any
) : JvmAnyValueGenerator(voidInstance) {
override fun anyValue(cls: KClass<*>, isNullable: Boolean, orInstantiateVia: () -> Any?): Any? {
if (isNullable) return null
return super.anyValue(cls, isNullable, orInstantiateVia)
}
}

@Test
fun testRelaxedMockReturnsNull() {
JvmMockKGateway.anyValueGeneratorFactory = { voidInstance ->
NullableValueGenerator(voidInstance)
}

class Bar

@Suppress("RedundantNullableReturnType", "RedundantSuspendModifier")
class Foo {
val property: Bar? = Bar()
val isEnabled: Boolean? = false
fun getSomething(): Bar? = Bar()
suspend fun getOtherThing(): Bar? = Bar()
}

val mock = mockk<Foo>(relaxed = true)
assertEquals(null, mock.property)
assertEquals(null, mock.isEnabled)
assertEquals(null, mock.getSomething())
assertEquals(null, runBlocking { mock.getOtherThing() })

JvmMockKGateway.anyValueGeneratorFactory = { voidInstance ->
JvmAnyValueGenerator(voidInstance)
}
}
}

0 comments on commit 070ab2d

Please sign in to comment.