Skip to content

Commit

Permalink
Use Iterator::size_hint() to speed up IteratorRandom::choose
Browse files Browse the repository at this point in the history
  • Loading branch information
sicking committed Aug 21, 2018
1 parent 22d6607 commit 9ff000a
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 25 deletions.
73 changes: 68 additions & 5 deletions benches/seq.rs
Expand Up @@ -8,6 +8,9 @@ use test::Bencher;

use rand::prelude::*;
use rand::seq::*;
use std::mem::size_of;

const RAND_BENCH_N: u64 = 1000;

#[bench]
fn seq_shuffle_100(b: &mut Bencher) {
Expand All @@ -22,10 +25,18 @@ fn seq_shuffle_100(b: &mut Bencher) {
#[bench]
fn seq_slice_choose_1_of_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &[usize] = &[1; 1000];
let x : &mut [usize] = &mut [1; 1000];
for i in 0..1000 {
x[i] = i;
}
b.iter(|| {
x.choose(&mut rng)
})
let mut s = 0;
for _ in 0..RAND_BENCH_N {
s += x.choose(&mut rng).unwrap();
}
s
});
b.bytes = size_of::<usize>() as u64 * ::RAND_BENCH_N;
}

macro_rules! seq_slice_choose_multiple {
Expand Down Expand Up @@ -54,11 +65,63 @@ seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100);
seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100);

#[bench]
fn seq_iter_choose_from_100(b: &mut Bencher) {
fn seq_iter_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &mut [usize] = &mut [1; 1000];
for i in 0..1000 {
x[i] = i;
}
b.iter(|| {
let mut s = 0;
for _ in 0..RAND_BENCH_N {
s += x.iter().choose(&mut rng).unwrap();
}
s
});
b.bytes = size_of::<usize>() as u64 * ::RAND_BENCH_N;
}

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I,
}
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
window_size: usize,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(std::cmp::min(self.iter.len(), self.window_size), None)
}
}

#[bench]
fn seq_iter_unhinted_choose_from_100(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &[usize] = &[1; 1000];
b.iter(|| {
UnhintedIterator { iter: x.iter() }.choose(&mut rng).unwrap()
})
}

#[bench]
fn seq_iter_window_hinted_choose_from_100(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &[usize] = &[1; 100];
b.iter(|| {
x.iter().cloned().choose(&mut rng)
WindowHintedIterator { iter: x.iter(), window_size: 7 }.choose(&mut rng)
})
}

Expand Down
144 changes: 124 additions & 20 deletions src/seq/mod.rs
Expand Up @@ -188,20 +188,57 @@ pub trait IteratorRandom: Iterator + Sized {
fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
where R: Rng + ?Sized
{
if let Some(elem) = self.next() {
let mut result = elem;

// Continue until the iterator is exhausted
for (i, elem) in self.enumerate() {
let denom = (i + 2) as f64; // accurate to 2^53 elements
let (mut lower, mut upper) = self.size_hint();
let mut consumed = 0;
let mut result = None;

if upper == Some(lower) {
// Remove this once we can specialize on ExactSizeIterator
return if lower == 0 { None } else { self.nth(rng.gen_range(0, lower)) };
} else if lower <= 1 {
result = self.next();
if result.is_none() {
return result;
}
consumed = 1;
let hint = self.size_hint();
lower = hint.0;
upper = hint.1;
}

// Continue until the iterator is exhausted
loop {
if lower > 1 {
let ix = rng.gen_range(0, lower + consumed);
let skip;
if ix < lower {
result = self.nth(ix);
skip = lower - (ix + 1);
} else {
skip = lower;
}
if upper == Some(lower) {
return result;
}
consumed += lower;
if skip > 0 {
self.nth(skip - 1);
}
} else {
let elem = self.next();
if elem.is_none() {
return result;
}
consumed += 1;
let denom = consumed as f64; // accurate to 2^53 elements
if rng.gen_bool(1.0 / denom) {
result = elem;
}
}
Some(result)
} else {
None

let hint = self.size_hint();
lower = hint.0;
upper = hint.1;
}
}

Expand Down Expand Up @@ -519,20 +556,87 @@ mod test {
assert_eq!(v.choose_mut(&mut r), None);
}

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I,
}
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
chunk_remaining: usize,
chunk_size: usize,
hint_total_size: bool,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
if self.chunk_remaining == 0 {
self.chunk_remaining = ::core::cmp::min(self.chunk_size,
self.iter.len());
}
self.chunk_remaining = self.chunk_remaining.saturating_sub(1);

self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.chunk_remaining,
if self.hint_total_size { Some(self.iter.len()) } else { None })
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
window_size: usize,
hint_total_size: bool,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(::core::cmp::min(self.iter.len(), self.window_size),
if self.hint_total_size { Some(self.iter.len()) } else { None })
}
}

#[test]
fn test_iterator_choose() {
let mut r = ::test::rng(109);
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = (0..9).choose(&mut r).unwrap();
chosen[picked] += 1;
}
for count in chosen.iter() {
let err = *count - 1000 / 9;
assert!(-25 <= err && err <= 25);
let r = &mut ::test::rng(109);
fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item=usize> + Clone>(r: &mut R, iter: Iter) {
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = iter.clone().choose(r).unwrap();
chosen[picked] += 1;
}
for count in chosen.iter() {
let err = *count - 1000 / 9;
if !(-25 <= err && err <= 25) {
println!("err is {}", err)
}
assert!(-25 <= err && err <= 25);
}
}

assert_eq!((0..0).choose(&mut r), None);
test_iter(r, 0..9);
test_iter(r, (0..9).collect::<Vec<_>>().iter().cloned());
test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
test_iter(r, UnhintedIterator { iter: 0..9 });
test_iter(r, ChunkHintedIterator { iter: 0..9, chunk_size: 4, chunk_remaining: 4, hint_total_size: false });
test_iter(r, ChunkHintedIterator { iter: 0..9, chunk_size: 4, chunk_remaining: 4, hint_total_size: true });
test_iter(r, WindowHintedIterator { iter: 0..9, window_size: 2, hint_total_size: false });
test_iter(r, WindowHintedIterator { iter: 0..9, window_size: 2, hint_total_size: true });

assert_eq!((0..0).choose(r), None);
assert_eq!(UnhintedIterator{ iter: 0..0 }.choose(r), None);
}

#[test]
Expand Down

0 comments on commit 9ff000a

Please sign in to comment.