Skip to content

Commit

Permalink
sample_weighted: Make sure the correct IndexVec is generated
Browse files Browse the repository at this point in the history
Also add some tests.
  • Loading branch information
vks committed Aug 27, 2020
1 parent 03bd82a commit d73eac5
Showing 1 changed file with 50 additions and 10 deletions.
60 changes: 50 additions & 10 deletions src/seq/index.rs
Expand Up @@ -280,9 +280,10 @@ where
F: Fn(usize) -> X,
X: Into<f64>,
{
if length > (::core::u32::MAX as usize) {
if length > (core::u32::MAX as usize) {
sample_efraimidis_spirakis(rng, length, weight, amount)
} else {
assert!(amount <= core::u32::MAX as usize);
let amount = amount as u32;
let length = length as u32;
sample_efraimidis_spirakis(rng, length, weight, amount)
Expand Down Expand Up @@ -310,6 +311,7 @@ where
F: Fn(usize) -> X,
X: Into<f64>,
N: UInt,
IndexVec: From<Vec<N>>,
{
if amount == N::zero() {
return Ok(IndexVec::U32(Vec::new()));
Expand Down Expand Up @@ -345,14 +347,17 @@ where
#[cfg(feature = "nightly")]
{
let mut candidates = Vec::with_capacity(length.as_usize());
for index in 0..length.as_usize() {
let weight = weight(index).into();
let mut index = N::zero();
while index < length {
let weight = weight(index.as_usize()).into();
if !(weight >= 0.) {
return Err(WeightedError::InvalidWeight);
}

let key = rng.gen::<f64>().powf(1.0 / weight);
candidates.push(Element { index, key })
candidates.push(Element { index, key });

index += N::one();
}

// Partially sort the array to find the `amount` elements with the greatest
Expand All @@ -362,7 +367,7 @@ where
let (_, mid, greater)
= candidates.partition_at_index(length.as_usize() - amount.as_usize());

let mut result = Vec::with_capacity(amount.as_usize());
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
result.push(mid.index);
for element in greater {
result.push(element.index);
Expand All @@ -380,17 +385,20 @@ where
// Partially sort the array such that the `amount` elements with the largest
// keys are first using a binary max heap.
let mut candidates = BinaryHeap::with_capacity(length.as_usize());
for index in 0..length.as_usize() {
let weight = weight(index).into();
if weight < 0.0 || weight.is_nan() {
let mut index = N::zero();
while index < length {
let weight = weight(index.as_usize()).into();
if !(weight >= 0.) {
return Err(WeightedError::InvalidWeight);
}

let key = rng.gen::<f64>().powf(1.0 / weight);
candidates.push(Element { index, key });

index += N::one();
}

let mut result = Vec::with_capacity(amount.as_usize());
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
while result.len() < amount.as_usize() {
result.push(candidates.pop().unwrap().index);
}
Expand Down Expand Up @@ -462,8 +470,10 @@ where R: Rng + ?Sized {
IndexVec::from(indices)
}

trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash {
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform
+ core::hash::Hash + core::ops::AddAssign {
fn zero() -> Self;
fn one() -> Self;
fn as_usize(self) -> usize;
}
impl UInt for u32 {
Expand All @@ -472,6 +482,11 @@ impl UInt for u32 {
0
}

#[inline]
fn one() -> Self {
1
}

#[inline]
fn as_usize(self) -> usize {
self as usize
Expand All @@ -483,6 +498,11 @@ impl UInt for usize {
0
}

#[inline]
fn one() -> Self {
1
}

#[inline]
fn as_usize(self) -> usize {
self
Expand Down Expand Up @@ -602,6 +622,26 @@ mod test {
assert_eq!(v1, v2);
}

#[test]
fn test_sample_weighted() {
let seed_rng = crate::test::rng;
for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] {
let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap();
match v {
IndexVec::U32(mut indices) => {
assert_eq!(indices.len(), amount);
indices.sort();
indices.dedup();
assert_eq!(indices.len(), amount);
for &i in &indices {
assert!((i as usize) < len);
}
},
IndexVec::USize(_) => panic!("expected `IndexVec::U32`"),
}
}
}

#[test]
fn value_stability_sample() {
let do_test = |length, amount, values: &[u32]| {
Expand Down

0 comments on commit d73eac5

Please sign in to comment.