Skip to content

Commit

Permalink
Handle NaN in WeightedIndex with error instead of panic
Browse files Browse the repository at this point in the history
  • Loading branch information
wschella committed Jul 30, 2020
1 parent 39a37f0 commit 550a090
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions src/distributions/weighted_index.rs
Expand Up @@ -11,7 +11,7 @@
use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
use crate::distributions::Distribution;
use crate::Rng;
use core::cmp::PartialOrd;
use core::cmp::{Ordering, PartialOrd};
use core::fmt;

// Note that this whole module is only imported if feature="alloc" is enabled.
Expand Down Expand Up @@ -102,13 +102,13 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();

let zero = <X as Default>::default();
if total_weight < zero {
if let Some(Ordering::Less) | None = total_weight.partial_cmp(&zero) {
return Err(WeightedError::InvalidWeight);
}

let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
for w in iter {
if *w.borrow() < zero {
if let Some(Ordering::Less) | None = w.borrow().partial_cmp(&zero) {
return Err(WeightedError::InvalidWeight);
}
weights.push(total_weight.clone());
Expand Down Expand Up @@ -158,7 +158,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
return Err(WeightedError::InvalidWeight);
}
}
if *w < zero {
if let Some(Ordering::Less) | None = w.partial_cmp(&zero) {
return Err(WeightedError::InvalidWeight);
}
if i >= self.cumulative_weights.len() + 1 {
Expand Down Expand Up @@ -221,7 +221,6 @@ impl<X> Distribution<usize> for WeightedIndex<X>
where X: SampleUniform + PartialOrd
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
use ::core::cmp::Ordering;
let chosen_weight = self.weight_distribution.sample(rng);
// Find the first item which has a weight *higher* than the chosen weight.
self.cumulative_weights
Expand Down Expand Up @@ -256,6 +255,30 @@ mod test {
assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
}

#[test]
fn test_accepting_nan(){
assert_eq!(
WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
WeightedError::InvalidWeight,
);
assert_eq!(
WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight,
);
assert_eq!(
WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight,
);

assert_eq!(
WeightedIndex::new(&[0.5, 7.0])
.unwrap()
.update_weights(&[(0, &core::f32::NAN)])
.unwrap_err(),
WeightedError::InvalidWeight,
)
}


#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
Expand Down Expand Up @@ -399,8 +422,8 @@ pub enum WeightedError {
/// The provided weight collection contains no items.
NoItem,

/// A weight is either less than zero, greater than the supported maximum or
/// otherwise invalid.
/// A weight is either less than zero, greater than the supported maximum,
/// NaN, or otherwise invalid.
InvalidWeight,

/// All items in the provided weight collection are zero.
Expand Down

0 comments on commit 550a090

Please sign in to comment.