From e0e13d79c60d72dee2ae417c4b7f320437f1fc6a Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 23 May 2021 08:09:41 +0200 Subject: [PATCH] Create a distribution by mapping the output of another one This is useful if consumers are to be given an opaque type implementing the Distribution trait, but the output of the provided implementations needs additional post processing, e.g. to attach compile time units of measurement. --- CHANGELOG.md | 1 + src/distributions/mod.rs | 60 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4945f5e8390..e8baa39e002 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. ### Additions - Use const-generics to support arrays of all sizes (#1104) - Implement `Clone` and `Copy` for `Alphanumeric` (#1126) +- Add `Distribution::map` to derive a distribution using a closure (#1129) ### Other - Reorder asserts in `Uniform` float distributions for easier debugging of non-finite arguments diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 8171e30e43c..d1ae30c9402 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -199,6 +199,35 @@ pub trait Distribution { phantom: ::core::marker::PhantomData, } } + + /// Create a distribution of values of 'S' by mapping the output of `Self` + /// through the closure `F` + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::distributions::{Distribution, Uniform}; + /// + /// let mut rng = thread_rng(); + /// + /// let die = Uniform::new_inclusive(1, 6); + /// let even_number = die.map(|num| num % 2 == 0); + /// while !even_number.sample(&mut rng) { + /// println!("Still odd; rolling again!"); + /// } + /// ``` + fn map(self, func: F) -> DistMap + where + F: Fn(T) -> S, + Self: Sized, + { + DistMap { + distr: self, + func, + phantom: ::core::marker::PhantomData, + } + } } impl<'a, T, D: Distribution> Distribution for &'a D { @@ -256,6 +285,28 @@ where { } +/// A distribution of values of type `S` derived from the distribution `D` +/// by mapping its output of type `T` through the closure `F`. +/// +/// This `struct` is created by the [`Distribution::map`] method. +/// See its documentation for more. +#[derive(Debug)] +pub struct DistMap { + distr: D, + func: F, + phantom: ::core::marker::PhantomData S>, +} + +impl Distribution for DistMap +where + D: Distribution, + F: Fn(T) -> S, +{ + fn sample(&self, rng: &mut R) -> S { + (self.func)(self.distr.sample(rng)) + } +} + /// A generic random value distribution, implemented for many primitive types. /// Usually generates values with a numerically uniform distribution, and with a /// range appropriate to the type. @@ -360,6 +411,15 @@ mod tests { assert!(0. < sum && sum < 100.); } + #[test] + fn test_distributions_map() { + let dist = Uniform::new_inclusive(0, 5).map(|val| val + 15); + + let mut rng = crate::test::rng(212); + let val = dist.sample(&mut rng); + assert!(val >= 15 && val <= 20); + } + #[test] fn test_make_an_iter() { fn ten_dice_rolls_other_than_five(