diff --git a/src/distributions/weighted/alias_method.rs b/src/distributions/weighted/alias_method.rs index fef4b619b0d..2e58d84ceca 100644 --- a/src/distributions/weighted/alias_method.rs +++ b/src/distributions/weighted/alias_method.rs @@ -25,10 +25,10 @@ use Rng; /// Given that `n` is the number of items in the vector used to create an /// [`WeightedIndex`], [`WeightedIndex`] will require `O(n)` amount of /// memory. More specifically it takes up some constant amount of memory plus -/// the vector used to create it and a [`Vec`] with capacity `n`. +/// the vector used to create it and a [`Vec`] with capacity `n`. /// /// Time complexity for the creation of a [`WeightedIndex`] is `O(n)`. -/// Sampling is `O(1)`, it makes a call to [`Uniform::sample`] and a call +/// Sampling is `O(1)`, it makes a call to [`Uniform::sample`] and a call /// to [`Uniform::sample`]. /// /// # Example @@ -56,13 +56,13 @@ use Rng; /// /// [`WeightedIndex`]: crate::distributions::weighted::alias_method::WeightedIndex /// [`Weight`]: crate::distributions::weighted::alias_method::Weight -/// [`Vec`]: Vec -/// [`Uniform::sample`]: Distribution::sample +/// [`Vec`]: Vec +/// [`Uniform::sample`]: Distribution::sample /// [`Uniform::sample`]: Distribution::sample pub struct WeightedIndex { - aliases: Vec, + aliases: Vec, no_alias_odds: Vec, - uniform_index: Uniform, + uniform_index: Uniform, uniform_within_weight_sum: Uniform, } @@ -71,6 +71,7 @@ impl WeightedIndex { /// /// Returns an error if: /// - The vector is empty. + /// - The vector is longer than `u32::MAX`. /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / /// weights.len()`. /// - The sum of weights is zero. @@ -78,9 +79,12 @@ impl WeightedIndex { let n = weights.len(); if n == 0 { return Err(WeightedError::NoItem); + } else if n > ::core::u32::MAX as usize { + return Err(WeightedError::TooMany); } + let n = n as u32; - let max_weight_size = W::try_from_usize_lossy(n) + let max_weight_size = W::try_from_u32_lossy(n) .map(|n| W::MAX / n) .unwrap_or(W::ZERO); if !weights @@ -103,7 +107,7 @@ impl WeightedIndex { } // `weight_sum` would have been zero if `try_from_lossy` causes an error here. - let n_converted = W::try_from_usize_lossy(n).unwrap(); + let n_converted = W::try_from_u32_lossy(n).unwrap(); let mut no_alias_odds = weights; for odds in no_alias_odds.iter_mut() { @@ -119,52 +123,52 @@ impl WeightedIndex { /// be ensured that a single index is only ever in one of them at the /// same time. struct Aliases { - aliases: Vec, - smalls_head: usize, - bigs_head: usize, + aliases: Vec, + smalls_head: u32, + bigs_head: u32, } impl Aliases { - fn new(size: usize) -> Self { + fn new(size: u32) -> Self { Aliases { - aliases: vec![0; size], - smalls_head: ::core::usize::MAX, - bigs_head: ::core::usize::MAX, + aliases: vec![0; size as usize], + smalls_head: ::core::u32::MAX, + bigs_head: ::core::u32::MAX, } } - fn push_small(&mut self, idx: usize) { - self.aliases[idx] = self.smalls_head; + fn push_small(&mut self, idx: u32) { + self.aliases[idx as usize] = self.smalls_head; self.smalls_head = idx; } - fn push_big(&mut self, idx: usize) { - self.aliases[idx] = self.bigs_head; + fn push_big(&mut self, idx: u32) { + self.aliases[idx as usize] = self.bigs_head; self.bigs_head = idx; } - fn pop_small(&mut self) -> usize { + fn pop_small(&mut self) -> u32 { let popped = self.smalls_head; - self.smalls_head = self.aliases[popped]; + self.smalls_head = self.aliases[popped as usize]; popped } - fn pop_big(&mut self) -> usize { + fn pop_big(&mut self) -> u32 { let popped = self.bigs_head; - self.bigs_head = self.aliases[popped]; + self.bigs_head = self.aliases[popped as usize]; popped } fn smalls_is_empty(&self) -> bool { - self.smalls_head == ::core::usize::MAX + self.smalls_head == ::core::u32::MAX } fn bigs_is_empty(&self) -> bool { - self.bigs_head == ::core::usize::MAX + self.bigs_head == ::core::u32::MAX } - fn set_alias(&mut self, idx: usize, alias: usize) { - self.aliases[idx] = alias; + fn set_alias(&mut self, idx: u32, alias: u32) { + self.aliases[idx as usize] = alias; } } @@ -173,9 +177,9 @@ impl WeightedIndex { // Split indices into those with small weights and those with big weights. for (index, &odds) in no_alias_odds.iter().enumerate() { if odds < weight_sum { - aliases.push_small(index); + aliases.push_small(index as u32); } else { - aliases.push_big(index); + aliases.push_big(index as u32); } } @@ -186,9 +190,11 @@ impl WeightedIndex { let b = aliases.pop_big(); aliases.set_alias(s, b); - no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s]; + no_alias_odds[b as usize] = no_alias_odds[b as usize] + - weight_sum + + no_alias_odds[s as usize]; - if no_alias_odds[b] < weight_sum { + if no_alias_odds[b as usize] < weight_sum { aliases.push_small(b); } else { aliases.push_big(b); @@ -198,10 +204,10 @@ impl WeightedIndex { // The remaining indices should have no alias odds of about 100%. This is due to // numeric accuracy. Otherwise they would be exactly 100%. while !aliases.smalls_is_empty() { - no_alias_odds[aliases.pop_small()] = weight_sum; + no_alias_odds[aliases.pop_small() as usize] = weight_sum; } while !aliases.bigs_is_empty() { - no_alias_odds[aliases.pop_big()] = weight_sum; + no_alias_odds[aliases.pop_big() as usize] = weight_sum; } // Prepare distributions for sampling. Creating them beforehand improves @@ -221,10 +227,10 @@ impl WeightedIndex { impl Distribution for WeightedIndex { fn sample(&self, rng: &mut R) -> usize { let candidate = rng.sample(self.uniform_index); - if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { - candidate + if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] { + candidate as usize } else { - self.aliases[candidate] + self.aliases[candidate as usize] as usize } } } @@ -282,10 +288,10 @@ pub trait Weight: /// Element of `Self` equivalent to 0. const ZERO: Self; - /// Produce an instance of `Self` from a `usize` value, or return `None` if + /// Produce an instance of `Self` from a `u32` value, or return `None` if /// out of range. Loss of precision (where `Self` is a floating point type) /// is acceptable. - fn try_from_usize_lossy(n: usize) -> Option; + fn try_from_u32_lossy(n: u32) -> Option; /// Sums all values in slice `values`. fn sum(values: &[Self]) -> Self { @@ -299,7 +305,7 @@ macro_rules! impl_weight_for_float { const MAX: Self = ::core::$T::MAX; const ZERO: Self = 0.0; - fn try_from_usize_lossy(n: usize) -> Option { + fn try_from_u32_lossy(n: u32) -> Option { Some(n as $T) } @@ -328,9 +334,9 @@ macro_rules! impl_weight_for_int { const MAX: Self = ::core::$T::MAX; const ZERO: Self = 0; - fn try_from_usize_lossy(n: usize) -> Option { + fn try_from_u32_lossy(n: u32) -> Option { let n_converted = n as Self; - if n_converted >= Self::ZERO && n_converted as usize == n { + if n_converted >= Self::ZERO && n_converted as u32 == n { Some(n_converted) } else { None @@ -439,21 +445,21 @@ mod test { where WeightedIndex: fmt::Debug, { - const NUM_WEIGHTS: usize = 10; - const ZERO_WEIGHT_INDEX: usize = 3; + const NUM_WEIGHTS: u32 = 10; + const ZERO_WEIGHT_INDEX: u32 = 3; const NUM_SAMPLES: u32 = 15000; let mut rng = ::test::rng(0x9c9fa0b0580a7031); let weights = { - let mut weights = Vec::with_capacity(NUM_WEIGHTS); + let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize); let random_weight_distribution = ::distributions::Uniform::new_inclusive( W::ZERO, - W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(), + W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), ); for _ in 0..NUM_WEIGHTS { weights.push(rng.sample(&random_weight_distribution)); } - weights[ZERO_WEIGHT_INDEX] = W::ZERO; + weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO; weights }; let weight_sum = weights.iter().map(|w| *w).sum::(); @@ -463,12 +469,12 @@ mod test { .collect::>(); let weight_distribution = WeightedIndex::new(weights).unwrap(); - let mut counts = vec![0_usize; NUM_WEIGHTS]; + let mut counts = vec![0; NUM_WEIGHTS as usize]; for _ in 0..NUM_SAMPLES { counts[rng.sample(&weight_distribution)] += 1; } - assert_eq!(counts[ZERO_WEIGHT_INDEX], 0); + assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0); for (count, expected_count) in counts.into_iter().zip(expected_counts) { let difference = (count as f64 - expected_count).abs(); let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index df388e70aab..6cd66d1bc15 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -208,6 +208,9 @@ pub enum WeightedError { /// All items in the provided weight collection are zero. AllWeightsZero, + + /// Too many weights are provided (length greater than `u32::MAX`) + TooMany, } impl WeightedError { @@ -216,6 +219,7 @@ impl WeightedError { WeightedError::NoItem => "No weights provided.", WeightedError::InvalidWeight => "A weight is invalid.", WeightedError::AllWeightsZero => "All weights are zero.", + WeightedError::TooMany => "Too many weights (hit u32::MAX)", } } }