Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Iterator::size_hint() to speed up IteratorRandom::choose #593

Merged
merged 1 commit into from Aug 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 68 additions & 5 deletions benches/seq.rs
Expand Up @@ -8,6 +8,9 @@ use test::Bencher;

use rand::prelude::*;
use rand::seq::*;
use std::mem::size_of;

const RAND_BENCH_N: u64 = 1000;

#[bench]
fn seq_shuffle_100(b: &mut Bencher) {
Expand All @@ -22,10 +25,18 @@ fn seq_shuffle_100(b: &mut Bencher) {
#[bench]
fn seq_slice_choose_1_of_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &[usize] = &[1; 1000];
let x : &mut [usize] = &mut [1; 1000];
for i in 0..1000 {
x[i] = i;
}
b.iter(|| {
x.choose(&mut rng)
})
let mut s = 0;
for _ in 0..RAND_BENCH_N {
s += x.choose(&mut rng).unwrap();
}
s
});
b.bytes = size_of::<usize>() as u64 * ::RAND_BENCH_N;
}

macro_rules! seq_slice_choose_multiple {
Expand Down Expand Up @@ -54,11 +65,63 @@ seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100);
seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100);

#[bench]
fn seq_iter_choose_from_100(b: &mut Bencher) {
fn seq_iter_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &mut [usize] = &mut [1; 1000];
for i in 0..1000 {
x[i] = i;
}
b.iter(|| {
let mut s = 0;
for _ in 0..RAND_BENCH_N {
s += x.iter().choose(&mut rng).unwrap();
}
s
});
b.bytes = size_of::<usize>() as u64 * ::RAND_BENCH_N;
}

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I,
}
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
window_size: usize,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(std::cmp::min(self.iter.len(), self.window_size), None)
}
}

#[bench]
fn seq_iter_unhinted_choose_from_100(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &[usize] = &[1; 1000];
b.iter(|| {
UnhintedIterator { iter: x.iter() }.choose(&mut rng).unwrap()
})
}

#[bench]
fn seq_iter_window_hinted_choose_from_100(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x : &[usize] = &[1; 100];
b.iter(|| {
x.iter().cloned().choose(&mut rng)
WindowHintedIterator { iter: x.iter(), window_size: 7 }.choose(&mut rng)
})
}

Expand Down
142 changes: 122 additions & 20 deletions src/seq/mod.rs
Expand Up @@ -188,20 +188,57 @@ pub trait IteratorRandom: Iterator + Sized {
fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
where R: Rng + ?Sized
{
if let Some(elem) = self.next() {
let mut result = elem;

// Continue until the iterator is exhausted
for (i, elem) in self.enumerate() {
let denom = (i + 2) as f64; // accurate to 2^53 elements
let (mut lower, mut upper) = self.size_hint();
let mut consumed = 0;
let mut result = None;

if upper == Some(lower) {
// Remove this once we can specialize on ExactSizeIterator
return if lower == 0 { None } else { self.nth(rng.gen_range(0, lower)) };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we should remove later — it's valid to give an exact size hint without implementing the latter trait (e.g. because the size bounds are not always known).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can certainly go either way. The code should still compile to something quite similar since the upper == Some(lower) check further down is still there. But we do a few branches which the compiler likely won't be able to optimize away before getting there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's seems there's a big difference when this is removed — around 6ms vs 3.7ms on the fully hinted seq_iter_choose_from_1000 test — so worth keeping I think.

} else if lower <= 1 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this path for? Seems to me that it's only purpose is to eliminate a possible None result early, yet it doesn't always do that. Does it improve performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My goal was to ensure that the new code should reduce down to the old code once the compiler does constant folding of an "unhinted" iterator. And similarly that doing constant folding of a perfectly hinted iterator would reduce to SliceRandom::choose. I.e. the first two categories discussed in the initial comment should reduce to optimal code once constant folding is done.

That's why both of these paths are there. Both could be removed without loosing any correctness, and in both cases we just end up adding a few branches to the runtime code.

(The difference for "perfectly hinted" is slightly bigger for iterators of size exactly 1, but that's a rare use case anyway).

We definitely could remove these paths. I don't have a strong opinion either way. The performance difference will be small and will mainly depend on how slow other operations are, i.e. how many items the "unhinted" iterator has, and how fast RngCore implementation is used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, makes some sense, though I still don't really understand the reason to use lower <= 1 here. You could simply drop the condition (equivalent of true) and would get the same behaviour on unhinted iterators.

My gut feeling is that the cases worth bothering with are (a) perfectly hinted iterators, (b) iterators where some lower bound is known, and (c) iterators with no hinting. So you could potentially insert a lower > 0 case (instead of the loop case) and use this code for the lower == 0 case — but I think only benchmarks can tell which one is best.

Are chunked iterators common? It's not a pattern I remember seeing with iterators, but I guess it may be relevant to buffered input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, makes some sense, though I still don't really understand the reason to use lower <= 1 here

When lower > 1 it is better to use gen_range to pick one of the first lower items, than to unconditionally grab the first item and then iterate.

Are chunked iterators common?

No idea.

As mentioned in other comments, the idea is that for all other cases, the code will reduce to the optimal approach. So the cost to support chunked iterators is only one of source complexity, neither one of runtime cost, nor of binary size.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this path improves benchmarks slightly. The only functional difference from the second path within the loop is that the gen_bool bit is skipped.

The Bernoulli code actually has special handling for the p==1.0 case anyway and doesn't touch the RNG.

Simpler code and faster benchmarks implies we're probably better off without this code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong preference either way. Generally it's a good idea to be careful about optimizing too much based on just microbenchmarks, since that often doesn't reflect real-world performance.

In theory this code should help "unhinted" iterators, though as you point out by very little since it mainly saves a few branches. In neither case will we call into the RNG to generate numbers. So yeah, probably worth removing since the win seems quite small.

I can see two reasons why removing this code would speed up anything:

  • We might end up inlining things better since the function is smaller.
  • The "window hinted" test might get faster since we do one branch less. However this benchmark seems less important to me than the "unhinted" and the fully hinted ones.

Which benchmarks specifically got faster with this removed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unhinted one:

test seq_iter_choose_from_1000               ... bench:       3,756 ns/iter (+/- 108) = 2129 MB/s
test seq_iter_unhinted_choose_from_1000      ... bench:       5,311 ns/iter (+/- 98)
test seq_iter_window_hinted_choose_from_1000 ... bench:       1,789 ns/iter (+/- 44)
# after removing this branch:
test seq_iter_choose_from_1000               ... bench:       3,746 ns/iter (+/- 60) = 2135 MB/s
test seq_iter_unhinted_choose_from_1000      ... bench:       5,007 ns/iter (+/- 106)
test seq_iter_window_hinted_choose_from_1000 ... bench:       1,769 ns/iter (+/- 42)

result = self.next();
if result.is_none() {
return result;
}
consumed = 1;
let hint = self.size_hint();
lower = hint.0;
upper = hint.1;
}

// Continue until the iterator is exhausted
loop {
if lower > 1 {
let ix = rng.gen_range(0, lower + consumed);
let skip;
if ix < lower {
result = self.nth(ix);
skip = lower - (ix + 1);
} else {
skip = lower;
}
if upper == Some(lower) {
return result;
}
consumed += lower;
if skip > 0 {
self.nth(skip - 1);
}
} else {
let elem = self.next();
if elem.is_none() {
return result;
}
consumed += 1;
let denom = consumed as f64; // accurate to 2^53 elements
if rng.gen_bool(1.0 / denom) {
result = elem;
}
}
Some(result)
} else {
None

let hint = self.size_hint();
lower = hint.0;
upper = hint.1;
}
}

Expand Down Expand Up @@ -519,20 +556,85 @@ mod test {
assert_eq!(v.choose_mut(&mut r), None);
}

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I,
}
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
chunk_remaining: usize,
chunk_size: usize,
hint_total_size: bool,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
if self.chunk_remaining == 0 {
self.chunk_remaining = ::core::cmp::min(self.chunk_size,
self.iter.len());
}
self.chunk_remaining = self.chunk_remaining.saturating_sub(1);

self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.chunk_remaining,
if self.hint_total_size { Some(self.iter.len()) } else { None })
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
window_size: usize,
hint_total_size: bool,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
(::core::cmp::min(self.iter.len(), self.window_size),
if self.hint_total_size { Some(self.iter.len()) } else { None })
}
}

#[test]
fn test_iterator_choose() {
let mut r = ::test::rng(109);
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = (0..9).choose(&mut r).unwrap();
chosen[picked] += 1;
}
for count in chosen.iter() {
let err = *count - 1000 / 9;
assert!(-25 <= err && err <= 25);
let r = &mut ::test::rng(109);
fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item=usize> + Clone>(r: &mut R, iter: Iter) {
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = iter.clone().choose(r).unwrap();
chosen[picked] += 1;
}
for count in chosen.iter() {
let err = *count - 1000 / 9;
assert!(-25 <= err && err <= 25);
}
}

assert_eq!((0..0).choose(&mut r), None);
test_iter(r, 0..9);
test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
#[cfg(feature = "alloc")]
test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
test_iter(r, UnhintedIterator { iter: 0..9 });
test_iter(r, ChunkHintedIterator { iter: 0..9, chunk_size: 4, chunk_remaining: 4, hint_total_size: false });
test_iter(r, ChunkHintedIterator { iter: 0..9, chunk_size: 4, chunk_remaining: 4, hint_total_size: true });
test_iter(r, WindowHintedIterator { iter: 0..9, window_size: 2, hint_total_size: false });
test_iter(r, WindowHintedIterator { iter: 0..9, window_size: 2, hint_total_size: true });

assert_eq!((0..0).choose(r), None);
assert_eq!(UnhintedIterator{ iter: 0..0 }.choose(r), None);
}

#[test]
Expand Down