Skip to content

Commit

Permalink
Merge pull request #1129 from adamreichold/dist-map
Browse files Browse the repository at this point in the history
Create a distribution by mapping the output of another one
  • Loading branch information
dhardy committed May 24, 2021
2 parents ea26d87 + e0e13d7 commit a97d94a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -12,6 +12,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
### Additions
- Use const-generics to support arrays of all sizes (#1104)
- Implement `Clone` and `Copy` for `Alphanumeric` (#1126)
- Add `Distribution::map` to derive a distribution using a closure (#1129)

### Other
- Reorder asserts in `Uniform` float distributions for easier debugging of non-finite arguments
Expand Down
60 changes: 60 additions & 0 deletions src/distributions/mod.rs
Expand Up @@ -199,6 +199,35 @@ pub trait Distribution<T> {
phantom: ::core::marker::PhantomData,
}
}

/// Create a distribution of values of 'S' by mapping the output of `Self`
/// through the closure `F`
///
/// # Example
///
/// ```
/// use rand::thread_rng;
/// use rand::distributions::{Distribution, Uniform};
///
/// let mut rng = thread_rng();
///
/// let die = Uniform::new_inclusive(1, 6);
/// let even_number = die.map(|num| num % 2 == 0);
/// while !even_number.sample(&mut rng) {
/// println!("Still odd; rolling again!");
/// }
/// ```
fn map<F, S>(self, func: F) -> DistMap<Self, F, T, S>
where
F: Fn(T) -> S,
Self: Sized,
{
DistMap {
distr: self,
func,
phantom: ::core::marker::PhantomData,
}
}
}

impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
Expand Down Expand Up @@ -256,6 +285,28 @@ where
{
}

/// A distribution of values of type `S` derived from the distribution `D`
/// by mapping its output of type `T` through the closure `F`.
///
/// This `struct` is created by the [`Distribution::map`] method.
/// See its documentation for more.
#[derive(Debug)]
pub struct DistMap<D, F, T, S> {
distr: D,
func: F,
phantom: ::core::marker::PhantomData<fn(T) -> S>,
}

impl<D, F, T, S> Distribution<S> for DistMap<D, F, T, S>
where
D: Distribution<T>,
F: Fn(T) -> S,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S {
(self.func)(self.distr.sample(rng))
}
}

/// A generic random value distribution, implemented for many primitive types.
/// Usually generates values with a numerically uniform distribution, and with a
/// range appropriate to the type.
Expand Down Expand Up @@ -360,6 +411,15 @@ mod tests {
assert!(0. < sum && sum < 100.);
}

#[test]
fn test_distributions_map() {
let dist = Uniform::new_inclusive(0, 5).map(|val| val + 15);

let mut rng = crate::test::rng(212);
let val = dist.sample(&mut rng);
assert!(val >= 15 && val <= 20);
}

#[test]
fn test_make_an_iter() {
fn ten_dice_rolls_other_than_five<R: Rng>(
Expand Down

0 comments on commit a97d94a

Please sign in to comment.