Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distribution::sample_iter changes #758

Merged
merged 3 commits into from Apr 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 45 additions & 17 deletions src/distributions/mod.rs
Expand Up @@ -182,29 +182,35 @@ pub trait Distribution<T> {
/// Create an iterator that generates random values of `T`, using `rng` as
/// the source of randomness.
///
/// Note that this function takes `self` by value. This works since
/// `Distribution<T>` is impl'd for `&D` where `D: Distribution<T>`,
/// however borrowing is not automatic hence `distr.sample_iter(...)` may
/// need to be replaced with `(&distr).sample_iter(...)` to borrow or
/// `(&*distr).sample_iter(...)` to reborrow an existing reference.
///
/// # Example
///
/// ```
/// use rand::thread_rng;
/// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard};
///
/// let mut rng = thread_rng();
/// let rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = Standard.sample_iter(&mut rng).take(16).collect();
/// let v: Vec<f32> = Standard.sample_iter(rng).take(16).collect();
///
/// // String:
/// let s: String = Alphanumeric.sample_iter(&mut rng).take(7).collect();
/// let s: String = Alphanumeric.sample_iter(rng).take(7).collect();
///
/// // Dice-rolling:
/// let die_range = Uniform::new_inclusive(1, 6);
/// let mut roll_die = die_range.sample_iter(&mut rng);
/// let mut roll_die = die_range.sample_iter(rng);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T>
where Self: Sized, R: Rng
fn sample_iter<R>(self, rng: R) -> DistIter<Self, R, T>
where R: Rng, Self: Sized
{
DistIter {
distr: self,
Expand All @@ -229,20 +235,23 @@ impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
///
/// [`sample_iter`]: Distribution::sample_iter
#[derive(Debug)]
pub struct DistIter<'a, D: 'a, R: 'a, T> {
distr: &'a D,
rng: &'a mut R,
pub struct DistIter<D, R, T> {
distr: D,
rng: R,
phantom: ::core::marker::PhantomData<T>,
}

impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a
impl<D, R, T> Iterator for DistIter<D, R, T>
where D: Distribution<T>, R: Rng
{
type Item = T;

#[inline(always)]
fn next(&mut self) -> Option<T> {
Some(self.distr.sample(self.rng))
// Here, self.rng may be a reference, but we must take &mut anyway.
// Even if sample could take an R: Rng by value, we would need to do this
// since Rng is not copyable and we cannot enforce that this is "reborrowable".
Some(self.distr.sample(&mut self.rng))
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand All @@ -251,12 +260,12 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
}

#[cfg(rustc_1_26)]
impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a {}
impl<D, R, T> iter::FusedIterator for DistIter<D, R, T>
where D: Distribution<T>, R: Rng {}

#[cfg(features = "nightly")]
impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a {}
impl<D, R, T> iter::TrustedLen for DistIter<D, R, T>
where D: Distribution<T>, R: Rng {}


/// A generic random value distribution, implemented for many primitive types.
Expand Down Expand Up @@ -340,7 +349,8 @@ pub struct Standard;

#[cfg(all(test, feature = "std"))]
mod tests {
use super::Distribution;
use ::Rng;
use super::{Distribution, Uniform};

#[test]
fn test_distributions_iter() {
Expand All @@ -350,4 +360,22 @@ mod tests {
let results: Vec<f32> = distr.sample_iter(&mut rng).take(100).collect();
println!("{:?}", results);
}

#[test]
fn test_make_an_iter() {
fn ten_dice_rolls_other_than_five<'a, R: Rng>(rng: &'a mut R) -> impl Iterator<Item = i32> + 'a {
Uniform::new_inclusive(1, 6)
.sample_iter(rng)
.filter(|x| *x != 5)
.take(10)
}

let mut rng = ::test::rng(211);
let mut count = 0;
for val in ten_dice_rolls_other_than_five(&mut rng) {
assert!(val >= 1 && val <= 6 && val != 5);
count += 1;
}
assert_eq!(count, 10);
}
}
22 changes: 13 additions & 9 deletions src/lib.rs
Expand Up @@ -206,35 +206,39 @@ pub trait Rng: RngCore {

/// Create an iterator that generates values using the given distribution.
///
/// Note that this function takes its arguments by value. This works since
/// `(&mut R): Rng where R: Rng` and
/// `(&D): Distribution where D: Distribution`,
/// however borrowing is not automatic hence `rng.sample_iter(...)` may
/// need to be replaced with `(&mut rng).sample_iter(...)`.
///
/// # Example
///
/// ```
/// use rand::{thread_rng, Rng};
/// use rand::distributions::{Alphanumeric, Uniform, Standard};
///
/// let mut rng = thread_rng();
/// let rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = thread_rng().sample_iter(&Standard).take(16).collect();
/// let v: Vec<f32> = rng.sample_iter(Standard).take(16).collect();
///
/// // String:
/// let s: String = rng.sample_iter(&Alphanumeric).take(7).collect();
/// let s: String = rng.sample_iter(Alphanumeric).take(7).collect();
///
/// // Combined values
/// println!("{:?}", thread_rng().sample_iter(&Standard).take(5)
/// println!("{:?}", rng.sample_iter(Standard).take(5)
/// .collect::<Vec<(f64, bool)>>());
///
/// // Dice-rolling:
/// let die_range = Uniform::new_inclusive(1, 6);
/// let mut roll_die = rng.sample_iter(&die_range);
/// let mut roll_die = rng.sample_iter(die_range);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, T, D: Distribution<T>>(
&'a mut self, distr: &'a D,
) -> distributions::DistIter<'a, D, Self, T>
where Self: Sized {
fn sample_iter<T, D>(self, distr: D) -> distributions::DistIter<D, Self, T>
where D: Distribution<T>, Self: Sized {
distr.sample_iter(self)
}

Expand Down
2 changes: 1 addition & 1 deletion src/rngs/thread.rs
Expand Up @@ -67,7 +67,7 @@ const THREAD_RNG_RESEED_THRESHOLD: u64 = 32*1024*1024; // 32 MiB
/// [`ReseedingRng`]: crate::rngs::adapter::ReseedingRng
/// [`StdRng`]: crate::rngs::StdRng
/// [HC-128]: rand_hc::Hc128Rng
#[derive(Clone, Debug)]
#[derive(Copy, Clone, Debug)]
pub struct ThreadRng {
// use of raw pointer implies type is neither Send nor Sync
rng: *mut ReseedingRng<Hc128Core, OsRng>,
Expand Down