diff --git a/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedClassTest.kt b/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedClassTest.kt index efa39cab2..631532156 100644 --- a/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedClassTest.kt +++ b/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedClassTest.kt @@ -2,6 +2,8 @@ package io.mockk.it import io.mockk.every import io.mockk.mockk +import org.junit.jupiter.api.condition.DisabledForJreRange +import org.junit.jupiter.api.condition.JRE import kotlin.test.Test import kotlin.test.assertEquals @@ -30,6 +32,18 @@ class SealedClassTest { assertEquals(Leaf(1), result) } + @Test + fun serviceTakesSealedClassAsInput() { + val formattedNode = "Formatted node" + val factory = mockk { + every { format(any()) } answers { formattedNode } + } + + val result = factory.format(Root(0)) + + assertEquals(formattedNode, result) + } + companion object { sealed class Node @@ -39,10 +53,14 @@ class SealedClassTest { interface Factory { fun create(): Node + + fun format(node: Node): String } class FactoryImpl : Factory { override fun create(): Node = Root(0) + + override fun format(node: Node): String = node.toString() } } diff --git a/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedInterfaceTest.kt b/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedInterfaceTest.kt index 91f727abf..5001fe6aa 100644 --- a/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedInterfaceTest.kt +++ b/modules/mockk/src/commonTest/kotlin/io/mockk/it/SealedInterfaceTest.kt @@ -2,6 +2,8 @@ package io.mockk.it import io.mockk.every import io.mockk.mockk +import org.junit.jupiter.api.condition.DisabledForJreRange +import org.junit.jupiter.api.condition.JRE import kotlin.test.Test import kotlin.test.assertEquals @@ -30,6 +32,18 @@ class SealedInterfaceTest { assertEquals(Leaf(1), result) } + @Test + fun serviceTakesSealedInterfaceAsInput() { + val formattedNode = "Formatted node" + val factory = mockk { + every { format(any()) } answers { formattedNode } + } + + val result = factory.format(Root(0)) + + assertEquals(formattedNode, result) + } + companion object { sealed interface Node @@ -39,10 +53,14 @@ class SealedInterfaceTest { interface Factory { fun create(): Node + + fun format(node: Node): String } class FactoryImpl : Factory { override fun create(): Node = Root(0) + + override fun format(node: Node): String = node.toString() } } diff --git a/modules/mockk/src/jvmMain/kotlin/io/mockk/impl/recording/JvmSignatureValueGenerator.kt b/modules/mockk/src/jvmMain/kotlin/io/mockk/impl/recording/JvmSignatureValueGenerator.kt index 20a0278e2..e70e4e6b8 100644 --- a/modules/mockk/src/jvmMain/kotlin/io/mockk/impl/recording/JvmSignatureValueGenerator.kt +++ b/modules/mockk/src/jvmMain/kotlin/io/mockk/impl/recording/JvmSignatureValueGenerator.kt @@ -24,24 +24,34 @@ class JvmSignatureValueGenerator(val rnd: Random) : SignatureValueGenerator { return constructor.call(valueSig) } - return cls.cast( - when (cls) { - java.lang.Boolean::class -> rnd.nextBoolean() - java.lang.Byte::class -> rnd.nextInt().toByte() - java.lang.Short::class -> rnd.nextInt().toShort() - java.lang.Character::class -> rnd.nextInt().toChar() - java.lang.Integer::class -> rnd.nextInt() - java.lang.Long::class -> rnd.nextLong() - java.lang.Float::class -> rnd.nextFloat() - java.lang.Double::class -> rnd.nextDouble() - java.lang.String::class -> rnd.nextLong().toString(16) + return cls.cast(instantiate(cls, anyValueGeneratorProvider, instantiator)) + } + + private fun instantiate( + cls: KClass, + anyValueGeneratorProvider: () -> AnyValueGenerator, + instantiator: AbstractInstantiator + ): Any = when (cls) { + Boolean::class -> rnd.nextBoolean() + Byte::class -> rnd.nextInt().toByte() + Short::class -> rnd.nextInt().toShort() + Character::class -> rnd.nextInt().toChar() + Integer::class -> rnd.nextInt() + Long::class -> rnd.nextLong() + Float::class -> rnd.nextFloat() + Double::class -> rnd.nextDouble() + String::class -> rnd.nextLong().toString(16) - else -> - @Suppress("UNCHECKED_CAST") - anyValueGeneratorProvider().anyValue(cls, isNullable = false) { - instantiator.instantiate(cls) - } as T + else -> + if (cls.isSealed) { + cls.sealedSubclasses.firstNotNullOfOrNull { + instantiate(it, anyValueGeneratorProvider, instantiator) + } ?: error("Unable to create proxy for sealed class $cls, available subclasses: ${cls.sealedSubclasses}") + } else { + @Suppress("UNCHECKED_CAST") + anyValueGeneratorProvider().anyValue(cls, isNullable = false) { + instantiator.instantiate(cls) + } as T } - ) } }