Skip to content

Commit

Permalink
Added an implementation of alias method for weighted indices
Browse files Browse the repository at this point in the history
  • Loading branch information
zroug committed Jan 14, 2019
1 parent a7c2eae commit 50d3f4b
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 1 deletion.
15 changes: 15 additions & 0 deletions benches/distributions.rs
Expand Up @@ -187,6 +187,21 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0,
distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());

distr_int!(
distr_weighted_alias_method,
usize,
AliasMethodWeightedIndex::new(
vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]
).unwrap()
);
distr_int!(
distr_weighted_alias_method_large_set,
usize,
AliasMethodWeightedIndex::new(
(0..10000).rev().chain(1..10001).map(|x| x as f64).collect()
).unwrap()
);

// construct and sample from a range
macro_rules! gen_range_int {
($fnn:ident, $ty:ident, $low:expr, $high:expr) => {
Expand Down
5 changes: 4 additions & 1 deletion src/distributions/mod.rs
Expand Up @@ -181,7 +181,10 @@ pub use self::other::Alphanumeric;
#[doc(inline)] pub use self::uniform::Uniform;
pub use self::float::{OpenClosed01, Open01};
pub use self::bernoulli::Bernoulli;
#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError};
#[cfg(feature = "alloc")]
pub use self::weighted::{
AliasMethodWeightedIndex, AliasMethodWeightedIndexError, WeightedError, WeightedIndex,
};
#[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface;
#[cfg(feature="std")] pub use self::unit_circle::UnitCircle;
#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF,
Expand Down
208 changes: 208 additions & 0 deletions src/distributions/weighted.rs
Expand Up @@ -13,7 +13,9 @@ use ::core::cmp::PartialOrd;
use core::fmt;

// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature = "std"))] use alloc::collections::VecDeque;
#[cfg(not(feature="std"))] use alloc::vec::Vec;
#[cfg(feature = "std")] use std::collections::VecDeque;

/// A distribution using weighted sampling to pick a discretely selected
/// item.
Expand Down Expand Up @@ -130,6 +132,116 @@ impl<X> Distribution<usize> for WeightedIndex<X> where
}
}

#[allow(missing_docs)] // todo: add docs
#[derive(Debug, Clone)]
pub struct AliasMethodWeightedIndex {
aliases: Vec<usize>,
no_alias_odds: Vec<f64>,
uniform_index: super::Uniform<usize>,
}

impl AliasMethodWeightedIndex {
#[allow(missing_docs)] // todo: add docs
pub fn new(weights: Vec<f64>) -> Result<Self, AliasMethodWeightedIndexError> {
if weights.is_empty() {
return Err(AliasMethodWeightedIndexError::NoItem);
}
if !weights.iter().all(|&w| w >= 0.0) {
return Err(AliasMethodWeightedIndexError::InvalidWeight);
}

let n = weights.len();
let weight_sum = pairwise_sum_f64(weights.as_slice());
if weight_sum.is_infinite() {
return Err(AliasMethodWeightedIndexError::WeightSumToBig);
}

let weight_scale = n as f64 / weight_sum;
if weight_scale.is_infinite() {
return Err(AliasMethodWeightedIndexError::WeightSumToSmall);
}

let mut no_alias_odds = weights;
for odds in no_alias_odds.iter_mut() {
*odds *= weight_scale;
}

// Split indices into indices with small weights and indices with big weights.
// Instead of two `Vec` with unknown capacity we use a single `VecDeque` with
// known capacity. Front represents smalls and back represents bigs. We also
// need to keep track of the size of each virtual `Vec`.
let mut smalls_bigs = VecDeque::with_capacity(n);
let mut smalls_len = 0_usize;
let mut bigs_len = 0_usize;
for (index, &odds) in no_alias_odds.iter().enumerate() {
if odds < 1.0 {
smalls_bigs.push_front(index);
smalls_len += 1;
} else {
smalls_bigs.push_back(index);
bigs_len += 1;
}
}

let mut aliases = vec![0; n];
while smalls_len > 0 && bigs_len > 0 {
let s = smalls_bigs.pop_front().unwrap();
smalls_len -= 1;
let b = smalls_bigs.pop_back().unwrap();
bigs_len -= 1;

aliases[s] = b;
no_alias_odds[b] = no_alias_odds[s] + no_alias_odds[b] - 1.0;

if no_alias_odds[b] < 1.0 {
smalls_bigs.push_front(b);
smalls_len += 1;
} else {
smalls_bigs.push_back(b);
bigs_len += 1;
}
}

// The remaining indices should have no alias odds of about 1. This is due to
// numeric accuracy. Otherwise they would be exactly 1.
for index in smalls_bigs.into_iter() {
// Because p = 1 we don't need to set an alias. It will never be accessed.
no_alias_odds[index] = 1.0;
}

// Prepare a distribution to sample random indices. Creating it beforehand
// improves sampling performance.
let uniform_index = super::Uniform::new(0, no_alias_odds.len());

Ok(Self {
aliases,
no_alias_odds,
uniform_index,
})
}
}

impl Distribution<usize> for AliasMethodWeightedIndex {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let candidate = rng.sample(self.uniform_index);
if rng.sample::<f64, _>(super::Standard) < self.no_alias_odds[candidate] {
candidate
} else {
self.aliases[candidate]
}
}
}

fn pairwise_sum_f64(values: &[f64]) -> f64 {
if values.len() <= 32 {
values.iter().sum()
} else {
let mid = values.len() / 2;
let (a, b) = values.split_at(mid);
pairwise_sum_f64(a) + pairwise_sum_f64(b)
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -188,6 +300,66 @@ mod test {
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight);
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight);
}

#[test]
fn test_alias_method_weighted_index() {
const NUM_WEIGHTS: usize = 10;
const ZERO_WEIGHT_INDEX: usize = 3;
const NUM_SAMPLES: u32 = 10000;
let mut rng = ::test::rng(0x9c9fa0b0580a7031);

let weights = {
let mut weights = Vec::with_capacity(NUM_WEIGHTS);
for _ in 0..NUM_WEIGHTS {
weights.push(rng.sample::<f64, _>(::distributions::Standard));
}
weights[ZERO_WEIGHT_INDEX] = 0.0;
weights
};
let weight_sum = weights.iter().sum::<f64>();
let expected_counts = weights
.iter()
.map(|&w| w / weight_sum * NUM_SAMPLES as f64)
.collect::<Vec<f64>>();
let weight_distribution = AliasMethodWeightedIndex::new(weights).unwrap();

let mut counts = vec![0_usize; NUM_WEIGHTS];
for _ in 0..NUM_SAMPLES {
counts[rng.sample(&weight_distribution)] += 1;
}

assert_eq!(counts[ZERO_WEIGHT_INDEX], 0);
for (count, expected_count) in counts.into_iter().zip(expected_counts) {
let difference = (count as f64 - expected_count).abs();
let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
assert!(difference <= max_allowed_difference);
}

assert_eq!(
AliasMethodWeightedIndex::new(vec![]).unwrap_err(),
AliasMethodWeightedIndexError::NoItem
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(),
AliasMethodWeightedIndexError::WeightSumToSmall
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(),
AliasMethodWeightedIndexError::WeightSumToBig
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(),
AliasMethodWeightedIndexError::WeightSumToBig
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::NAN]).unwrap_err(),
AliasMethodWeightedIndexError::InvalidWeight
);
}
}

/// Error type returned from `WeightedIndex::new`.
Expand Down Expand Up @@ -228,3 +400,39 @@ impl fmt::Display for WeightedError {
write!(f, "{}", self.msg())
}
}

#[allow(missing_docs)] // todo: add docs
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AliasMethodWeightedIndexError {
NoItem,
InvalidWeight,
WeightSumToSmall,
WeightSumToBig,
}

impl AliasMethodWeightedIndexError {
fn msg(&self) -> &str {
match *self {
AliasMethodWeightedIndexError::NoItem => "No items found.",
AliasMethodWeightedIndexError::InvalidWeight => "An item has an invalid weight.",
AliasMethodWeightedIndexError::WeightSumToSmall => "The sum of weights is to small.",
AliasMethodWeightedIndexError::WeightSumToBig => "The sum of weights is to big.",
}
}
}

impl fmt::Display for AliasMethodWeightedIndexError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.msg())
}
}

#[cfg(feature = "std")]
impl ::std::error::Error for AliasMethodWeightedIndexError {
fn description(&self) -> &str {
self.msg()
}
fn cause(&self) -> Option<&::std::error::Error> {
None
}
}

0 comments on commit 50d3f4b

Please sign in to comment.