diff --git a/benches/distributions.rs b/benches/distributions.rs index cf6e4a6558f..0efa50d71d3 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -187,20 +187,10 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); -distr_int!( - distr_weighted_alias_method, - usize, - AliasMethodWeightedIndex::new( - vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001] - ).unwrap() -); -distr_int!( - distr_weighted_alias_method_large_set, - usize, - AliasMethodWeightedIndex::new( - (0..10000).rev().chain(1..10001).map(|x| x as f64).collect() - ).unwrap() -); +distr_int!(distr_weighted_alias_method_i8, usize, AliasMethodWeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_u32, usize, AliasMethodWeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_f64, usize, AliasMethodWeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); +distr_int!(distr_weighted_alias_method_large_set, usize, AliasMethodWeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); // construct and sample from a range macro_rules! gen_range_int { diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 96c2fcc747d..e823ad8b0a4 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -11,6 +11,8 @@ use distributions::Distribution; use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; use ::core::cmp::PartialOrd; use core::fmt; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; // Note that this whole module is only imported if feature="alloc" is enabled. #[cfg(not(feature="std"))] use alloc::vec::Vec; @@ -131,40 +133,52 @@ impl Distribution for WeightedIndex where } #[allow(missing_docs)] // todo: add docs -#[derive(Debug, Clone)] -pub struct AliasMethodWeightedIndex { +#[allow(missing_debug_implementations)] // todo: why does `#[derive(Debug)]` not work? +pub struct AliasMethodWeightedIndex { aliases: Vec, - no_alias_odds: Vec, + no_alias_odds: Vec, uniform_index: super::Uniform, - uniform_within_weight_sum: super::Uniform, + uniform_within_weight_sum: super::Uniform, } -impl AliasMethodWeightedIndex { +impl AliasMethodWeightedIndex { #[allow(missing_docs)] // todo: add docs - pub fn new(weights: Vec) -> Result { + pub fn new(weights: Vec) -> Result { let n = weights.len(); if n == 0 { return Err(AliasMethodWeightedIndexError::NoItem); } - let max_weight_size = ::core::f64::MAX / n as f64; - if !weights.iter().all(|&w| 0_f64 <= w && w <= max_weight_size) { + let max_weight_size = W::try_from_usize_lossy(n) + .map(|n| W::MAX / n) + .unwrap_or(W::ZERO); + if !weights + .iter() + .all(|&w| W::ZERO <= w && w <= max_weight_size) + { return Err(AliasMethodWeightedIndexError::InvalidWeight); } // The sum of weights will represent 100% of no alias odds. - let weight_sum = pairwise_sum_f64(weights.as_slice()); + let weight_sum = pairwise_sum(weights.as_slice()); // Prevent floating point overflow due to rounding errors. - let weight_sum = weight_sum.min(::core::f64::MAX); - if weight_sum == 0_f64 { + let weight_sum = if weight_sum > W::MAX { + W::MAX + } else { + weight_sum + }; + if weight_sum == W::ZERO { return Err(AliasMethodWeightedIndexError::AllWeightsZero); } + // `weight_sum` would have been zero if `try_from_lossy` causes an error here. + let n_converted = W::try_from_usize_lossy(n).unwrap(); + let mut no_alias_odds = weights; for odds in no_alias_odds.iter_mut() { - *odds *= n as f64; + *odds *= n_converted; // Prevent floating point overflow due to rounding errors. - *odds = odds.min(::core::f64::MAX); + *odds = if *odds > W::MAX { W::MAX } else { *odds }; } /// This struct is designed to contain three data structures at once, @@ -262,7 +276,7 @@ impl AliasMethodWeightedIndex { // Prepare distributions for sampling. Creating them beforehand improves // sampling performance. let uniform_index = super::Uniform::new(0, n); - let uniform_within_weight_sum = super::Uniform::new(0_f64, weight_sum); + let uniform_within_weight_sum = super::Uniform::new(W::ZERO, weight_sum); Ok(Self { aliases: aliases.aliases, @@ -273,10 +287,10 @@ impl AliasMethodWeightedIndex { } } -impl Distribution for AliasMethodWeightedIndex { +impl Distribution for AliasMethodWeightedIndex { fn sample(&self, rng: &mut R) -> usize { let candidate = rng.sample(self.uniform_index); - if rng.sample(self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { + if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { candidate } else { self.aliases[candidate] @@ -284,16 +298,87 @@ impl Distribution for AliasMethodWeightedIndex { } } -fn pairwise_sum_f64(values: &[f64]) -> f64 { +/// In comparision to naive accumulation, the pairwise sum algorithm reduces +/// rounding errors when there are many floating point values. +fn pairwise_sum(values: &[T]) -> T { if values.len() <= 32 { - values.iter().sum() + values.iter().map(|x| *x).sum() } else { let mid = values.len() / 2; let (a, b) = values.split_at(mid); - pairwise_sum_f64(a) + pairwise_sum_f64(b) + pairwise_sum(a) + pairwise_sum(b) } } +pub trait AliasMethodWeight: + Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Mul + + MulAssign + + Div + + DivAssign + + Sum +{ + const MAX: Self; + const ZERO: Self; + + fn try_from_usize_lossy(n: usize) -> Option; +} + +macro_rules! impl_alias_method_weight_for_float { + ($T: ident) => { + impl AliasMethodWeight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0.0; + + fn try_from_usize_lossy(n: usize) -> Option { + Some(n as $T) + } + } + }; +} + +macro_rules! impl_alias_method_weight_for_int { + ($T: ident) => { + impl AliasMethodWeight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0; + + fn try_from_usize_lossy(n: usize) -> Option { + let n_converted = n as Self; + if n_converted >= Self::ZERO && n_converted as usize == n { + Some(n_converted) + } else { + None + } + } + } + }; +} + +impl_alias_method_weight_for_float!(f64); +impl_alias_method_weight_for_float!(f32); +impl_alias_method_weight_for_int!(usize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_alias_method_weight_for_int!(u128); +impl_alias_method_weight_for_int!(u64); +impl_alias_method_weight_for_int!(u32); +impl_alias_method_weight_for_int!(u16); +impl_alias_method_weight_for_int!(u8); +impl_alias_method_weight_for_int!(isize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_alias_method_weight_for_int!(i128); +impl_alias_method_weight_for_int!(i64); +impl_alias_method_weight_for_int!(i32); +impl_alias_method_weight_for_int!(i16); +impl_alias_method_weight_for_int!(i8); + #[cfg(test)] mod test { use super::*; @@ -354,28 +439,106 @@ mod test { } #[test] - fn test_alias_method_weighted_index() { + fn test_alias_method_weighted_index_f32() { + test_alias_method_weighted_index(f32::into); + + // Floating point special cases + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::f32::INFINITY]) + .err() + .unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-0_f32]).err().unwrap(), + AliasMethodWeightedIndexError::AllWeightsZero + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1_f32]).err().unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-::core::f32::INFINITY]) + .err() + .unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::f32::NAN]) + .err() + .unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_alias_method_weighted_index_u128() { + test_alias_method_weighted_index(|x: u128| x as f64); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_alias_method_weighted_index_i128() { + test_alias_method_weighted_index(|x: i128| x as f64); + + // Signed integer special cases + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1_i128]).err().unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::i128::MIN]) + .err() + .unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } + + #[test] + fn test_alias_method_weighted_index_u8() { + test_alias_method_weighted_index(u8::into); + } + + #[test] + fn test_alias_method_weighted_index_i8() { + test_alias_method_weighted_index(i8::into); + + // Signed integer special cases + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1_i8]).err().unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::i8::MIN]) + .err() + .unwrap(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } + + fn test_alias_method_weighted_index f64>(w_to_f64: F) { const NUM_WEIGHTS: usize = 10; const ZERO_WEIGHT_INDEX: usize = 3; - const NUM_SAMPLES: u32 = 10000; + const NUM_SAMPLES: u32 = 15000; let mut rng = ::test::rng(0x9c9fa0b0580a7031); let weights = { let mut weights = Vec::with_capacity(NUM_WEIGHTS); let random_weight_distribution = ::distributions::Uniform::new_inclusive( - 0_f64, - ::core::f64::MAX / NUM_WEIGHTS as f64, + W::ZERO, + W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(), ); for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample(random_weight_distribution)); + weights.push(rng.sample(&random_weight_distribution)); } - weights[ZERO_WEIGHT_INDEX] = 0.0; + weights[ZERO_WEIGHT_INDEX] = W::ZERO; weights }; - let weight_sum = weights.iter().sum::(); + let weight_sum = weights.iter().map(|w| *w).sum::(); let expected_counts = weights .iter() - .map(|&w| w / weight_sum * NUM_SAMPLES as f64) + .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) .collect::>(); let weight_distribution = AliasMethodWeightedIndex::new(weights).unwrap(); @@ -392,35 +555,17 @@ mod test { } assert_eq!( - AliasMethodWeightedIndex::new(vec![]).unwrap_err(), + AliasMethodWeightedIndex::::new(vec![]).err().unwrap(), AliasMethodWeightedIndexError::NoItem ); assert_eq!( - AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(), - AliasMethodWeightedIndexError::AllWeightsZero - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-0.0]).unwrap_err(), + AliasMethodWeightedIndex::new(vec![W::ZERO]).err().unwrap(), AliasMethodWeightedIndexError::AllWeightsZero ); assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-::core::f64::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f64::NAN]).unwrap_err(), + AliasMethodWeightedIndex::new(vec![W::MAX, W::MAX]) + .err() + .unwrap(), AliasMethodWeightedIndexError::InvalidWeight ); }