diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 2684ae53f64..77f9cd5f730 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -177,7 +177,7 @@ use Rng; #[doc(inline)] pub use self::uniform::Uniform; #[doc(inline)] pub use self::float::{OpenClosed01, Open01}; #[cfg(feature="alloc")] -#[doc(inline)] pub use self::weighted::WeightedIndex; +#[doc(inline)] pub use self::weighted::{WeightedIndex, WeightedError}; #[cfg(feature="std")] #[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT}; #[cfg(feature="std")] diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 749dd0e9343..64d862987c3 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -12,7 +12,7 @@ use Rng; use distributions::Distribution; use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; use ::core::cmp::PartialOrd; -use ::{Error, ErrorKind}; +use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. #[cfg(not(feature="std"))] use alloc::vec::Vec; @@ -63,7 +63,7 @@ impl WeightedIndex { /// /// [`Distribution`]: trait.Distribution.html /// [`Uniform`]: struct.Uniform.html - pub fn new(weights: I) -> Result, Error> + pub fn new(weights: I) -> Result, WeightedError> where I: IntoIterator, I::Item: SampleBorrow, X: for<'a> ::core::ops::AddAssign<&'a X> + @@ -71,26 +71,26 @@ impl WeightedIndex { Default { let mut iter = weights.into_iter(); let mut total_weight: X = iter.next() - .ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))? + .ok_or(WeightedError::NoItem)? .borrow() .clone(); let zero = ::default(); if total_weight < zero { - return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new")); + return Err(WeightedError::NegativeWeight); } let mut weights = Vec::::with_capacity(iter.size_hint().0); for w in iter { if *w.borrow() < zero { - return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new")); + return Err(WeightedError::NegativeWeight); } weights.push(total_weight.clone()); total_weight += w.borrow(); } if total_weight == zero { - return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new")); + return Err(WeightedError::AllWeightsZero); } let distr = X::Sampler::new(zero, total_weight); @@ -161,10 +161,43 @@ mod test { assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4); } - assert!(WeightedIndex::new(&[10][0..0]).is_err()); - assert!(WeightedIndex::new(&[0]).is_err()); - assert!(WeightedIndex::new(&[10, 20, -1, 30]).is_err()); - assert!(WeightedIndex::new(&[-10, 20, 1, 30]).is_err()); - assert!(WeightedIndex::new(&[-10]).is_err()); + assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); + assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); + assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight); + assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); + assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WeightedError { + NoItem, + NegativeWeight, + AllWeightsZero, +} + +impl WeightedError { + fn msg(&self) -> &str { + match *self { + WeightedError::NoItem => "No items found", + WeightedError::NegativeWeight => "Item has negative weight", + WeightedError::AllWeightsZero => "All items had weight zero", + } + } +} + +#[cfg(feature="std")] +impl ::std::error::Error for WeightedError { + fn description(&self) -> &str { + self.msg() + } + fn cause(&self) -> Option<&::std::error::Error> { + None + } +} + +impl fmt::Display for WeightedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.msg()) } } diff --git a/src/seq.rs b/src/seq.rs index 0895b1afcc1..33d98061ada 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -21,6 +21,8 @@ #[cfg(feature="std")] use std::collections::HashMap; #[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeMap; +#[cfg(feature = "alloc")] use distributions::WeightedError; + use super::Rng; #[cfg(feature="alloc")] use distributions::uniform::{SampleUniform, SampleBorrow}; @@ -109,7 +111,7 @@ pub trait SliceRandom { /// ``` /// [`choose`]: trait.SliceRandom.html#method.choose #[cfg(feature = "alloc")] - fn choose_weighted(&self, rng: &mut R, weight: F) -> Option<&Self::Item> + fn choose_weighted(&self, rng: &mut R, weight: F) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, @@ -129,7 +131,7 @@ pub trait SliceRandom { /// [`choose_mut`]: trait.SliceRandom.html#method.choose_mut /// [`choose_weighted`]: trait.SliceRandom.html#method.choose_weighted #[cfg(feature = "alloc")] - fn choose_weighted_mut(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item> + fn choose_weighted_mut(&mut self, rng: &mut R, weight: F) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, @@ -327,7 +329,7 @@ impl SliceRandom for [T] { } #[cfg(feature = "alloc")] - fn choose_weighted(&self, rng: &mut R, weight: F) -> Option<&Self::Item> + fn choose_weighted(&self, rng: &mut R, weight: F) -> Result<&Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, @@ -337,12 +339,12 @@ impl SliceRandom for [T] { Clone + Default { use distributions::{Distribution, WeightedIndex}; - WeightedIndex::new(self.iter().map(weight)).ok() - .map(|distr| &self[distr.sample(rng)]) + let distr = WeightedIndex::new(self.iter().map(weight))?; + Ok(&self[distr.sample(rng)]) } #[cfg(feature = "alloc")] - fn choose_weighted_mut(&mut self, rng: &mut R, weight: F) -> Option<&mut Self::Item> + fn choose_weighted_mut(&mut self, rng: &mut R, weight: F) -> Result<&mut Self::Item, WeightedError> where R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, @@ -352,9 +354,8 @@ impl SliceRandom for [T] { Clone + Default { use distributions::{Distribution, WeightedIndex}; - WeightedIndex::new(self.iter().map(weight)).ok() - .map(|distr| distr.sample(rng)) - .map(move |ix| &mut self[ix]) + let distr = WeightedIndex::new(self.iter().map(weight))?; + Ok(&mut self[distr.sample(rng)]) } fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized @@ -868,8 +869,10 @@ mod test { // Check error cases let empty_slice = &mut [10][0..0]; - assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), None); - assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), None); - assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), None); + assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), Err(WeightedError::NoItem)); + assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), Err(WeightedError::NoItem)); + assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), Err(WeightedError::AllWeightsZero)); + assert_eq!([0, -1].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight)); + assert_eq!([-1, 0].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight)); } }