Skip to content

Commit

Permalink
Added new versions of choose and choose_stable (#1268)
Browse files Browse the repository at this point in the history
* Added new versions of choose and choose_stable

* Removed coin_flipper tests which were unnecessary and not building on ci

* Performance optimizations in coin_flipper

* Clippy fixes and more documentation

* Added a correctness fix for coin_flipper

* Update benches/seq.rs

Co-authored-by: Vinzent Steinberg <Vinzent.Steinberg@gmail.com>

* Update benches/seq.rs

Co-authored-by: Vinzent Steinberg <Vinzent.Steinberg@gmail.com>

* Removed old version of choose and choose stable and updated value stability tests

* Moved sequence choose benchmarks to their own file

* Reworked coin_flipper

* Use criterion for seq_choose benches

* Removed an old comment

* Change how c is estimated in coin_flipper

* Revert "Use criterion for seq_choose benches"

This reverts commit 2339539.

* Added seq_choose benches for smaller numbers

* Removed some unneeded lines from seq_choose

* Improvements in coin_flipper.rs

* Small refactor of coin_flipper

* Tidied comments in coin_flipper

* Use criterion for seq_choose benchmarks

* Made choose not generate a random number if len=1

* small change to IteratorRandom::choose

* Made it easier to change seq_choose benchmarks RNG

* Added Pcg64 benchmarks for seq_choose

* Added TODO to coin_flipper

* Changed criterion settings in seq_choose

Co-authored-by: Vinzent Steinberg <Vinzent.Steinberg@gmail.com>
  • Loading branch information
wainwrightmark and vks committed Jan 5, 2023
1 parent 3107a54 commit 1e96eb4
Show file tree
Hide file tree
Showing 5 changed files with 460 additions and 185 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ rand_pcg = { path = "rand_pcg", version = "0.4.0" }
bincode = "1.2.1"
rayon = "1.5.3"
criterion = { version = "0.4" }

[[bench]]
name = "seq_choose"
path = "benches/seq_choose.rs"
harness = false
72 changes: 1 addition & 71 deletions benches/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ extern crate test;

use test::Bencher;

use core::mem::size_of;
use rand::prelude::*;
use rand::seq::*;
use core::mem::size_of;

// We force use of 32-bit RNG since seq code is optimised for use with 32-bit
// generators on all platforms.
Expand Down Expand Up @@ -74,76 +74,6 @@ seq_slice_choose_multiple!(seq_slice_choose_multiple_950_of_1000, 950, 1000);
seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100);
seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100);

#[bench]
fn seq_iter_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x: &mut [usize] = &mut [1; 1000];
for (i, r) in x.iter_mut().enumerate() {
*r = i;
}
b.iter(|| {
let mut s = 0;
for _ in 0..RAND_BENCH_N {
s += x.iter().choose(&mut rng).unwrap();
}
s
});
b.bytes = size_of::<usize>() as u64 * crate::RAND_BENCH_N;
}

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I,
}
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
window_size: usize,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}

fn size_hint(&self) -> (usize, Option<usize>) {
(core::cmp::min(self.iter.len(), self.window_size), None)
}
}

#[bench]
fn seq_iter_unhinted_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x: &[usize] = &[1; 1000];
b.iter(|| {
UnhintedIterator { iter: x.iter() }
.choose(&mut rng)
.unwrap()
})
}

#[bench]
fn seq_iter_window_hinted_choose_from_1000(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let x: &[usize] = &[1; 1000];
b.iter(|| {
WindowHintedIterator {
iter: x.iter(),
window_size: 7,
}
.choose(&mut rng)
})
}

#[bench]
fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
Expand Down
111 changes: 111 additions & 0 deletions benches/seq_choose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright 2018-2022 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::prelude::*;
use rand::SeedableRng;

criterion_group!(
name = benches;
config = Criterion::default();
targets = bench
);
criterion_main!(benches);

pub fn bench(c: &mut Criterion) {
bench_rng::<rand_chacha::ChaCha20Rng>(c, "ChaCha20");
bench_rng::<rand_pcg::Pcg32>(c, "Pcg32");
bench_rng::<rand_pcg::Pcg64>(c, "Pcg64");
}

fn bench_rng<Rng: RngCore + SeedableRng>(c: &mut Criterion, rng_name: &'static str) {
for length in [1, 2, 3, 10, 100, 1000].map(|x| black_box(x)) {
c.bench_function(
format!("choose_size-hinted_from_{length}_{rng_name}").as_str(),
|b| {
let mut rng = Rng::seed_from_u64(123);
b.iter(|| choose_size_hinted(length, &mut rng))
},
);

c.bench_function(
format!("choose_stable_from_{length}_{rng_name}").as_str(),
|b| {
let mut rng = Rng::seed_from_u64(123);
b.iter(|| choose_stable(length, &mut rng))
},
);

c.bench_function(
format!("choose_unhinted_from_{length}_{rng_name}").as_str(),
|b| {
let mut rng = Rng::seed_from_u64(123);
b.iter(|| choose_unhinted(length, &mut rng))
},
);

c.bench_function(
format!("choose_windowed_from_{length}_{rng_name}").as_str(),
|b| {
let mut rng = Rng::seed_from_u64(123);
b.iter(|| choose_windowed(length, 7, &mut rng))
},
);
}
}

fn choose_size_hinted<R: Rng>(max: usize, rng: &mut R) -> Option<usize> {
let iterator = 0..max;
iterator.choose(rng)
}

fn choose_stable<R: Rng>(max: usize, rng: &mut R) -> Option<usize> {
let iterator = 0..max;
iterator.choose_stable(rng)
}

fn choose_unhinted<R: Rng>(max: usize, rng: &mut R) -> Option<usize> {
let iterator = UnhintedIterator { iter: (0..max) };
iterator.choose(rng)
}

fn choose_windowed<R: Rng>(max: usize, window_size: usize, rng: &mut R) -> Option<usize> {
let iterator = WindowHintedIterator {
iter: (0..max),
window_size,
};
iterator.choose(rng)
}

#[derive(Clone)]
struct UnhintedIterator<I: Iterator + Clone> {
iter: I,
}
impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

#[derive(Clone)]
struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
iter: I,
window_size: usize,
}
impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}

fn size_hint(&self) -> (usize, Option<usize>) {
(core::cmp::min(self.iter.len(), self.window_size), None)
}
}
152 changes: 152 additions & 0 deletions src/seq/coin_flipper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use crate::RngCore;

pub(crate) struct CoinFlipper<R: RngCore> {
pub rng: R,
chunk: u32, //TODO(opt): this should depend on RNG word size
chunk_remaining: u32,
}

impl<R: RngCore> CoinFlipper<R> {
pub fn new(rng: R) -> Self {
Self {
rng,
chunk: 0,
chunk_remaining: 0,
}
}

#[inline]
/// Returns true with a probability of 1 / d
/// Uses an expected two bits of randomness
/// Panics if d == 0
pub fn gen_ratio_one_over(&mut self, d: usize) -> bool {
debug_assert_ne!(d, 0);
// This uses the same logic as `gen_ratio` but is optimized for the case that
// the starting numerator is one (which it always is for `Sequence::Choose()`)

// In this case (but not `gen_ratio`), this way of calculating c is always accurate
let c = (usize::BITS - 1 - d.leading_zeros()).min(32);

if self.flip_c_heads(c) {
let numerator = 1 << c;
return self.gen_ratio(numerator, d);
} else {
return false;
}
}

#[inline]
/// Returns true with a probability of n / d
/// Uses an expected two bits of randomness
fn gen_ratio(&mut self, mut n: usize, d: usize) -> bool {
// Explanation:
// We are trying to return true with a probability of n / d
// If n >= d, we can just return true
// Otherwise there are two possibilities 2n < d and 2n >= d
// In either case we flip a coin.
// If 2n < d
// If it comes up tails, return false
// If it comes up heads, double n and start again
// This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d
// (if 2n was greater than d we would effectively round it down to 1
// by returning true)
// If 2n >= d
// If it comes up tails, set n to 2n - d and start again
// If it comes up heads, return true
// This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d
// Note that if 2n = d and the coin comes up tails, n will be set to 0
// before restarting which is equivalent to returning false.

// As a performance optimization we can flip multiple coins at once
// This is efficient because we can use the `lzcnt` intrinsic
// We can check up to 32 flips at once but we only receive one bit of information
// - all heads or at least one tail.

// Let c be the number of coins to flip. 1 <= c <= 32
// If 2n < d, n * 2^c < d
// If the result is all heads, then set n to n * 2^c
// If there was at least one tail, return false
// If 2n >= d, the order of results matters so we flip one coin at a time so c = 1
// Ideally, c will be as high as possible within these constraints

while n < d {
// Find a good value for c by counting leading zeros
// This will either give the highest possible c, or 1 less than that
let c = n
.leading_zeros()
.saturating_sub(d.leading_zeros() + 1)
.clamp(1, 32);

if self.flip_c_heads(c) {
// All heads
// Set n to n * 2^c
// If 2n >= d, the while loop will exit and we will return `true`
// If n * 2^c > `usize::MAX` we always return `true` anyway
n = n.saturating_mul(2_usize.pow(c));
} else {
//At least one tail
if c == 1 {
// Calculate 2n - d.
// We need to use wrapping as 2n might be greater than `usize::MAX`
let next_n = n.wrapping_add(n).wrapping_sub(d);
if next_n == 0 || next_n > n {
// This will happen if 2n < d
return false;
}
n = next_n;
} else {
// c > 1 so 2n < d so we can return false
return false;
}
}
}
true
}

/// If the next `c` bits of randomness all represent heads, consume them, return true
/// Otherwise return false and consume the number of heads plus one.
/// Generates new bits of randomness when necessary (in 32 bit chunks)
/// Has a 1 in 2 to the `c` chance of returning true
/// `c` must be less than or equal to 32
fn flip_c_heads(&mut self, mut c: u32) -> bool {
debug_assert!(c <= 32);
// Note that zeros on the left of the chunk represent heads.
// It needs to be this way round because zeros are filled in when left shifting
loop {
let zeros = self.chunk.leading_zeros();

if zeros < c {
// The happy path - we found a 1 and can return false
// Note that because a 1 bit was detected,
// We cannot have run out of random bits so we don't need to check

// First consume all of the bits read
// Using shl seems to give worse performance for size-hinted iterators
self.chunk = self.chunk.wrapping_shl(zeros + 1);

self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1);
return false;
} else {
// The number of zeros is larger than `c`
// There are two possibilities
if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) {
// Those zeroes were all part of our random chunk,
// throw away `c` bits of randomness and return true
self.chunk_remaining = new_remaining;
self.chunk <<= c;
return true;
} else {
// Some of those zeroes were part of the random chunk
// and some were part of the space behind it
// We need to take into account only the zeroes that were random
c -= self.chunk_remaining;

// Generate a new chunk
self.chunk = self.rng.next_u32();
self.chunk_remaining = 32;
// Go back to start of loop
}
}
}
}
}

0 comments on commit 1e96eb4

Please sign in to comment.