From f58091721b762688082c0bf593af1eb0f1019cfe Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Sat, 1 Jun 2019 12:54:31 +0100 Subject: [PATCH] Make seq::index::sample_rejection generic over uint index types --- src/seq/index.rs | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/src/seq/index.rs b/src/seq/index.rs index 79ed6c0ec7e..b6fc81e1bb1 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -16,7 +16,7 @@ #[cfg(feature="std")] use std::collections::{HashSet}; #[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeSet; -#[cfg(feature="alloc")] use distributions::{Distribution, Uniform}; +#[cfg(feature="alloc")] use distributions::{Distribution, Uniform, uniform::SampleUniform}; use Rng; /// A vector of indices. @@ -212,9 +212,7 @@ where R: Rng + ?Sized { if (length as f32) < C[j] * (amount as f32) { sample_inplace(rng, length, amount) } else { - // note: could have a specific u32 impl, but I'm lazy and - // generics don't have usable conversions - sample_rejection(rng, length as usize, amount as usize) + sample_rejection(rng, length, amount) } } } @@ -285,20 +283,36 @@ where R: Rng + ?Sized { IndexVec::from(indices) } +trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash { + fn zero() -> Self; + fn as_usize(self) -> usize; +} +impl UInt for u32 { + #[inline] fn zero() -> Self { 0 } + #[inline] fn as_usize(self) -> usize { self as usize } +} +impl UInt for usize { + #[inline] fn zero() -> Self { 0 } + #[inline] fn as_usize(self) -> usize { self } +} + /// Randomly sample exactly `amount` indices from `0..length`, using rejection /// sampling. /// /// Since `amount <<< length` there is a low chance of a random sample in /// `0..length` being a duplicate. We test for duplicates and resample where /// necessary. The algorithm is `O(amount)` time and memory. -fn sample_rejection(rng: &mut R, length: usize, amount: usize) -> IndexVec -where R: Rng + ?Sized { +/// +/// This function is generic over X primarily so that results are value-stable +/// over 32-bit and 64-bit platforms. +fn sample_rejection(rng: &mut R, length: X, amount: X) -> IndexVec +where R: Rng + ?Sized, IndexVec: From> { debug_assert!(amount < length); - #[cfg(feature="std")] let mut cache = HashSet::with_capacity(amount); + #[cfg(feature="std")] let mut cache = HashSet::with_capacity(amount.as_usize()); #[cfg(not(feature="std"))] let mut cache = BTreeSet::new(); - let distr = Uniform::new(0, length); - let mut indices = Vec::with_capacity(amount); - for _ in 0..amount { + let distr = Uniform::new(X::zero(), length); + let mut indices = Vec::with_capacity(amount.as_usize()); + for _ in 0..amount.as_usize() { let mut pos = distr.sample(rng); while !cache.insert(pos) { pos = distr.sample(rng); @@ -306,7 +320,7 @@ where R: Rng + ?Sized { indices.push(pos); } - debug_assert_eq!(indices.len(), amount); + debug_assert_eq!(indices.len(), amount.as_usize()); IndexVec::from(indices) } @@ -322,14 +336,14 @@ mod test { assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0); assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]); - assert_eq!(sample_rejection(&mut r, 1, 0).len(), 0); + assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0); assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0); assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0); assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]); // These algorithms should be fast with big numbers. Test average. - let sum: usize = sample_rejection(&mut r, 1 << 25, 10) + let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32) .into_iter().sum(); assert!(1 << 25 < sum && sum < (1 << 25) * 25); @@ -368,7 +382,7 @@ mod test { // A large length and larger amount should use cache let (length, amount): (usize, usize) = (1<<20, 600); let v1 = sample(&mut seed_rng(422), length, amount); - let v2 = sample_rejection(&mut seed_rng(422), length, amount); + let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32); assert!(v1.iter().all(|e| e < length)); assert_eq!(v1, v2); }