Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a distribution by mapping the output of another one #1129

Merged
merged 1 commit into from May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -12,6 +12,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
### Additions
- Use const-generics to support arrays of all sizes (#1104)
- Implement `Clone` and `Copy` for `Alphanumeric` (#1126)
- Add `Distribution::map` to derive a distribution using a closure (#1129)

### Other
- Reorder asserts in `Uniform` float distributions for easier debugging of non-finite arguments
Expand Down
60 changes: 60 additions & 0 deletions src/distributions/mod.rs
Expand Up @@ -199,6 +199,35 @@ pub trait Distribution<T> {
phantom: ::core::marker::PhantomData,
}
}

/// Create a distribution of values of 'S' by mapping the output of `Self`
/// through the closure `F`
///
/// # Example
///
/// ```
/// use rand::thread_rng;
/// use rand::distributions::{Distribution, Uniform};
///
/// let mut rng = thread_rng();
///
/// let die = Uniform::new_inclusive(1, 6);
/// let even_number = die.map(|num| num % 2 == 0);
/// while !even_number.sample(&mut rng) {
/// println!("Still odd; rolling again!");
/// }
/// ```
fn map<F, S>(self, func: F) -> DistMap<Self, F, T, S>
where
F: Fn(T) -> S,
Self: Sized,
{
DistMap {
distr: self,
func,
phantom: ::core::marker::PhantomData,
}
}
}

impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
Expand Down Expand Up @@ -256,6 +285,28 @@ where
{
}

/// A distribution of values of type `S` derived from the distribution `D`
/// by mapping its output of type `T` through the closure `F`.
///
/// This `struct` is created by the [`Distribution::map`] method.
/// See its documentation for more.
#[derive(Debug)]
pub struct DistMap<D, F, T, S> {
distr: D,
func: F,
phantom: ::core::marker::PhantomData<fn(T) -> S>,
}

impl<D, F, T, S> Distribution<S> for DistMap<D, F, T, S>
where
D: Distribution<T>,
F: Fn(T) -> S,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S {
(self.func)(self.distr.sample(rng))
}
}

/// A generic random value distribution, implemented for many primitive types.
/// Usually generates values with a numerically uniform distribution, and with a
/// range appropriate to the type.
Expand Down Expand Up @@ -360,6 +411,15 @@ mod tests {
assert!(0. < sum && sum < 100.);
}

#[test]
fn test_distributions_map() {
let dist = Uniform::new_inclusive(0, 5).map(|val| val + 15);

let mut rng = crate::test::rng(212);
let val = dist.sample(&mut rng);
assert!(val >= 15 && val <= 20);
adamreichold marked this conversation as resolved.
Show resolved Hide resolved
}

#[test]
fn test_make_an_iter() {
fn ten_dice_rolls_other_than_five<R: Rng>(
Expand Down