Skip to content

Commit

Permalink
Handle NaN in WeightedIndex with error instead of panic
Browse files Browse the repository at this point in the history
  • Loading branch information
wschella committed Jul 31, 2020
1 parent 39a37f0 commit fc1fc9b
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions src/distributions/weighted_index.rs
Expand Up @@ -102,13 +102,15 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();

let zero = <X as Default>::default();
if total_weight < zero {
if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}

let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
for w in iter {
if *w.borrow() < zero {
// Note that `!(w >= x)` is not equivalent to `w < x` for partially
// ordered types due to NaNs which are equal to nothing.
if !(w.borrow() >= &zero) {
return Err(WeightedError::InvalidWeight);
}
weights.push(total_weight.clone());
Expand Down Expand Up @@ -158,7 +160,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
return Err(WeightedError::InvalidWeight);
}
}
if *w < zero {
if !(*w >= zero) {
return Err(WeightedError::InvalidWeight);
}
if i >= self.cumulative_weights.len() + 1 {
Expand Down Expand Up @@ -256,6 +258,30 @@ mod test {
assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
}

#[test]
fn test_accepting_nan(){
assert_eq!(
WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
WeightedError::InvalidWeight,
);
assert_eq!(
WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight,
);
assert_eq!(
WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight,
);

assert_eq!(
WeightedIndex::new(&[0.5, 7.0])
.unwrap()
.update_weights(&[(0, &core::f32::NAN)])
.unwrap_err(),
WeightedError::InvalidWeight,
)
}


#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
Expand Down Expand Up @@ -399,8 +425,8 @@ pub enum WeightedError {
/// The provided weight collection contains no items.
NoItem,

/// A weight is either less than zero, greater than the supported maximum or
/// otherwise invalid.
/// A weight is either less than zero, greater than the supported maximum,
/// NaN, or otherwise invalid.
InvalidWeight,

/// All items in the provided weight collection are zero.
Expand Down

0 comments on commit fc1fc9b

Please sign in to comment.