From 73d86e5ed40e8067dad44eccb131c98383322bd8 Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Thu, 24 Jan 2019 14:41:44 +0100 Subject: [PATCH] AliasMethodWeightedIndex: Make implementation details more generic --- src/distributions/weighted.rs | 67 +++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 06637a7e987..96c2fcc747d 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -136,32 +136,35 @@ pub struct AliasMethodWeightedIndex { aliases: Vec, no_alias_odds: Vec, uniform_index: super::Uniform, + uniform_within_weight_sum: super::Uniform, } impl AliasMethodWeightedIndex { #[allow(missing_docs)] // todo: add docs pub fn new(weights: Vec) -> Result { - if weights.is_empty() { + let n = weights.len(); + if n == 0 { return Err(AliasMethodWeightedIndexError::NoItem); } - if !weights.iter().all(|&w| w >= 0.0) { + + let max_weight_size = ::core::f64::MAX / n as f64; + if !weights.iter().all(|&w| 0_f64 <= w && w <= max_weight_size) { return Err(AliasMethodWeightedIndexError::InvalidWeight); } - let n = weights.len(); + // The sum of weights will represent 100% of no alias odds. let weight_sum = pairwise_sum_f64(weights.as_slice()); - if weight_sum.is_infinite() { - return Err(AliasMethodWeightedIndexError::WeightSumToBig); - } - - let weight_scale = n as f64 / weight_sum; - if weight_scale.is_infinite() { - return Err(AliasMethodWeightedIndexError::WeightSumToSmall); + // Prevent floating point overflow due to rounding errors. + let weight_sum = weight_sum.min(::core::f64::MAX); + if weight_sum == 0_f64 { + return Err(AliasMethodWeightedIndexError::AllWeightsZero); } let mut no_alias_odds = weights; for odds in no_alias_odds.iter_mut() { - *odds *= weight_scale; + *odds *= n as f64; + // Prevent floating point overflow due to rounding errors. + *odds = odds.min(::core::f64::MAX); } /// This struct is designed to contain three data structures at once, @@ -224,7 +227,7 @@ impl AliasMethodWeightedIndex { // Split indices into those with small weights and those with big weights. for (index, &odds) in no_alias_odds.iter().enumerate() { - if odds < 1.0 { + if odds < weight_sum { aliases.push_small(index); } else { aliases.push_big(index); @@ -238,32 +241,34 @@ impl AliasMethodWeightedIndex { let b = aliases.pop_big(); aliases.set_alias(s, b); - no_alias_odds[b] = no_alias_odds[b] - 1.0 + no_alias_odds[s]; + no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s]; - if no_alias_odds[b] < 1.0 { + if no_alias_odds[b] < weight_sum { aliases.push_small(b); } else { aliases.push_big(b); } } - // The remaining indices should have no alias odds of about 1. This is due to - // numeric accuracy. Otherwise they would be exactly 1. + // The remaining indices should have no alias odds of about 100%. This is due to + // numeric accuracy. Otherwise they would be exactly 100%. while !aliases.smalls_is_empty() { - no_alias_odds[aliases.pop_small()] = 1.0; + no_alias_odds[aliases.pop_small()] = weight_sum; } while !aliases.bigs_is_empty() { - no_alias_odds[aliases.pop_big()] = 1.0; + no_alias_odds[aliases.pop_big()] = weight_sum; } - // Prepare a distribution to sample random indices. Creating it beforehand - // improves sampling performance. + // Prepare distributions for sampling. Creating them beforehand improves + // sampling performance. let uniform_index = super::Uniform::new(0, n); + let uniform_within_weight_sum = super::Uniform::new(0_f64, weight_sum); Ok(Self { aliases: aliases.aliases, no_alias_odds, uniform_index, + uniform_within_weight_sum, }) } } @@ -271,7 +276,7 @@ impl AliasMethodWeightedIndex { impl Distribution for AliasMethodWeightedIndex { fn sample(&self, rng: &mut R) -> usize { let candidate = rng.sample(self.uniform_index); - if rng.sample::(super::Standard) < self.no_alias_odds[candidate] { + if rng.sample(self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { candidate } else { self.aliases[candidate] @@ -357,8 +362,12 @@ mod test { let weights = { let mut weights = Vec::with_capacity(NUM_WEIGHTS); + let random_weight_distribution = ::distributions::Uniform::new_inclusive( + 0_f64, + ::core::f64::MAX / NUM_WEIGHTS as f64, + ); for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample::(::distributions::Standard)); + weights.push(rng.sample(random_weight_distribution)); } weights[ZERO_WEIGHT_INDEX] = 0.0; weights @@ -388,19 +397,19 @@ mod test { ); assert_eq!( AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToSmall + AliasMethodWeightedIndexError::AllWeightsZero ); assert_eq!( AliasMethodWeightedIndex::new(vec![-0.0]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToSmall + AliasMethodWeightedIndexError::AllWeightsZero ); assert_eq!( AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToBig + AliasMethodWeightedIndexError::InvalidWeight ); assert_eq!( AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToBig + AliasMethodWeightedIndexError::InvalidWeight ); assert_eq!( AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(), @@ -461,8 +470,7 @@ impl fmt::Display for WeightedError { pub enum AliasMethodWeightedIndexError { NoItem, InvalidWeight, - WeightSumToSmall, - WeightSumToBig, + AllWeightsZero, } impl AliasMethodWeightedIndexError { @@ -470,8 +478,7 @@ impl AliasMethodWeightedIndexError { match *self { AliasMethodWeightedIndexError::NoItem => "No items found.", AliasMethodWeightedIndexError::InvalidWeight => "An item has an invalid weight.", - AliasMethodWeightedIndexError::WeightSumToSmall => "The sum of weights is to small.", - AliasMethodWeightedIndexError::WeightSumToBig => "The sum of weights is to big.", + AliasMethodWeightedIndexError::AllWeightsZero => "All weights are zero.", } } }