Skip to content

Commit

Permalink
Make AliasMethodWeightedIndex generic
Browse files Browse the repository at this point in the history
  • Loading branch information
zroug committed Jan 24, 2019
1 parent 73d86e5 commit 14feb99
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 64 deletions.
18 changes: 4 additions & 14 deletions benches/distributions.rs
Expand Up @@ -187,20 +187,10 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0,
distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());

distr_int!(
distr_weighted_alias_method,
usize,
AliasMethodWeightedIndex::new(
vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]
).unwrap()
);
distr_int!(
distr_weighted_alias_method_large_set,
usize,
AliasMethodWeightedIndex::new(
(0..10000).rev().chain(1..10001).map(|x| x as f64).collect()
).unwrap()
);
distr_int!(distr_weighted_alias_method_i8, usize, AliasMethodWeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_alias_method_u32, usize, AliasMethodWeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_alias_method_f64, usize, AliasMethodWeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_alias_method_large_set, usize, AliasMethodWeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap());

// construct and sample from a range
macro_rules! gen_range_int {
Expand Down
245 changes: 195 additions & 50 deletions src/distributions/weighted.rs
Expand Up @@ -11,6 +11,8 @@ use distributions::Distribution;
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
use ::core::cmp::PartialOrd;
use core::fmt;
use core::iter::Sum;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};

// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature="std"))] use alloc::vec::Vec;
Expand Down Expand Up @@ -131,40 +133,52 @@ impl<X> Distribution<usize> for WeightedIndex<X> where
}

#[allow(missing_docs)] // todo: add docs
#[derive(Debug, Clone)]
pub struct AliasMethodWeightedIndex {
#[allow(missing_debug_implementations)] // todo: why does `#[derive(Debug)]` not work?
pub struct AliasMethodWeightedIndex<W: AliasMethodWeight> {
aliases: Vec<usize>,
no_alias_odds: Vec<f64>,
no_alias_odds: Vec<W>,
uniform_index: super::Uniform<usize>,
uniform_within_weight_sum: super::Uniform<f64>,
uniform_within_weight_sum: super::Uniform<W>,
}

impl AliasMethodWeightedIndex {
impl<W: AliasMethodWeight> AliasMethodWeightedIndex<W> {
#[allow(missing_docs)] // todo: add docs
pub fn new(weights: Vec<f64>) -> Result<Self, AliasMethodWeightedIndexError> {
pub fn new(weights: Vec<W>) -> Result<Self, AliasMethodWeightedIndexError> {
let n = weights.len();
if n == 0 {
return Err(AliasMethodWeightedIndexError::NoItem);
}

let max_weight_size = ::core::f64::MAX / n as f64;
if !weights.iter().all(|&w| 0_f64 <= w && w <= max_weight_size) {
let max_weight_size = W::try_from_usize_lossy(n)
.map(|n| W::MAX / n)
.unwrap_or(W::ZERO);
if !weights
.iter()
.all(|&w| W::ZERO <= w && w <= max_weight_size)
{
return Err(AliasMethodWeightedIndexError::InvalidWeight);
}

// The sum of weights will represent 100% of no alias odds.
let weight_sum = pairwise_sum_f64(weights.as_slice());
let weight_sum = pairwise_sum(weights.as_slice());
// Prevent floating point overflow due to rounding errors.
let weight_sum = weight_sum.min(::core::f64::MAX);
if weight_sum == 0_f64 {
let weight_sum = if weight_sum > W::MAX {
W::MAX
} else {
weight_sum
};
if weight_sum == W::ZERO {
return Err(AliasMethodWeightedIndexError::AllWeightsZero);
}

// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
let n_converted = W::try_from_usize_lossy(n).unwrap();

let mut no_alias_odds = weights;
for odds in no_alias_odds.iter_mut() {
*odds *= n as f64;
*odds *= n_converted;
// Prevent floating point overflow due to rounding errors.
*odds = odds.min(::core::f64::MAX);
*odds = if *odds > W::MAX { W::MAX } else { *odds };
}

/// This struct is designed to contain three data structures at once,
Expand Down Expand Up @@ -262,7 +276,7 @@ impl AliasMethodWeightedIndex {
// Prepare distributions for sampling. Creating them beforehand improves
// sampling performance.
let uniform_index = super::Uniform::new(0, n);
let uniform_within_weight_sum = super::Uniform::new(0_f64, weight_sum);
let uniform_within_weight_sum = super::Uniform::new(W::ZERO, weight_sum);

Ok(Self {
aliases: aliases.aliases,
Expand All @@ -273,27 +287,98 @@ impl AliasMethodWeightedIndex {
}
}

impl Distribution<usize> for AliasMethodWeightedIndex {
impl<W: AliasMethodWeight> Distribution<usize> for AliasMethodWeightedIndex<W> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let candidate = rng.sample(self.uniform_index);
if rng.sample(self.uniform_within_weight_sum) < self.no_alias_odds[candidate] {
if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] {
candidate
} else {
self.aliases[candidate]
}
}
}

fn pairwise_sum_f64(values: &[f64]) -> f64 {
/// In comparision to naive accumulation, the pairwise sum algorithm reduces
/// rounding errors when there are many floating point values.
fn pairwise_sum<T: AliasMethodWeight>(values: &[T]) -> T {
if values.len() <= 32 {
values.iter().sum()
values.iter().map(|x| *x).sum()
} else {
let mid = values.len() / 2;
let (a, b) = values.split_at(mid);
pairwise_sum_f64(a) + pairwise_sum_f64(b)
pairwise_sum(a) + pairwise_sum(b)
}
}

pub trait AliasMethodWeight:
Sized
+ Copy
+ SampleUniform
+ PartialOrd
+ Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Div<Output = Self>
+ DivAssign
+ Sum
{
const MAX: Self;
const ZERO: Self;

fn try_from_usize_lossy(n: usize) -> Option<Self>;
}

macro_rules! impl_alias_method_weight_for_float {
($T: ident) => {
impl AliasMethodWeight for $T {
const MAX: Self = ::core::$T::MAX;
const ZERO: Self = 0.0;

fn try_from_usize_lossy(n: usize) -> Option<Self> {
Some(n as $T)
}
}
};
}

macro_rules! impl_alias_method_weight_for_int {
($T: ident) => {
impl AliasMethodWeight for $T {
const MAX: Self = ::core::$T::MAX;
const ZERO: Self = 0;

fn try_from_usize_lossy(n: usize) -> Option<Self> {
let n_converted = n as Self;
if n_converted >= Self::ZERO && n_converted as usize == n {
Some(n_converted)
} else {
None
}
}
}
};
}

impl_alias_method_weight_for_float!(f64);
impl_alias_method_weight_for_float!(f32);
impl_alias_method_weight_for_int!(usize);
#[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
impl_alias_method_weight_for_int!(u128);
impl_alias_method_weight_for_int!(u64);
impl_alias_method_weight_for_int!(u32);
impl_alias_method_weight_for_int!(u16);
impl_alias_method_weight_for_int!(u8);
impl_alias_method_weight_for_int!(isize);
#[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
impl_alias_method_weight_for_int!(i128);
impl_alias_method_weight_for_int!(i64);
impl_alias_method_weight_for_int!(i32);
impl_alias_method_weight_for_int!(i16);
impl_alias_method_weight_for_int!(i8);

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -354,28 +439,106 @@ mod test {
}

#[test]
fn test_alias_method_weighted_index() {
fn test_alias_method_weighted_index_f32() {
test_alias_method_weighted_index(f32::into);

// Floating point special cases
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f32::INFINITY])
.err()
.unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-0_f32]).err().unwrap(),
AliasMethodWeightedIndexError::AllWeightsZero
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-1_f32]).err().unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-::core::f32::INFINITY])
.err()
.unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f32::NAN])
.err()
.unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
}

#[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
#[test]
fn test_alias_method_weighted_index_u128() {
test_alias_method_weighted_index(|x: u128| x as f64);
}

#[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
#[test]
fn test_alias_method_weighted_index_i128() {
test_alias_method_weighted_index(|x: i128| x as f64);

// Signed integer special cases
assert_eq!(
AliasMethodWeightedIndex::new(vec![-1_i128]).err().unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::i128::MIN])
.err()
.unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
}

#[test]
fn test_alias_method_weighted_index_u8() {
test_alias_method_weighted_index(u8::into);
}

#[test]
fn test_alias_method_weighted_index_i8() {
test_alias_method_weighted_index(i8::into);

// Signed integer special cases
assert_eq!(
AliasMethodWeightedIndex::new(vec![-1_i8]).err().unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::i8::MIN])
.err()
.unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
}

fn test_alias_method_weighted_index<W: AliasMethodWeight, F: Fn(W) -> f64>(w_to_f64: F) {
const NUM_WEIGHTS: usize = 10;
const ZERO_WEIGHT_INDEX: usize = 3;
const NUM_SAMPLES: u32 = 10000;
const NUM_SAMPLES: u32 = 15000;
let mut rng = ::test::rng(0x9c9fa0b0580a7031);

let weights = {
let mut weights = Vec::with_capacity(NUM_WEIGHTS);
let random_weight_distribution = ::distributions::Uniform::new_inclusive(
0_f64,
::core::f64::MAX / NUM_WEIGHTS as f64,
W::ZERO,
W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(),
);
for _ in 0..NUM_WEIGHTS {
weights.push(rng.sample(random_weight_distribution));
weights.push(rng.sample(&random_weight_distribution));
}
weights[ZERO_WEIGHT_INDEX] = 0.0;
weights[ZERO_WEIGHT_INDEX] = W::ZERO;
weights
};
let weight_sum = weights.iter().sum::<f64>();
let weight_sum = weights.iter().map(|w| *w).sum::<W>();
let expected_counts = weights
.iter()
.map(|&w| w / weight_sum * NUM_SAMPLES as f64)
.map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
.collect::<Vec<f64>>();
let weight_distribution = AliasMethodWeightedIndex::new(weights).unwrap();

Expand All @@ -392,35 +555,17 @@ mod test {
}

assert_eq!(
AliasMethodWeightedIndex::new(vec![]).unwrap_err(),
AliasMethodWeightedIndex::<W>::new(vec![]).err().unwrap(),
AliasMethodWeightedIndexError::NoItem
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(),
AliasMethodWeightedIndexError::AllWeightsZero
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-0.0]).unwrap_err(),
AliasMethodWeightedIndex::new(vec![W::ZERO]).err().unwrap(),
AliasMethodWeightedIndexError::AllWeightsZero
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-::core::f64::INFINITY]).unwrap_err(),
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::NAN]).unwrap_err(),
AliasMethodWeightedIndex::new(vec![W::MAX, W::MAX])
.err()
.unwrap(),
AliasMethodWeightedIndexError::InvalidWeight
);
}
Expand Down

0 comments on commit 14feb99

Please sign in to comment.