From 60c2812bd53719e389d627102f0e3705f8006eb1 Mon Sep 17 00:00:00 2001 From: Mitchell Yuwono Date: Sat, 26 Mar 2022 12:04:40 +1100 Subject: [PATCH 1/2] implement trampolines for flatmap, map, filter, and merge. Remove suspension point allocation in single shot builder. --- .../io/kotest/property/arbitrary/builders.kt | 69 ++++++++----------- .../io/kotest/property/arbitrary/filter.kt | 22 +++--- .../io/kotest/property/arbitrary/map.kt | 69 +++++++++++++++---- .../io/kotest/property/arbitrary/merge.kt | 29 ++++---- .../kotest/property/arbitrary/BuilderTest.kt | 12 ++++ .../kotest/property/arbitrary/FilterTest.kt | 14 ++++ .../kotest/property/arbitrary/FlatMapTest.kt | 15 ++++ .../kotest/property/arbitrary/MapTest.kt | 11 +++ .../kotest/property/exhaustive/MergeTest.kt | 15 ++++ 9 files changed, 180 insertions(+), 76 deletions(-) diff --git a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/builders.kt b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/builders.kt index 0c9cfb2efae..395086224f7 100644 --- a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/builders.kt +++ b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/builders.kt @@ -10,7 +10,6 @@ import kotlin.coroutines.Continuation import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext import kotlin.coroutines.RestrictsSuspension -import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED import kotlin.coroutines.intrinsics.startCoroutineUninterceptedOrReturn import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn import kotlin.coroutines.resume @@ -213,8 +212,8 @@ fun arbitraryBuilder( edgecaseFn: EdgecaseFn? = null, builderFn: suspend ArbitraryBuilderContext.(RandomSource) -> A ): Arb = object : Arb() { - override fun edgecase(rs: RandomSource): A? = singleShotArb().edgecase(rs) - override fun sample(rs: RandomSource): Sample = singleShotArb().sample(rs) + override fun edgecase(rs: RandomSource): A? = singleShotArb(SingleShotGenerationMode.Edgecase, rs).edgecase(rs) + override fun sample(rs: RandomSource): Sample = singleShotArb(SingleShotGenerationMode.Sample, rs).sample(rs) override val classifier: Classifier? = classifier /** @@ -228,13 +227,13 @@ fun arbitraryBuilder( * will provide another single shot Arb. Hence the reason why this function is invoked * on every call to [sample] / [edgecase]. */ - private fun singleShotArb(): Arb { - val restrictedContinuation = SingleShotArbContinuation.Restricted { + private fun singleShotArb(mode: SingleShotGenerationMode, rs: RandomSource): Arb { + val restrictedContinuation = SingleShotArbContinuation.Restricted(mode, rs) { /** * At the end of the suspension we got a generated value [A] as a comprehension result. * This value can either be a sample, or an edgecase. */ - val value: A = builderFn(randomSource.bind()) + val value: A = builderFn(rs) /** * Here we point A into an Arb with the appropriate enrichments including @@ -263,8 +262,8 @@ suspend fun suspendArbitraryBuilder( fn: suspend GenerateArbitraryBuilderContext.(RandomSource) -> A ): Arb = suspendCoroutineUninterceptedOrReturn { cont -> val arb = object : Arb() { - override fun edgecase(rs: RandomSource): A? = singleShotArb().edgecase(rs) - override fun sample(rs: RandomSource): Sample = singleShotArb().sample(rs) + override fun edgecase(rs: RandomSource): A? = singleShotArb(SingleShotGenerationMode.Edgecase, rs).edgecase(rs) + override fun sample(rs: RandomSource): Sample = singleShotArb(SingleShotGenerationMode.Sample, rs).sample(rs) override val classifier: Classifier? = classifier /** @@ -278,13 +277,13 @@ suspend fun suspendArbitraryBuilder( * will provide another single shot Arb. Hence the reason why this function is invoked * on every call to [sample] / [edgecase]. */ - private fun singleShotArb(): Arb { - val suspendableContinuation = SingleShotArbContinuation.Suspendedable(cont.context) { + private fun singleShotArb(genMode: SingleShotGenerationMode, rs: RandomSource): Arb { + val suspendableContinuation = SingleShotArbContinuation.Suspendedable(genMode, rs, cont.context) { /** * At the end of the suspension we got a generated value [A] as a comprehension result. * This value can either be a sample, or an edgecase. */ - val value: A = fn(randomSource.bind()) + val value: A = fn(rs) /** * Here we point A into an Arb with the appropriate enrichments including @@ -303,13 +302,6 @@ suspend fun suspendArbitraryBuilder( cont.resume(arb) } -/** - * passthrough arb to extract the propagated RandomSource. It's important to pass rs through both the - * sample and the edgecases to ensure that flatMap can evaluate on both [sample] and [edgecase] - * regardless of any absence of edgecases in the firstly bound arb. - */ -private val randomSource: Arb = ArbitraryBuilder.create { it }.withEdgecaseFn { it }.build() - typealias SampleFn = (RandomSource) -> A typealias EdgecaseFn = (RandomSource) -> A? @@ -352,18 +344,29 @@ interface ArbitraryBuilderContext : BaseArbitraryBuilderSyntax interface GenerateArbitraryBuilderContext : BaseArbitraryBuilderSyntax +enum class SingleShotGenerationMode { Edgecase, Sample } + sealed class SingleShotArbContinuation( override val context: CoroutineContext, + private val generationMode: SingleShotGenerationMode, + private val randomSource: RandomSource, private val fn: suspend F.() -> Arb ) : Continuation>, BaseArbitraryBuilderSyntax { + class Restricted( + genMode: SingleShotGenerationMode, + rs: RandomSource, fn: suspend ArbitraryBuilderContext.() -> Arb - ) : SingleShotArbContinuation(EmptyCoroutineContext, fn), ArbitraryBuilderContext + ) : SingleShotArbContinuation(EmptyCoroutineContext, genMode, rs, fn), + ArbitraryBuilderContext class Suspendedable( + genMode: SingleShotGenerationMode, + rs: RandomSource, override val context: CoroutineContext, fn: suspend GenerateArbitraryBuilderContext.() -> Arb - ) : SingleShotArbContinuation(context, fn), GenerateArbitraryBuilderContext + ) : SingleShotArbContinuation(context, genMode, rs, fn), + GenerateArbitraryBuilderContext private lateinit var returnedArb: Arb private var hasExecuted: Boolean = false @@ -373,24 +376,9 @@ sealed class SingleShotArbContinuation( result.map { resultArb -> returnedArb = resultArb }.getOrThrow() } - override suspend fun Arb.bind(): T = suspendCoroutineUninterceptedOrReturn { c -> - // we call flatMap on the bound arb, and then returning the `returnedArb`, without modification - returnedArb = this.flatMap { value: T -> - /** - * we resume the suspension with the value passed inside the flatMap function. - * this [value] can be either sample or edgecases. This is important - * because from the point of view of a user of kotest, when we talk about transformation, - * we care about transforming the generated value of this arb for both sample and edgecases. - */ - c.resume(value) - returnedArb - } - /** - * Notice this block returns the special COROUTINE_SUSPENDED value - * this means the Continuation provided to the block shall be resumed by invoking [resumeWith] - * at some moment in the future when the result becomes available to resume the computation. - */ - COROUTINE_SUSPENDED + override suspend fun Arb.bind(): T = when (generationMode) { + SingleShotGenerationMode.Edgecase -> this.edgecase(randomSource) ?: this.sample(randomSource).value + SingleShotGenerationMode.Sample -> this.sample(randomSource).value } /** @@ -404,7 +392,10 @@ sealed class SingleShotArbContinuation( */ fun F.createSingleShotArb(): Arb { require(!hasExecuted) { "continuation has already been executed, if you see this error please raise a bug report" } - fn.startCoroutineUninterceptedOrReturn(this@createSingleShotArb, this@SingleShotArbContinuation) + val result = fn.startCoroutineUninterceptedOrReturn(this@createSingleShotArb, this@SingleShotArbContinuation) + + @Suppress("UNCHECKED_CAST") + returnedArb = result as Arb return returnedArb } } diff --git a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt index 7f7b5a1b4c3..fe4cff629ee 100644 --- a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt +++ b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt @@ -12,17 +12,19 @@ import io.kotest.property.filter * predicate. This gen will continue to request elements from the underlying gen until one satisfies * the predicate. */ -fun Arb.filter(predicate: (A) -> Boolean): Arb = object : Arb() { +fun Arb.filter(predicate: (A) -> Boolean): Arb = trampoline { sampleA -> + object : Arb() { + override fun edgecase(rs: RandomSource): A? = + sequenceOf(sampleA.value) + .plus(generateSequence { this@filter.edgecase(rs) }) + .take(PropertyTesting.maxFilterAttempts) + .filter(predicate) + .firstOrNull() - override fun edgecase(rs: RandomSource): A? = - generateSequence { this@filter.edgecase(rs) } - .take(PropertyTesting.maxFilterAttempts) - .filter(predicate) - .firstOrNull() - - override fun sample(rs: RandomSource): Sample { - val sample = this@filter.samples(rs).filter { predicate(it.value) }.first() - return Sample(sample.value, sample.shrinks.filter(predicate) ?: RTree({ sample.value })) + override fun sample(rs: RandomSource): Sample { + val sample = sequenceOf(sampleA).plus(this@filter.samples(rs)).filter { predicate(it.value) }.first() + return Sample(sample.value, sample.shrinks.filter(predicate) ?: RTree({ sample.value })) + } } } diff --git a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt index b91a59d2d8c..e13f78686a5 100644 --- a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt +++ b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt @@ -8,26 +8,69 @@ import io.kotest.property.map /** * Returns a new [Arb] which takes its elements from the receiver and maps them using the supplied function. */ -fun Arb.map(f: (A) -> B): Arb = object : Arb() { - - override fun edgecase(rs: RandomSource): B? = this@map.edgecase(rs)?.let(f) - - override fun sample(rs: RandomSource): Sample = - this@map.sample(rs).let { - Sample(f(it.value), it.shrinks.map(f)) +fun Arb.map(fn: (A) -> B): Arb = trampoline { sampleA -> + object : Arb() { + override fun edgecase(rs: RandomSource): B? = fn(sampleA.value) + override fun sample(rs: RandomSource): Sample { + val value = fn(sampleA.value) + val shrinks = sampleA.shrinks.map(fn) + return Sample(value, shrinks) } + } } /** * Returns a new [Arb] which takes its elements from the receiver and maps them using the supplied function. */ -fun Arb.flatMap(f: (A) -> Arb): Arb = object : Arb() { +fun Arb.flatMap(fn: (A) -> Arb): Arb = trampoline { fn(it.value) } - override fun edgecase(rs: RandomSource): B? { - // generate an edge case, map it to another arb, and generate an edge case again - val a = this@flatMap.edgecase(rs) ?: this@flatMap.next(rs) - return f(a).edgecase(rs) +/** + * Returns a new [TrampolineArb] from the receiver [Arb] which composes the operations of [next] lambda + * using a trampoline method. This allows [next] function to be executed without exhausting call stack. + */ +internal fun Arb.trampoline(next: (Sample) -> Arb): Arb = when (this) { + is TrampolineArb -> thunk(next) + else -> TrampolineArb(this).thunk(next) +} + +/** + * The [TrampolineArb] is a special Arb that exchanges call stack with heap. + * In a nutshell, this arb stores command chains to be applied to the original arb inside a list. + * This technique is an imperative reduction of Free Monads. This eliminates the need of creating intermediate + * Trampoline Monad and tail-recursive function on those which can be expensive. + * This minimizes the amount of code and unnecessary object allocation during sample generation in the expense of typesafety. + * + * This is an internal implementation. Do not use this TrampolineArb as is and please do not expose this + * to users outside of the library. For library maintainers, please use the [Arb.trampoline] extension function. + * The extension function will provide some type-guardrails to workaround the loss of types within this Arb. + */ +@Suppress("UNCHECKED_CAST") +internal class TrampolineArb(val first: Arb) : Arb() { + private val commandList: MutableList<(Sample) -> Arb> = mutableListOf() + + fun thunk(fn: (Sample) -> Arb): TrampolineArb { + val nextFn: (Sample) -> Arb = { fn(it) } + commandList.add(nextFn as (Sample) -> Arb) + return this as TrampolineArb + } + + override fun edgecase(rs: RandomSource): A? { + var currentArb = first as Arb + for (command in commandList) { + val currentEdge = currentArb.edgecase(rs) ?: currentArb.sample(rs).value + currentArb = command(Sample(currentEdge)) + } + + return currentArb.edgecase(rs) as A? } - override fun sample(rs: RandomSource): Sample = f(this@flatMap.sample(rs).value).sample(rs) + override fun sample(rs: RandomSource): Sample { + var currentArb = first as Arb + for (command in commandList) { + val currentSample = currentArb.sample(rs) + currentArb = command(currentSample) + } + + return currentArb.sample(rs) as Sample + } } diff --git a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/merge.kt b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/merge.kt index 64c485c3ef5..3b998311fbf 100644 --- a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/merge.kt +++ b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/merge.kt @@ -20,20 +20,21 @@ import io.kotest.property.Sample * @param other the arg to merge with this one * @return the merged arg. */ -fun Arb.merge(other: Gen): Arb = object : Arb() { - - override fun edgecase(rs: RandomSource): A? = when (other) { - is Arb -> listOf(this@merge, other).random(rs.random).edgecase(rs) - is Exhaustive -> this@merge.edgecase(rs) - } +fun Arb.merge(other: Gen): Arb = trampoline { sampleA -> + object : Arb() { + override fun edgecase(rs: RandomSource): A? = when (other) { + is Arb -> if (rs.random.nextBoolean()) sampleA.value else other.edgecase(rs) + is Exhaustive -> sampleA.value + } - override fun sample(rs: RandomSource): Sample = - if (rs.random.nextBoolean()) { - this@merge.sample(rs) - } else { - when (other) { - is Arb -> other.sample(rs) - is Exhaustive -> other.toArb().sample(rs) + override fun sample(rs: RandomSource): Sample = + if (rs.random.nextBoolean()) { + sampleA + } else { + when (other) { + is Arb -> other.sample(rs) + is Exhaustive -> other.toArb().sample(rs) + } } - } + } } diff --git a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/BuilderTest.kt b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/BuilderTest.kt index bbcbe8bf697..0d4d61ddb96 100644 --- a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/BuilderTest.kt +++ b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/BuilderTest.kt @@ -1,5 +1,6 @@ package com.sksamuel.kotest.property.arbitrary +import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.collections.shouldContainExactly @@ -55,6 +56,17 @@ class BuilderTest : FunSpec() { } context("arbitrary builder using restricted continuation") { + test("should be stack safe") { + val arb: Arb = arbitrary { + (1..100000).map { + Arb.int().bind() + }.last() + } + + val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234)) } + result shouldBe -1486934023 + } + test("should be equivalent to chaining flatMaps") { val arbFlatMaps: Arb = Arb.string(5, Codepoint.alphanumeric()).withEdgecases("edge1", "edge2").flatMap { first -> diff --git a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt index de6273d66cc..a98f764af92 100644 --- a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt +++ b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt @@ -1,16 +1,21 @@ package com.sksamuel.kotest.property.arbitrary import io.kotest.assertions.throwables.shouldNotThrow +import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.core.spec.style.FunSpec import io.kotest.inspectors.forAll import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.collections.shouldNotBeIn +import io.kotest.matchers.shouldBe import io.kotest.property.Arb import io.kotest.property.EdgeConfig import io.kotest.property.RandomSource import io.kotest.property.Sample import io.kotest.property.arbitrary.filter import io.kotest.property.arbitrary.int +import io.kotest.property.arbitrary.map +import io.kotest.property.arbitrary.of +import io.kotest.property.arbitrary.single import io.kotest.property.arbitrary.take import io.kotest.property.arbitrary.withEdgecases @@ -54,4 +59,13 @@ class FilterTest : FunSpec({ } } } + + test("Arb.filter composition should not exhaust call stack") { + var arb: Arb = Arb.of(0, 1) + repeat(10000) { + arb = arb.filter { it == 0 } + } + val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234L)) } + result shouldBe 0 + } }) diff --git a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt index e274ddb7177..a363b131b27 100644 --- a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt +++ b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt @@ -1,7 +1,9 @@ package com.sksamuel.kotest.property.arbitrary +import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.shouldBe import io.kotest.property.Arb import io.kotest.property.EdgeConfig import io.kotest.property.RandomSource @@ -86,5 +88,18 @@ class FlatMapTest : FunSpec() { 22 ) } + + test("Arb.flatMap composition should not exhaust call stack") { + var arb: Arb = Arb.int(-3..3) + repeat(10000) { + arb = arb.flatMap { value -> + Arb.int(-3..3).flatMap { + Arb.of(value + it) + } + } + } + val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234L)) } + result shouldBe 49 + } } } diff --git a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/MapTest.kt b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/MapTest.kt index d444d63862b..664bef897b6 100644 --- a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/MapTest.kt +++ b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/MapTest.kt @@ -1,5 +1,6 @@ package com.sksamuel.kotest.property.arbitrary +import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.shouldBe @@ -10,6 +11,7 @@ import io.kotest.property.arbitrary.IntShrinker import io.kotest.property.arbitrary.arbitrary import io.kotest.property.arbitrary.int import io.kotest.property.arbitrary.map +import io.kotest.property.arbitrary.of import io.kotest.property.arbitrary.single import io.kotest.property.arbitrary.withEdgecases import java.util.concurrent.atomic.AtomicInteger @@ -47,4 +49,13 @@ class MapTest : FunSpec({ "120" ) } + + test("Arb.map composition should not exhaust call stack") { + var arb: Arb = Arb.of(0) + repeat(10000) { + arb = arb.map { value -> value + 1 } + } + val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234L)) } + result shouldBe 10000 + } }) diff --git a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/exhaustive/MergeTest.kt b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/exhaustive/MergeTest.kt index c364b75ec6c..f5a1764c158 100644 --- a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/exhaustive/MergeTest.kt +++ b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/exhaustive/MergeTest.kt @@ -1,7 +1,13 @@ package com.sksamuel.kotest.property.exhaustive +import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.shouldBe +import io.kotest.property.Arb +import io.kotest.property.RandomSource +import io.kotest.property.arbitrary.merge +import io.kotest.property.arbitrary.of +import io.kotest.property.arbitrary.single import io.kotest.property.exhaustive.exhaustive import io.kotest.property.exhaustive.merge @@ -27,4 +33,13 @@ class MergeTest : FunSpec({ Common.Bar(6) ) } + + test("Arb.merge composition should not exhaust call stack") { + var arb: Arb = Arb.of(0) + repeat(100) { + arb = arb.merge(Arb.of(1)) + } + val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234L)) } + result shouldBe 1 + } }) From ba11fb5f75ffe9d9e4efdd34ab591cc83934cb8d Mon Sep 17 00:00:00 2001 From: Mitchell Yuwono Date: Sat, 26 Mar 2022 14:07:35 +1100 Subject: [PATCH 2/2] make sure flatmap preserves immutability --- .../io/kotest/property/arbitrary/map.kt | 50 +++++++++---------- .../kotest/property/arbitrary/FlatMapTest.kt | 15 ++++++ 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt index e13f78686a5..cb3c6c6116c 100644 --- a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt +++ b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/map.kt @@ -29,7 +29,7 @@ fun Arb.flatMap(fn: (A) -> Arb): Arb = trampoline { fn(it.value) * using a trampoline method. This allows [next] function to be executed without exhausting call stack. */ internal fun Arb.trampoline(next: (Sample) -> Arb): Arb = when (this) { - is TrampolineArb -> thunk(next) + is TrampolineArb -> this.thunk(next) else -> TrampolineArb(this).thunk(next) } @@ -45,32 +45,32 @@ internal fun Arb.trampoline(next: (Sample) -> Arb): Arb = whe * The extension function will provide some type-guardrails to workaround the loss of types within this Arb. */ @Suppress("UNCHECKED_CAST") -internal class TrampolineArb(val first: Arb) : Arb() { - private val commandList: MutableList<(Sample) -> Arb> = mutableListOf() +internal class TrampolineArb private constructor( + private val first: Arb, + commands: List<(Sample) -> Arb> +) : Arb() { + constructor(first: Arb) : this(first, emptyList()) - fun thunk(fn: (Sample) -> Arb): TrampolineArb { - val nextFn: (Sample) -> Arb = { fn(it) } - commandList.add(nextFn as (Sample) -> Arb) - return this as TrampolineArb - } - - override fun edgecase(rs: RandomSource): A? { - var currentArb = first as Arb - for (command in commandList) { - val currentEdge = currentArb.edgecase(rs) ?: currentArb.sample(rs).value - currentArb = command(Sample(currentEdge)) - } + private val commandList: MutableList<(Sample) -> Arb> = commands.toMutableList() - return currentArb.edgecase(rs) as A? - } + fun thunk(fn: (Sample) -> Arb): TrampolineArb = + TrampolineArb( + first, + commandList.toList() + (fn as (Sample) -> Arb) + ) as TrampolineArb - override fun sample(rs: RandomSource): Sample { - var currentArb = first as Arb - for (command in commandList) { - val currentSample = currentArb.sample(rs) - currentArb = command(currentSample) - } + override fun edgecase(rs: RandomSource): A? = + commandList + .fold(first as Arb) { currentArb, next -> + val currentEdge = currentArb.edgecase(rs) ?: currentArb.sample(rs).value + next(Sample(currentEdge)) + } + .edgecase(rs) as A? - return currentArb.sample(rs) as Sample - } + override fun sample(rs: RandomSource): Sample = + commandList + .fold(first as Arb) { currentArb, next -> + next(currentArb.sample(rs)) + } + .sample(rs) as Sample } diff --git a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt index a363b131b27..a26ed5124cf 100644 --- a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt +++ b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FlatMapTest.kt @@ -4,6 +4,7 @@ import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldNotBeSameInstanceAs import io.kotest.property.Arb import io.kotest.property.EdgeConfig import io.kotest.property.RandomSource @@ -101,5 +102,19 @@ class FlatMapTest : FunSpec() { val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234L)) } result shouldBe 49 } + + test("should yield a new immutable arb") { + val firstArb: Arb = Arb.int(-3..3) + val secondArb: Arb = firstArb.flatMap { Arb.int(10..30) } + val thirdArb: Arb = secondArb.flatMap { Arb.string(3, Codepoint.alphanumeric()) } + + firstArb shouldNotBeSameInstanceAs secondArb + firstArb shouldNotBeSameInstanceAs thirdArb + secondArb shouldNotBeSameInstanceAs thirdArb + + firstArb.single(RandomSource.seeded(1234L)) shouldBe 3 + secondArb.single(RandomSource.seeded(1234L)) shouldBe 28 + thirdArb.single(RandomSource.seeded(1234L)) shouldBe "tID" + } } }