Skip to content

Commit

Permalink
Make seq::index::sample_rejection generic over uint index types
Browse files Browse the repository at this point in the history
  • Loading branch information
dhardy committed Jun 3, 2019
1 parent 80b5be9 commit 1c57f70
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions src/seq/index.rs
Expand Up @@ -16,7 +16,7 @@
#[cfg(feature="std")] use std::collections::{HashSet};
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeSet;

#[cfg(feature="alloc")] use distributions::{Distribution, Uniform};
#[cfg(feature="alloc")] use distributions::{Distribution, Uniform, uniform::SampleUniform};
use Rng;

/// A vector of indices.
Expand Down Expand Up @@ -212,9 +212,7 @@ where R: Rng + ?Sized {
if (length as f32) < C[j] * (amount as f32) {
sample_inplace(rng, length, amount)
} else {
// note: could have a specific u32 impl, but I'm lazy and
// generics don't have usable conversions
sample_rejection(rng, length as usize, amount as usize)
sample_rejection(rng, length, amount)
}
}
}
Expand Down Expand Up @@ -285,28 +283,44 @@ where R: Rng + ?Sized {
IndexVec::from(indices)
}

trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash {
fn zero() -> Self;
fn as_usize(self) -> usize;
}
impl UInt for u32 {
#[inline] fn zero() -> Self { 0 }
#[inline] fn as_usize(self) -> usize { self as usize }
}
impl UInt for usize {
#[inline] fn zero() -> Self { 0 }
#[inline] fn as_usize(self) -> usize { self }
}

/// Randomly sample exactly `amount` indices from `0..length`, using rejection
/// sampling.
///
/// Since `amount <<< length` there is a low chance of a random sample in
/// `0..length` being a duplicate. We test for duplicates and resample where
/// necessary. The algorithm is `O(amount)` time and memory.
fn sample_rejection<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
where R: Rng + ?Sized {
///
/// This function is generic over X primarily so that results are value-stable
/// over 32-bit and 64-bit platforms.
fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec
where R: Rng + ?Sized, IndexVec: From<Vec<X>> {
debug_assert!(amount < length);
#[cfg(feature="std")] let mut cache = HashSet::with_capacity(amount);
#[cfg(feature="std")] let mut cache = HashSet::with_capacity(amount.as_usize());
#[cfg(not(feature="std"))] let mut cache = BTreeSet::new();
let distr = Uniform::new(0, length);
let mut indices = Vec::with_capacity(amount);
for _ in 0..amount {
let distr = Uniform::new(X::zero(), length);
let mut indices = Vec::with_capacity(amount.as_usize());
for _ in 0..amount.as_usize() {
let mut pos = distr.sample(rng);
while !cache.insert(pos) {
pos = distr.sample(rng);
}
indices.push(pos);
}

debug_assert_eq!(indices.len(), amount);
debug_assert_eq!(indices.len(), amount.as_usize());
IndexVec::from(indices)
}

Expand All @@ -322,14 +336,14 @@ mod test {
assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0);
assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]);

assert_eq!(sample_rejection(&mut r, 1, 0).len(), 0);
assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0);

assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0);
assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0);
assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]);

// These algorithms should be fast with big numbers. Test average.
let sum: usize = sample_rejection(&mut r, 1 << 25, 10)
let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32)
.into_iter().sum();
assert!(1 << 25 < sum && sum < (1 << 25) * 25);

Expand Down Expand Up @@ -368,7 +382,7 @@ mod test {
// A large length and larger amount should use cache
let (length, amount): (usize, usize) = (1<<20, 600);
let v1 = sample(&mut seed_rng(422), length, amount);
let v2 = sample_rejection(&mut seed_rng(422), length, amount);
let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32);
assert!(v1.iter().all(|e| e < length));
assert_eq!(v1, v2);
}
Expand Down

0 comments on commit 1c57f70

Please sign in to comment.