From 9e90df35f7d06a29709ef1ea2efabaf3d51463a8 Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Mon, 14 Jan 2019 21:53:09 +0100 Subject: [PATCH] Added an implementation of alias method for weighted indices --- benches/distributions.rs | 15 +++ src/distributions/mod.rs | 5 +- src/distributions/weighted.rs | 208 ++++++++++++++++++++++++++++++++++ 3 files changed, 227 insertions(+), 1 deletion(-) diff --git a/benches/distributions.rs b/benches/distributions.rs index 069a82856a5..cf6e4a6558f 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -187,6 +187,21 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); +distr_int!( + distr_weighted_alias_method, + usize, + AliasMethodWeightedIndex::new( + vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001] + ).unwrap() +); +distr_int!( + distr_weighted_alias_method_large_set, + usize, + AliasMethodWeightedIndex::new( + (0..10000).rev().chain(1..10001).map(|x| x as f64).collect() + ).unwrap() +); + // construct and sample from a range macro_rules! gen_range_int { ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 6e2d6c7bad2..06b578e989f 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -181,7 +181,10 @@ pub use self::other::Alphanumeric; #[doc(inline)] pub use self::uniform::Uniform; pub use self::float::{OpenClosed01, Open01}; pub use self::bernoulli::Bernoulli; -#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError}; +#[cfg(feature = "alloc")] +pub use self::weighted::{ + AliasMethodWeightedIndex, AliasMethodWeightedIndexError, WeightedError, WeightedIndex, +}; #[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface; #[cfg(feature="std")] pub use self::unit_circle::UnitCircle; #[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index d7499596e39..7606bcd10e6 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -13,7 +13,9 @@ use ::core::cmp::PartialOrd; use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. +#[cfg(not(feature = "std"))] use alloc::collections::VecDeque; #[cfg(not(feature="std"))] use alloc::vec::Vec; +#[cfg(feature = "std")] use std::collections::VecDeque; /// A distribution using weighted sampling to pick a discretely selected /// item. @@ -130,6 +132,116 @@ impl Distribution for WeightedIndex where } } +#[allow(missing_docs)] // todo: add docs +#[derive(Debug, Clone)] +pub struct AliasMethodWeightedIndex { + aliases: Vec, + no_alias_odds: Vec, + uniform_index: super::Uniform, +} + +impl AliasMethodWeightedIndex { + #[allow(missing_docs)] // todo: add docs + pub fn new(weights: Vec) -> Result { + if weights.is_empty() { + return Err(AliasMethodWeightedIndexError::NoItem); + } + if !weights.iter().all(|&w| w >= 0.0) { + return Err(AliasMethodWeightedIndexError::InvalidWeight); + } + + let n = weights.len(); + let weight_sum = pairwise_sum_f64(weights.as_slice()); + if weight_sum.is_infinite() { + return Err(AliasMethodWeightedIndexError::WeightSumToBig); + } + + let weight_scale = n as f64 / weight_sum; + if weight_scale.is_infinite() { + return Err(AliasMethodWeightedIndexError::WeightSumToSmall); + } + + let mut no_alias_odds = weights; + for odds in no_alias_odds.iter_mut() { + *odds *= weight_scale; + } + + // Split indices into indices with small weights and indices with big weights. + // Instead of two `Vec` with unknown capacity we use a single `VecDeque` with + // known capacity. Front represents smalls and back represents bigs. We also + // need to keep track of the size of each virtual `Vec`. + let mut smalls_bigs = VecDeque::with_capacity(n); + let mut smalls_len = 0_usize; + let mut bigs_len = 0_usize; + for (index, &odds) in no_alias_odds.iter().enumerate() { + if odds < 1.0 { + smalls_bigs.push_front(index); + smalls_len += 1; + } else { + smalls_bigs.push_back(index); + bigs_len += 1; + } + } + + let mut aliases = vec![0; n]; + while smalls_len > 0 && bigs_len > 0 { + let s = smalls_bigs.pop_front().unwrap(); + smalls_len -= 1; + let b = smalls_bigs.pop_back().unwrap(); + bigs_len -= 1; + + aliases[s] = b; + no_alias_odds[b] = no_alias_odds[s] + no_alias_odds[b] - 1.0; + + if no_alias_odds[b] < 1.0 { + smalls_bigs.push_front(b); + smalls_len += 1; + } else { + smalls_bigs.push_back(b); + bigs_len += 1; + } + } + + // The remaining indices should have no alias odds of about 1. This is due to + // numeric accuracy. Otherwise they would be exactly 1. + for index in smalls_bigs.into_iter() { + // Because p = 1 we don't need to set an alias. It will never be accessed. + no_alias_odds[index] = 1.0; + } + + // Prepare a distribution to sample random indices. Creating it beforehand + // improves sampling performance. + let uniform_index = super::Uniform::new(0, no_alias_odds.len()); + + Ok(Self { + aliases, + no_alias_odds, + uniform_index, + }) + } +} + +impl Distribution for AliasMethodWeightedIndex { + fn sample(&self, rng: &mut R) -> usize { + let candidate = rng.sample(self.uniform_index); + if rng.sample::(super::Standard) < self.no_alias_odds[candidate] { + candidate + } else { + self.aliases[candidate] + } + } +} + +fn pairwise_sum_f64(values: &[f64]) -> f64 { + if values.len() <= 32 { + values.iter().sum() + } else { + let mid = values.len() / 2; + let (a, b) = values.split_at(mid); + pairwise_sum_f64(a) + pairwise_sum_f64(b) + } +} + #[cfg(test)] mod test { use super::*; @@ -188,6 +300,66 @@ mod test { assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); } + + #[test] + fn test_alias_method_weighted_index() { + const NUM_WEIGHTS: usize = 10; + const ZERO_WEIGHT_INDEX: usize = 3; + const NUM_SAMPLES: u32 = 10000; + let mut rng = ::test::rng(0x9c9fa0b0580a7031); + + let weights = { + let mut weights = Vec::new(); + weights.resize_with(NUM_WEIGHTS, || { + rng.sample::(::distributions::Standard) + }); + weights[ZERO_WEIGHT_INDEX] = 0.0; + weights + }; + let weight_sum = weights.iter().sum::(); + let expected_counts = weights + .iter() + .map(|&w| w / weight_sum * NUM_SAMPLES as f64) + .collect::>(); + let weight_distribution = AliasMethodWeightedIndex::new(weights).unwrap(); + + let mut counts = vec![0_usize; NUM_WEIGHTS]; + for _ in 0..NUM_SAMPLES { + counts[rng.sample(&weight_distribution)] += 1; + } + + assert_eq!(counts[ZERO_WEIGHT_INDEX], 0); + for (count, expected_count) in counts.into_iter().zip(expected_counts) { + let difference = (count as f64 - expected_count).abs(); + let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; + assert!(difference <= max_allowed_difference); + } + + assert_eq!( + AliasMethodWeightedIndex::new(vec![]).unwrap_err(), + AliasMethodWeightedIndexError::NoItem + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(), + AliasMethodWeightedIndexError::WeightSumToSmall + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![core::f64::INFINITY]).unwrap_err(), + AliasMethodWeightedIndexError::WeightSumToBig + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![core::f64::MAX, core::f64::MAX]).unwrap_err(), + AliasMethodWeightedIndexError::WeightSumToBig + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![core::f64::NAN]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } } /// Error type returned from `WeightedIndex::new`. @@ -228,3 +400,39 @@ impl fmt::Display for WeightedError { write!(f, "{}", self.msg()) } } + +#[allow(missing_docs)] // todo: add docs +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AliasMethodWeightedIndexError { + NoItem, + InvalidWeight, + WeightSumToSmall, + WeightSumToBig, +} + +impl AliasMethodWeightedIndexError { + fn msg(&self) -> &str { + match *self { + AliasMethodWeightedIndexError::NoItem => "No items found.", + AliasMethodWeightedIndexError::InvalidWeight => "An item has an invalid weight.", + AliasMethodWeightedIndexError::WeightSumToSmall => "The sum of weights is to small.", + AliasMethodWeightedIndexError::WeightSumToBig => "The sum of weights is to big.", + } + } +} + +impl fmt::Display for AliasMethodWeightedIndexError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.msg()) + } +} + +#[cfg(feature = "std")] +impl ::std::error::Error for AliasMethodWeightedIndexError { + fn description(&self) -> &str { + self.msg() + } + fn cause(&self) -> Option<&::std::error::Error> { + None + } +}