diff --git a/benches/distributions.rs b/benches/distributions.rs index 069a82856a5..ccf78b0310d 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -187,6 +187,11 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); +distr_int!(distr_weighted_alias_method_i8, usize, weighted::alias_method::WeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_u32, usize, weighted::alias_method::WeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_f64, usize, weighted::alias_method::WeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); +distr_int!(distr_weighted_alias_method_large_set, usize, weighted::alias_method::WeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); + // construct and sample from a range macro_rules! gen_range_int { ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 6e2d6c7bad2..c8ef0ef96d8 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -198,7 +198,7 @@ pub use self::bernoulli::Bernoulli; pub mod uniform; mod bernoulli; -#[cfg(feature="alloc")] mod weighted; +#[cfg(feature="alloc")] pub mod weighted; #[cfg(feature="std")] mod unit_sphere; #[cfg(feature="std")] mod unit_circle; #[cfg(feature="std")] mod gamma; diff --git a/src/distributions/weighted/alias_method.rs b/src/distributions/weighted/alias_method.rs new file mode 100644 index 00000000000..9fdba92ec77 --- /dev/null +++ b/src/distributions/weighted/alias_method.rs @@ -0,0 +1,484 @@ +//! This module contains an implementation of alias method for sampling random +//! indices with probabilities proportional to a collection of weights. + +use super::WeightedError; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use core::fmt; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; +use distributions::uniform::SampleUniform; +use distributions::Distribution; +use distributions::Uniform; +use Rng; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling a [`WeightedIndex`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedIndex`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which a implementation of +/// [`Weight`] exists. +/// +/// # Performance +/// +/// Given that `n` is the number of items in the vector used to create an +/// [`WeightedIndex`], [`WeightedIndex`] will require `O(n)` amount of +/// memory. More specifically it takes up some constant amount of memory plus +/// the vector used to create it and a [`Vec`] with capacity `n`. +/// +/// Time complexity for the creation of a [`WeightedIndex`] is `O(n)`. +/// Sampling is `O(1)`, it makes a call to [`Uniform::sample`] and a call +/// to [`Uniform::sample`]. +/// +/// # Example +/// +/// ``` +/// use rand::distributions::weighted::alias_method::WeightedIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 1, 1]; +/// let dist = WeightedIndex::new(weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`WeightedIndex`]: WeightedIndex +/// [`Weight`]: Weight +/// [`Vec`]: Vec +/// [`Uniform::sample`]: Distribution::sample +/// [`Uniform::sample`]: Distribution::sample +pub struct WeightedIndex { + aliases: Vec, + no_alias_odds: Vec, + uniform_index: Uniform, + uniform_within_weight_sum: Uniform, +} + +impl WeightedIndex { + /// Creates a new [`WeightedIndex`]. + /// + /// Returns an error if: + /// - The vector is empty. + /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / + /// weights.len()`. + /// - The sum of weights is zero. + pub fn new(weights: Vec) -> Result { + let n = weights.len(); + if n == 0 { + return Err(WeightedError::NoItem); + } + + let max_weight_size = W::try_from_usize_lossy(n) + .map(|n| W::MAX / n) + .unwrap_or(W::ZERO); + if !weights + .iter() + .all(|&w| W::ZERO <= w && w <= max_weight_size) + { + return Err(WeightedError::InvalidWeight); + } + + // The sum of weights will represent 100% of no alias odds. + let weight_sum = Weight::sum(weights.as_slice()); + // Prevent floating point overflow due to rounding errors. + let weight_sum = if weight_sum > W::MAX { + W::MAX + } else { + weight_sum + }; + if weight_sum == W::ZERO { + return Err(WeightedError::AllWeightsZero); + } + + // `weight_sum` would have been zero if `try_from_lossy` causes an error here. + let n_converted = W::try_from_usize_lossy(n).unwrap(); + + let mut no_alias_odds = weights; + for odds in no_alias_odds.iter_mut() { + *odds *= n_converted; + // Prevent floating point overflow due to rounding errors. + *odds = if *odds > W::MAX { W::MAX } else { *odds }; + } + + /// This struct is designed to contain three data structures at once, + /// sharing the same memory. More precisely it contains two linked lists + /// and an alias map, which will be the output of this method. To keep + /// the three data structures from getting in each other's way, it must + /// be ensured that a single index is only ever in one of them at the + /// same time. + struct Aliases { + aliases: Vec, + smalls_head: usize, + bigs_head: usize, + } + + impl Aliases { + fn new(size: usize) -> Self { + Aliases { + aliases: vec![0; size], + smalls_head: ::core::usize::MAX, + bigs_head: ::core::usize::MAX, + } + } + + fn push_small(&mut self, idx: usize) { + self.aliases[idx] = self.smalls_head; + self.smalls_head = idx; + } + + fn push_big(&mut self, idx: usize) { + self.aliases[idx] = self.bigs_head; + self.bigs_head = idx; + } + + fn pop_small(&mut self) -> usize { + let popped = self.smalls_head; + self.smalls_head = self.aliases[popped]; + popped + } + + fn pop_big(&mut self) -> usize { + let popped = self.bigs_head; + self.bigs_head = self.aliases[popped]; + popped + } + + fn smalls_is_empty(&self) -> bool { + self.smalls_head == ::core::usize::MAX + } + + fn bigs_is_empty(&self) -> bool { + self.bigs_head == ::core::usize::MAX + } + + fn set_alias(&mut self, idx: usize, alias: usize) { + self.aliases[idx] = alias; + } + } + + let mut aliases = Aliases::new(n); + + // Split indices into those with small weights and those with big weights. + for (index, &odds) in no_alias_odds.iter().enumerate() { + if odds < weight_sum { + aliases.push_small(index); + } else { + aliases.push_big(index); + } + } + + // Build the alias map by finding an alias with big weight for each index with + // small weight. + while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { + let s = aliases.pop_small(); + let b = aliases.pop_big(); + + aliases.set_alias(s, b); + no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s]; + + if no_alias_odds[b] < weight_sum { + aliases.push_small(b); + } else { + aliases.push_big(b); + } + } + + // The remaining indices should have no alias odds of about 100%. This is due to + // numeric accuracy. Otherwise they would be exactly 100%. + while !aliases.smalls_is_empty() { + no_alias_odds[aliases.pop_small()] = weight_sum; + } + while !aliases.bigs_is_empty() { + no_alias_odds[aliases.pop_big()] = weight_sum; + } + + // Prepare distributions for sampling. Creating them beforehand improves + // sampling performance. + let uniform_index = Uniform::new(0, n); + let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum); + + Ok(Self { + aliases: aliases.aliases, + no_alias_odds, + uniform_index, + uniform_within_weight_sum, + }) + } +} + +impl Distribution for WeightedIndex { + fn sample(&self, rng: &mut R) -> usize { + let candidate = rng.sample(self.uniform_index); + if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { + candidate + } else { + self.aliases[candidate] + } + } +} + +impl fmt::Debug for WeightedIndex +where + W: fmt::Debug, + Uniform: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("WeightedIndex") + .field("aliases", &self.aliases) + .field("no_alias_odds", &self.no_alias_odds) + .field("uniform_index", &self.uniform_index) + .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) + .finish() + } +} + +impl Clone for WeightedIndex +where + Uniform: Clone, +{ + fn clone(&self) -> Self { + Self { + aliases: self.aliases.clone(), + no_alias_odds: self.no_alias_odds.clone(), + uniform_index: self.uniform_index.clone(), + uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), + } + } +} + +/// Trait that must be implemented for weights, that are used with +/// [`WeightedIndex`]. Currently no guarantees on the correctness of +/// [`WeightedIndex`] are given for custom implementations of this trait. +pub trait Weight: + Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Mul + + MulAssign + + Div + + DivAssign + + Sum +{ + /// Maximum number representable by `Self`. + const MAX: Self; + + /// Element of `Self` equivalent to 0. + const ZERO: Self; + + /// Converts a [`usize`] to a `Self`, rounding if necessary. + fn try_from_usize_lossy(n: usize) -> Option; + + /// Sums all values in slice `values`. + fn sum(values: &[Self]) -> Self { + values.iter().map(|x| *x).sum() + } +} + +macro_rules! impl_weight_for_float { + ($T: ident) => { + impl Weight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0.0; + + fn try_from_usize_lossy(n: usize) -> Option { + Some(n as $T) + } + + fn sum(values: &[Self]) -> Self { + pairwise_sum(values) + } + } + }; +} + +/// In comparison to naive accumulation, the pairwise sum algorithm reduces +/// rounding errors when there are many floating point values. +fn pairwise_sum(values: &[T]) -> T { + if values.len() <= 32 { + values.iter().map(|x| *x).sum() + } else { + let mid = values.len() / 2; + let (a, b) = values.split_at(mid); + pairwise_sum(a) + pairwise_sum(b) + } +} + +macro_rules! impl_weight_for_int { + ($T: ident) => { + impl Weight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0; + + fn try_from_usize_lossy(n: usize) -> Option { + let n_converted = n as Self; + if n_converted >= Self::ZERO && n_converted as usize == n { + Some(n_converted) + } else { + None + } + } + } + }; +} + +impl_weight_for_float!(f64); +impl_weight_for_float!(f32); +impl_weight_for_int!(usize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_weight_for_int!(u128); +impl_weight_for_int!(u64); +impl_weight_for_int!(u32); +impl_weight_for_int!(u16); +impl_weight_for_int!(u8); +impl_weight_for_int!(isize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_weight_for_int!(i128); +impl_weight_for_int!(i64); +impl_weight_for_int!(i32); +impl_weight_for_int!(i16); +impl_weight_for_int!(i8); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_weighted_index_f32() { + test_weighted_index(f32::into); + + // Floating point special cases + assert_eq!( + WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![-0_f32]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(vec![-1_f32]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_weighted_index_u128() { + test_weighted_index(|x: u128| x as f64); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_weighted_index_i128() { + test_weighted_index(|x: i128| x as f64); + + // Signed integer special cases + assert_eq!( + WeightedIndex::new(vec![-1_i128]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[test] + fn test_weighted_index_u8() { + test_weighted_index(u8::into); + } + + #[test] + fn test_weighted_index_i8() { + test_weighted_index(i8::into); + + // Signed integer special cases + assert_eq!( + WeightedIndex::new(vec![-1_i8]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + fn test_weighted_index f64>(w_to_f64: F) + where + WeightedIndex: fmt::Debug, + { + const NUM_WEIGHTS: usize = 10; + const ZERO_WEIGHT_INDEX: usize = 3; + const NUM_SAMPLES: u32 = 15000; + let mut rng = ::test::rng(0x9c9fa0b0580a7031); + + let weights = { + let mut weights = Vec::with_capacity(NUM_WEIGHTS); + let random_weight_distribution = ::distributions::Uniform::new_inclusive( + W::ZERO, + W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(), + ); + for _ in 0..NUM_WEIGHTS { + weights.push(rng.sample(&random_weight_distribution)); + } + weights[ZERO_WEIGHT_INDEX] = W::ZERO; + weights + }; + let weight_sum = weights.iter().map(|w| *w).sum::(); + let expected_counts = weights + .iter() + .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) + .collect::>(); + let weight_distribution = WeightedIndex::new(weights).unwrap(); + + let mut counts = vec![0_usize; NUM_WEIGHTS]; + for _ in 0..NUM_SAMPLES { + counts[rng.sample(&weight_distribution)] += 1; + } + + assert_eq!(counts[ZERO_WEIGHT_INDEX], 0); + for (count, expected_count) in counts.into_iter().zip(expected_counts) { + let difference = (count as f64 - expected_count).abs(); + let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; + assert!(difference <= max_allowed_difference); + } + + assert_eq!( + WeightedIndex::::new(vec![]).unwrap_err(), + WeightedError::NoItem + ); + assert_eq!( + WeightedIndex::new(vec![W::ZERO]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), + WeightedError::InvalidWeight + ); + } +} diff --git a/src/distributions/weighted.rs b/src/distributions/weighted/mod.rs similarity index 90% rename from src/distributions/weighted.rs rename to src/distributions/weighted/mod.rs index d7499596e39..b58cada9c25 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted/mod.rs @@ -6,6 +6,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +//! This module contains different algorithms for sampling random indices with +//! probabilities proportional to a collection of weights. + +pub mod alias_method; + use Rng; use distributions::Distribution; use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; @@ -98,13 +103,13 @@ impl WeightedIndex { let zero = ::default(); if total_weight < zero { - return Err(WeightedError::NegativeWeight); + return Err(WeightedError::InvalidWeight); } let mut weights = Vec::::with_capacity(iter.size_hint().0); for w in iter { if *w.borrow() < zero { - return Err(WeightedError::NegativeWeight); + return Err(WeightedError::InvalidWeight); } weights.push(total_weight.clone()); total_weight += w.borrow(); @@ -184,31 +189,32 @@ mod test { assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); - assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight); - assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); - assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); + assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight); + assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight); + assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight); } } /// Error type returned from `WeightedIndex::new`. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum WeightedError { - /// The provided iterator contained no items. + /// The provided weight collection contains no items. NoItem, - /// A weight lower than zero was used. - NegativeWeight, + /// A weight is either less than zero, greater than the supported maximum or + /// otherwise invalid. + InvalidWeight, - /// All items in the provided iterator had a weight of zero. + /// All items in the provided weight collection are zero. AllWeightsZero, } impl WeightedError { fn msg(&self) -> &str { match *self { - WeightedError::NoItem => "No items found", - WeightedError::NegativeWeight => "Item has negative weight", - WeightedError::AllWeightsZero => "All items had weight zero", + WeightedError::NoItem => "No weights provided.", + WeightedError::InvalidWeight => "A weight is invalid.", + WeightedError::AllWeightsZero => "All weights are zero.", } } } diff --git a/src/seq/mod.rs b/src/seq/mod.rs index d0f83084a66..72fb211e08b 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -823,7 +823,7 @@ mod test { assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), Err(WeightedError::NoItem)); assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), Err(WeightedError::NoItem)); assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), Err(WeightedError::AllWeightsZero)); - assert_eq!([0, -1].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight)); - assert_eq!([-1, 0].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight)); + assert_eq!([0, -1].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::InvalidWeight)); + assert_eq!([-1, 0].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::InvalidWeight)); } }