diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index a75d41eae0c..1dc003bdc0e 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -102,13 +102,15 @@ impl WeightedIndex { let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); let zero = ::default(); - if total_weight < zero { + if !(total_weight >= zero) { return Err(WeightedError::InvalidWeight); } let mut weights = Vec::::with_capacity(iter.size_hint().0); for w in iter { - if *w.borrow() < zero { + // Note that `!(w >= x)` is not equivalent to `w < x` for partially + // ordered types due to NaNs which are equal to nothing. + if !(w.borrow() >= &zero) { return Err(WeightedError::InvalidWeight); } weights.push(total_weight.clone()); @@ -158,7 +160,7 @@ impl WeightedIndex { return Err(WeightedError::InvalidWeight); } } - if *w < zero { + if !(*w >= zero) { return Err(WeightedError::InvalidWeight); } if i >= self.cumulative_weights.len() + 1 { @@ -256,6 +258,30 @@ mod test { assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); } + #[test] + fn test_accepting_nan(){ + assert_eq!( + WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), + WeightedError::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight, + ); + + assert_eq!( + WeightedIndex::new(&[0.5, 7.0]) + .unwrap() + .update_weights(&[(0, &core::f32::NAN)]) + .unwrap_err(), + WeightedError::InvalidWeight, + ) + } + #[test] #[cfg_attr(miri, ignore)] // Miri is too slow @@ -399,8 +425,8 @@ pub enum WeightedError { /// The provided weight collection contains no items. NoItem, - /// A weight is either less than zero, greater than the supported maximum or - /// otherwise invalid. + /// A weight is either less than zero, greater than the supported maximum, + /// NaN, or otherwise invalid. InvalidWeight, /// All items in the provided weight collection are zero.