From c2bed15554938551218cf54b3ba802b2b46d0e4b Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Sat, 23 Feb 2019 13:25:30 +0100 Subject: [PATCH 01/10] Added an implementation of alias method for weighted indices --- src/distributions/mod.rs | 5 +- src/distributions/weighted.rs | 148 ++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 1 deletion(-) diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 6e2d6c7bad2..06b578e989f 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -181,7 +181,10 @@ pub use self::other::Alphanumeric; #[doc(inline)] pub use self::uniform::Uniform; pub use self::float::{OpenClosed01, Open01}; pub use self::bernoulli::Bernoulli; -#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError}; +#[cfg(feature = "alloc")] +pub use self::weighted::{ + AliasMethodWeightedIndex, AliasMethodWeightedIndexError, WeightedError, WeightedIndex, +}; #[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface; #[cfg(feature="std")] pub use self::unit_circle::UnitCircle; #[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index d7499596e39..667adfe64c8 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -13,7 +13,9 @@ use ::core::cmp::PartialOrd; use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. +#[cfg(not(feature = "std"))] use alloc::collections::VecDeque; #[cfg(not(feature="std"))] use alloc::vec::Vec; +#[cfg(feature = "std")] use std::collections::VecDeque; /// A distribution using weighted sampling to pick a discretely selected /// item. @@ -130,6 +132,116 @@ impl Distribution for WeightedIndex where } } +#[allow(missing_docs)] // todo: add docs +#[derive(Debug, Clone)] +pub struct AliasMethodWeightedIndex { + aliases: Vec, + no_alias_odds: Vec, + uniform_index: super::Uniform, +} + +impl AliasMethodWeightedIndex { + #[allow(missing_docs)] // todo: add docs + pub fn new(weights: Vec) -> Result { + if weights.is_empty() { + return Err(AliasMethodWeightedIndexError::NoItem); + } + if !weights.iter().all(|&w| w >= 0.0) { + return Err(AliasMethodWeightedIndexError::InvalidWeight); + } + + let n = weights.len(); + let weight_sum = pairwise_sum_f64(weights.as_slice()); + if weight_sum.is_infinite() { + return Err(AliasMethodWeightedIndexError::WeightSumToBig); + } + + let weight_scale = n as f64 / weight_sum; + if weight_scale.is_infinite() { + return Err(AliasMethodWeightedIndexError::WeightSumToSmall); + } + + let mut no_alias_odds = weights; + for odds in no_alias_odds.iter_mut() { + *odds *= weight_scale; + } + + // Split indices into indices with small weights and indices with big weights. + // Instead of two `Vec` with unknown capacity we use a single `VecDeque` with + // known capacity. Front represents smalls and back represents bigs. We also + // need to keep track of the size of each virtual `Vec`. + let mut smalls_bigs = VecDeque::with_capacity(n); + let mut smalls_len = 0_usize; + let mut bigs_len = 0_usize; + for (index, &odds) in no_alias_odds.iter().enumerate() { + if odds < 1.0 { + smalls_bigs.push_front(index); + smalls_len += 1; + } else { + smalls_bigs.push_back(index); + bigs_len += 1; + } + } + + let mut aliases = vec![0; n]; + while smalls_len > 0 && bigs_len > 0 { + let s = smalls_bigs.pop_front().unwrap(); + smalls_len -= 1; + let b = smalls_bigs.pop_back().unwrap(); + bigs_len -= 1; + + aliases[s] = b; + no_alias_odds[b] = no_alias_odds[b] - 1.0 + no_alias_odds[s]; + + if no_alias_odds[b] < 1.0 { + smalls_bigs.push_front(b); + smalls_len += 1; + } else { + smalls_bigs.push_back(b); + bigs_len += 1; + } + } + + // The remaining indices should have no alias odds of about 1. This is due to + // numeric accuracy. Otherwise they would be exactly 1. + for index in smalls_bigs.into_iter() { + // Because p = 1 we don't need to set an alias. It will never be accessed. + no_alias_odds[index] = 1.0; + } + + // Prepare a distribution to sample random indices. Creating it beforehand + // improves sampling performance. + let uniform_index = super::Uniform::new(0, n); + + Ok(Self { + aliases, + no_alias_odds, + uniform_index, + }) + } +} + +impl Distribution for AliasMethodWeightedIndex { + fn sample(&self, rng: &mut R) -> usize { + let candidate = rng.sample(self.uniform_index); + if rng.sample::(super::Standard) < self.no_alias_odds[candidate] { + candidate + } else { + self.aliases[candidate] + } + } +} + +fn pairwise_sum_f64(values: &[f64]) -> f64 { + if values.len() <= 32 { + values.iter().sum() + } else { + let mid = values.len() / 2; + let (a, b) = values.split_at(mid); + pairwise_sum_f64(a) + pairwise_sum_f64(b) + } +} + #[cfg(test)] mod test { use super::*; @@ -228,3 +340,39 @@ impl fmt::Display for WeightedError { write!(f, "{}", self.msg()) } } + +#[allow(missing_docs)] // todo: add docs +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AliasMethodWeightedIndexError { + NoItem, + InvalidWeight, + WeightSumToSmall, + WeightSumToBig, +} + +impl AliasMethodWeightedIndexError { + fn msg(&self) -> &str { + match *self { + AliasMethodWeightedIndexError::NoItem => "No items found.", + AliasMethodWeightedIndexError::InvalidWeight => "An item has an invalid weight.", + AliasMethodWeightedIndexError::WeightSumToSmall => "The sum of weights is to small.", + AliasMethodWeightedIndexError::WeightSumToBig => "The sum of weights is to big.", + } + } +} + +impl fmt::Display for AliasMethodWeightedIndexError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.msg()) + } +} + +#[cfg(feature = "std")] +impl ::std::error::Error for AliasMethodWeightedIndexError { + fn description(&self) -> &str { + self.msg() + } + fn cause(&self) -> Option<&::std::error::Error> { + None + } +} From 1feb63317977f2aa90ff34b8726b2d0ff8a0a7ec Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Sat, 23 Feb 2019 13:25:30 +0100 Subject: [PATCH 02/10] Added tests for AliasMethodWeightedIndex --- benches/distributions.rs | 15 ++++++++ src/distributions/weighted.rs | 68 +++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/benches/distributions.rs b/benches/distributions.rs index 069a82856a5..cf6e4a6558f 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -187,6 +187,21 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); +distr_int!( + distr_weighted_alias_method, + usize, + AliasMethodWeightedIndex::new( + vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001] + ).unwrap() +); +distr_int!( + distr_weighted_alias_method_large_set, + usize, + AliasMethodWeightedIndex::new( + (0..10000).rev().chain(1..10001).map(|x| x as f64).collect() + ).unwrap() +); + // construct and sample from a range macro_rules! gen_range_int { ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 667adfe64c8..97e6ad24556 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -300,6 +300,74 @@ mod test { assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); } + + #[test] + fn test_alias_method_weighted_index() { + const NUM_WEIGHTS: usize = 10; + const ZERO_WEIGHT_INDEX: usize = 3; + const NUM_SAMPLES: u32 = 10000; + let mut rng = ::test::rng(0x9c9fa0b0580a7031); + + let weights = { + let mut weights = Vec::with_capacity(NUM_WEIGHTS); + for _ in 0..NUM_WEIGHTS { + weights.push(rng.sample::(::distributions::Standard)); + } + weights[ZERO_WEIGHT_INDEX] = 0.0; + weights + }; + let weight_sum = weights.iter().sum::(); + let expected_counts = weights + .iter() + .map(|&w| w / weight_sum * NUM_SAMPLES as f64) + .collect::>(); + let weight_distribution = AliasMethodWeightedIndex::new(weights).unwrap(); + + let mut counts = vec![0_usize; NUM_WEIGHTS]; + for _ in 0..NUM_SAMPLES { + counts[rng.sample(&weight_distribution)] += 1; + } + + assert_eq!(counts[ZERO_WEIGHT_INDEX], 0); + for (count, expected_count) in counts.into_iter().zip(expected_counts) { + let difference = (count as f64 - expected_count).abs(); + let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; + assert!(difference <= max_allowed_difference); + } + + assert_eq!( + AliasMethodWeightedIndex::new(vec![]).unwrap_err(), + AliasMethodWeightedIndexError::NoItem + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(), + AliasMethodWeightedIndexError::WeightSumToSmall + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-0.0]).unwrap_err(), + AliasMethodWeightedIndexError::WeightSumToSmall + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(), + AliasMethodWeightedIndexError::WeightSumToBig + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(), + AliasMethodWeightedIndexError::WeightSumToBig + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-::core::f64::INFINITY]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::f64::NAN]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } } /// Error type returned from `WeightedIndex::new`. From 002a001e901cfbf7611d40852337c029d4a6f5a9 Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Sat, 23 Feb 2019 13:25:30 +0100 Subject: [PATCH 03/10] Get rid of the extra VecDeque during creation of AliasMethodWeightedIndex --- src/distributions/weighted.rs | 103 +++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 28 deletions(-) diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 97e6ad24556..06637a7e987 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -13,9 +13,7 @@ use ::core::cmp::PartialOrd; use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. -#[cfg(not(feature = "std"))] use alloc::collections::VecDeque; #[cfg(not(feature="std"))] use alloc::vec::Vec; -#[cfg(feature = "std")] use std::collections::VecDeque; /// A distribution using weighted sampling to pick a discretely selected /// item. @@ -166,47 +164,96 @@ impl AliasMethodWeightedIndex { *odds *= weight_scale; } - // Split indices into indices with small weights and indices with big weights. - // Instead of two `Vec` with unknown capacity we use a single `VecDeque` with - // known capacity. Front represents smalls and back represents bigs. We also - // need to keep track of the size of each virtual `Vec`. - let mut smalls_bigs = VecDeque::with_capacity(n); - let mut smalls_len = 0_usize; - let mut bigs_len = 0_usize; + /// This struct is designed to contain three data structures at once, + /// sharing the same memory. More precisely it contains two + /// linked lists and an alias map, which will be the output of this + /// method. To keep the three data structures from getting in + /// each other's way, it must be ensured that a single index is only + /// ever in one of them at the same time. + struct Aliases { + aliases: Vec, + smalls_head: usize, + bigs_head: usize, + } + + impl Aliases { + fn new(size: usize) -> Self { + Aliases { + aliases: vec![0; size], + smalls_head: ::core::usize::MAX, + bigs_head: ::core::usize::MAX, + } + } + + fn push_small(&mut self, idx: usize) { + self.aliases[idx] = self.smalls_head; + self.smalls_head = idx; + } + + fn push_big(&mut self, idx: usize) { + self.aliases[idx] = self.bigs_head; + self.bigs_head = idx; + } + + fn pop_small(&mut self) -> usize { + let popped = self.smalls_head; + self.smalls_head = self.aliases[popped]; + popped + } + + fn pop_big(&mut self) -> usize { + let popped = self.bigs_head; + self.bigs_head = self.aliases[popped]; + popped + } + + fn smalls_is_empty(&self) -> bool { + self.smalls_head == ::core::usize::MAX + } + + fn bigs_is_empty(&self) -> bool { + self.bigs_head == ::core::usize::MAX + } + + fn set_alias(&mut self, idx: usize, alias: usize) { + self.aliases[idx] = alias; + } + } + + let mut aliases = Aliases::new(n); + + // Split indices into those with small weights and those with big weights. for (index, &odds) in no_alias_odds.iter().enumerate() { if odds < 1.0 { - smalls_bigs.push_front(index); - smalls_len += 1; + aliases.push_small(index); } else { - smalls_bigs.push_back(index); - bigs_len += 1; + aliases.push_big(index); } } - let mut aliases = vec![0; n]; - while smalls_len > 0 && bigs_len > 0 { - let s = smalls_bigs.pop_front().unwrap(); - smalls_len -= 1; - let b = smalls_bigs.pop_back().unwrap(); - bigs_len -= 1; + // Build the alias map by finding an alias with big weight for each index with + // small weight. + while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { + let s = aliases.pop_small(); + let b = aliases.pop_big(); - aliases[s] = b; + aliases.set_alias(s, b); no_alias_odds[b] = no_alias_odds[b] - 1.0 + no_alias_odds[s]; if no_alias_odds[b] < 1.0 { - smalls_bigs.push_front(b); - smalls_len += 1; + aliases.push_small(b); } else { - smalls_bigs.push_back(b); - bigs_len += 1; + aliases.push_big(b); } } // The remaining indices should have no alias odds of about 1. This is due to // numeric accuracy. Otherwise they would be exactly 1. - for index in smalls_bigs.into_iter() { - // Because p = 1 we don't need to set an alias. It will never be accessed. - no_alias_odds[index] = 1.0; + while !aliases.smalls_is_empty() { + no_alias_odds[aliases.pop_small()] = 1.0; + } + while !aliases.bigs_is_empty() { + no_alias_odds[aliases.pop_big()] = 1.0; } // Prepare a distribution to sample random indices. Creating it beforehand @@ -214,7 +261,7 @@ impl AliasMethodWeightedIndex { let uniform_index = super::Uniform::new(0, n); Ok(Self { - aliases, + aliases: aliases.aliases, no_alias_odds, uniform_index, }) From ea8397492171390853ae8de454d069bfb48b55ff Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Sat, 23 Feb 2019 13:25:30 +0100 Subject: [PATCH 04/10] Made implementation details of AliasMethodWeightedIndex more generic --- src/distributions/weighted.rs | 67 +++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 06637a7e987..96c2fcc747d 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -136,32 +136,35 @@ pub struct AliasMethodWeightedIndex { aliases: Vec, no_alias_odds: Vec, uniform_index: super::Uniform, + uniform_within_weight_sum: super::Uniform, } impl AliasMethodWeightedIndex { #[allow(missing_docs)] // todo: add docs pub fn new(weights: Vec) -> Result { - if weights.is_empty() { + let n = weights.len(); + if n == 0 { return Err(AliasMethodWeightedIndexError::NoItem); } - if !weights.iter().all(|&w| w >= 0.0) { + + let max_weight_size = ::core::f64::MAX / n as f64; + if !weights.iter().all(|&w| 0_f64 <= w && w <= max_weight_size) { return Err(AliasMethodWeightedIndexError::InvalidWeight); } - let n = weights.len(); + // The sum of weights will represent 100% of no alias odds. let weight_sum = pairwise_sum_f64(weights.as_slice()); - if weight_sum.is_infinite() { - return Err(AliasMethodWeightedIndexError::WeightSumToBig); - } - - let weight_scale = n as f64 / weight_sum; - if weight_scale.is_infinite() { - return Err(AliasMethodWeightedIndexError::WeightSumToSmall); + // Prevent floating point overflow due to rounding errors. + let weight_sum = weight_sum.min(::core::f64::MAX); + if weight_sum == 0_f64 { + return Err(AliasMethodWeightedIndexError::AllWeightsZero); } let mut no_alias_odds = weights; for odds in no_alias_odds.iter_mut() { - *odds *= weight_scale; + *odds *= n as f64; + // Prevent floating point overflow due to rounding errors. + *odds = odds.min(::core::f64::MAX); } /// This struct is designed to contain three data structures at once, @@ -224,7 +227,7 @@ impl AliasMethodWeightedIndex { // Split indices into those with small weights and those with big weights. for (index, &odds) in no_alias_odds.iter().enumerate() { - if odds < 1.0 { + if odds < weight_sum { aliases.push_small(index); } else { aliases.push_big(index); @@ -238,32 +241,34 @@ impl AliasMethodWeightedIndex { let b = aliases.pop_big(); aliases.set_alias(s, b); - no_alias_odds[b] = no_alias_odds[b] - 1.0 + no_alias_odds[s]; + no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s]; - if no_alias_odds[b] < 1.0 { + if no_alias_odds[b] < weight_sum { aliases.push_small(b); } else { aliases.push_big(b); } } - // The remaining indices should have no alias odds of about 1. This is due to - // numeric accuracy. Otherwise they would be exactly 1. + // The remaining indices should have no alias odds of about 100%. This is due to + // numeric accuracy. Otherwise they would be exactly 100%. while !aliases.smalls_is_empty() { - no_alias_odds[aliases.pop_small()] = 1.0; + no_alias_odds[aliases.pop_small()] = weight_sum; } while !aliases.bigs_is_empty() { - no_alias_odds[aliases.pop_big()] = 1.0; + no_alias_odds[aliases.pop_big()] = weight_sum; } - // Prepare a distribution to sample random indices. Creating it beforehand - // improves sampling performance. + // Prepare distributions for sampling. Creating them beforehand improves + // sampling performance. let uniform_index = super::Uniform::new(0, n); + let uniform_within_weight_sum = super::Uniform::new(0_f64, weight_sum); Ok(Self { aliases: aliases.aliases, no_alias_odds, uniform_index, + uniform_within_weight_sum, }) } } @@ -271,7 +276,7 @@ impl AliasMethodWeightedIndex { impl Distribution for AliasMethodWeightedIndex { fn sample(&self, rng: &mut R) -> usize { let candidate = rng.sample(self.uniform_index); - if rng.sample::(super::Standard) < self.no_alias_odds[candidate] { + if rng.sample(self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { candidate } else { self.aliases[candidate] @@ -357,8 +362,12 @@ mod test { let weights = { let mut weights = Vec::with_capacity(NUM_WEIGHTS); + let random_weight_distribution = ::distributions::Uniform::new_inclusive( + 0_f64, + ::core::f64::MAX / NUM_WEIGHTS as f64, + ); for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample::(::distributions::Standard)); + weights.push(rng.sample(random_weight_distribution)); } weights[ZERO_WEIGHT_INDEX] = 0.0; weights @@ -388,19 +397,19 @@ mod test { ); assert_eq!( AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToSmall + AliasMethodWeightedIndexError::AllWeightsZero ); assert_eq!( AliasMethodWeightedIndex::new(vec![-0.0]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToSmall + AliasMethodWeightedIndexError::AllWeightsZero ); assert_eq!( AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToBig + AliasMethodWeightedIndexError::InvalidWeight ); assert_eq!( AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(), - AliasMethodWeightedIndexError::WeightSumToBig + AliasMethodWeightedIndexError::InvalidWeight ); assert_eq!( AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(), @@ -461,8 +470,7 @@ impl fmt::Display for WeightedError { pub enum AliasMethodWeightedIndexError { NoItem, InvalidWeight, - WeightSumToSmall, - WeightSumToBig, + AllWeightsZero, } impl AliasMethodWeightedIndexError { @@ -470,8 +478,7 @@ impl AliasMethodWeightedIndexError { match *self { AliasMethodWeightedIndexError::NoItem => "No items found.", AliasMethodWeightedIndexError::InvalidWeight => "An item has an invalid weight.", - AliasMethodWeightedIndexError::WeightSumToSmall => "The sum of weights is to small.", - AliasMethodWeightedIndexError::WeightSumToBig => "The sum of weights is to big.", + AliasMethodWeightedIndexError::AllWeightsZero => "All weights are zero.", } } } From f392fb770895c8958ace0ec229e15470162c9bac Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Sat, 23 Feb 2019 13:25:30 +0100 Subject: [PATCH 05/10] Made AliasMethodWeightedIndex generic --- benches/distributions.rs | 18 +-- src/distributions/mod.rs | 3 +- src/distributions/weighted.rs | 264 +++++++++++++++++++++++++++------- 3 files changed, 220 insertions(+), 65 deletions(-) diff --git a/benches/distributions.rs b/benches/distributions.rs index cf6e4a6558f..0efa50d71d3 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -187,20 +187,10 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); -distr_int!( - distr_weighted_alias_method, - usize, - AliasMethodWeightedIndex::new( - vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001] - ).unwrap() -); -distr_int!( - distr_weighted_alias_method_large_set, - usize, - AliasMethodWeightedIndex::new( - (0..10000).rev().chain(1..10001).map(|x| x as f64).collect() - ).unwrap() -); +distr_int!(distr_weighted_alias_method_i8, usize, AliasMethodWeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_u32, usize, AliasMethodWeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_f64, usize, AliasMethodWeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); +distr_int!(distr_weighted_alias_method_large_set, usize, AliasMethodWeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); // construct and sample from a range macro_rules! gen_range_int { diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 06b578e989f..ab3f8541b70 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -183,7 +183,8 @@ pub use self::float::{OpenClosed01, Open01}; pub use self::bernoulli::Bernoulli; #[cfg(feature = "alloc")] pub use self::weighted::{ - AliasMethodWeightedIndex, AliasMethodWeightedIndexError, WeightedError, WeightedIndex, + AliasMethodWeight, AliasMethodWeightedIndex, AliasMethodWeightedIndexError, WeightedError, + WeightedIndex, }; #[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface; #[cfg(feature="std")] pub use self::unit_circle::UnitCircle; diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 96c2fcc747d..520c1fd1a97 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -11,6 +11,8 @@ use distributions::Distribution; use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; use ::core::cmp::PartialOrd; use core::fmt; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; // Note that this whole module is only imported if feature="alloc" is enabled. #[cfg(not(feature="std"))] use alloc::vec::Vec; @@ -131,40 +133,51 @@ impl Distribution for WeightedIndex where } #[allow(missing_docs)] // todo: add docs -#[derive(Debug, Clone)] -pub struct AliasMethodWeightedIndex { +pub struct AliasMethodWeightedIndex { aliases: Vec, - no_alias_odds: Vec, + no_alias_odds: Vec, uniform_index: super::Uniform, - uniform_within_weight_sum: super::Uniform, + uniform_within_weight_sum: super::Uniform, } -impl AliasMethodWeightedIndex { +impl AliasMethodWeightedIndex { #[allow(missing_docs)] // todo: add docs - pub fn new(weights: Vec) -> Result { + pub fn new(weights: Vec) -> Result { let n = weights.len(); if n == 0 { return Err(AliasMethodWeightedIndexError::NoItem); } - let max_weight_size = ::core::f64::MAX / n as f64; - if !weights.iter().all(|&w| 0_f64 <= w && w <= max_weight_size) { + let max_weight_size = W::try_from_usize_lossy(n) + .map(|n| W::MAX / n) + .unwrap_or(W::ZERO); + if !weights + .iter() + .all(|&w| W::ZERO <= w && w <= max_weight_size) + { return Err(AliasMethodWeightedIndexError::InvalidWeight); } // The sum of weights will represent 100% of no alias odds. - let weight_sum = pairwise_sum_f64(weights.as_slice()); + let weight_sum = pairwise_sum(weights.as_slice()); // Prevent floating point overflow due to rounding errors. - let weight_sum = weight_sum.min(::core::f64::MAX); - if weight_sum == 0_f64 { + let weight_sum = if weight_sum > W::MAX { + W::MAX + } else { + weight_sum + }; + if weight_sum == W::ZERO { return Err(AliasMethodWeightedIndexError::AllWeightsZero); } + // `weight_sum` would have been zero if `try_from_lossy` causes an error here. + let n_converted = W::try_from_usize_lossy(n).unwrap(); + let mut no_alias_odds = weights; for odds in no_alias_odds.iter_mut() { - *odds *= n as f64; + *odds *= n_converted; // Prevent floating point overflow due to rounding errors. - *odds = odds.min(::core::f64::MAX); + *odds = if *odds > W::MAX { W::MAX } else { *odds }; } /// This struct is designed to contain three data structures at once, @@ -262,7 +275,7 @@ impl AliasMethodWeightedIndex { // Prepare distributions for sampling. Creating them beforehand improves // sampling performance. let uniform_index = super::Uniform::new(0, n); - let uniform_within_weight_sum = super::Uniform::new(0_f64, weight_sum); + let uniform_within_weight_sum = super::Uniform::new(W::ZERO, weight_sum); Ok(Self { aliases: aliases.aliases, @@ -273,10 +286,10 @@ impl AliasMethodWeightedIndex { } } -impl Distribution for AliasMethodWeightedIndex { +impl Distribution for AliasMethodWeightedIndex { fn sample(&self, rng: &mut R) -> usize { let candidate = rng.sample(self.uniform_index); - if rng.sample(self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { + if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { candidate } else { self.aliases[candidate] @@ -284,16 +297,116 @@ impl Distribution for AliasMethodWeightedIndex { } } -fn pairwise_sum_f64(values: &[f64]) -> f64 { +impl fmt::Debug for AliasMethodWeightedIndex +where + W: fmt::Debug, + super::Uniform: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("AliasMethodWeightedIndex") + .field("aliases", &self.aliases) + .field("no_alias_odds", &self.no_alias_odds) + .field("uniform_index", &self.uniform_index) + .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) + .finish() + } +} + +impl Clone for AliasMethodWeightedIndex +where + super::Uniform: Clone, +{ + fn clone(&self) -> Self { + Self { + aliases: self.aliases.clone(), + no_alias_odds: self.no_alias_odds.clone(), + uniform_index: self.uniform_index.clone(), + uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), + } + } +} + +/// In comparision to naive accumulation, the pairwise sum algorithm reduces +/// rounding errors when there are many floating point values. +fn pairwise_sum(values: &[T]) -> T { if values.len() <= 32 { - values.iter().sum() + values.iter().map(|x| *x).sum() } else { let mid = values.len() / 2; let (a, b) = values.split_at(mid); - pairwise_sum_f64(a) + pairwise_sum_f64(b) + pairwise_sum(a) + pairwise_sum(b) } } +pub trait AliasMethodWeight: + Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Mul + + MulAssign + + Div + + DivAssign + + Sum +{ + const MAX: Self; + const ZERO: Self; + + fn try_from_usize_lossy(n: usize) -> Option; +} + +macro_rules! impl_alias_method_weight_for_float { + ($T: ident) => { + impl AliasMethodWeight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0.0; + + fn try_from_usize_lossy(n: usize) -> Option { + Some(n as $T) + } + } + }; +} + +macro_rules! impl_alias_method_weight_for_int { + ($T: ident) => { + impl AliasMethodWeight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0; + + fn try_from_usize_lossy(n: usize) -> Option { + let n_converted = n as Self; + if n_converted >= Self::ZERO && n_converted as usize == n { + Some(n_converted) + } else { + None + } + } + } + }; +} + +impl_alias_method_weight_for_float!(f64); +impl_alias_method_weight_for_float!(f32); +impl_alias_method_weight_for_int!(usize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_alias_method_weight_for_int!(u128); +impl_alias_method_weight_for_int!(u64); +impl_alias_method_weight_for_int!(u32); +impl_alias_method_weight_for_int!(u16); +impl_alias_method_weight_for_int!(u8); +impl_alias_method_weight_for_int!(isize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_alias_method_weight_for_int!(i128); +impl_alias_method_weight_for_int!(i64); +impl_alias_method_weight_for_int!(i32); +impl_alias_method_weight_for_int!(i16); +impl_alias_method_weight_for_int!(i8); + #[cfg(test)] mod test { use super::*; @@ -354,28 +467,99 @@ mod test { } #[test] - fn test_alias_method_weighted_index() { + fn test_alias_method_weighted_index_f32() { + test_alias_method_weighted_index(f32::into); + + // Floating point special cases + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-0_f32]).unwrap_err(), + AliasMethodWeightedIndexError::AllWeightsZero + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1_f32]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_alias_method_weighted_index_u128() { + test_alias_method_weighted_index(|x: u128| x as f64); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_alias_method_weighted_index_i128() { + test_alias_method_weighted_index(|x: i128| x as f64); + + // Signed integer special cases + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1_i128]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } + + #[test] + fn test_alias_method_weighted_index_u8() { + test_alias_method_weighted_index(u8::into); + } + + #[test] + fn test_alias_method_weighted_index_i8() { + test_alias_method_weighted_index(i8::into); + + // Signed integer special cases + assert_eq!( + AliasMethodWeightedIndex::new(vec![-1_i8]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + assert_eq!( + AliasMethodWeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(), + AliasMethodWeightedIndexError::InvalidWeight + ); + } + + fn test_alias_method_weighted_index f64>(w_to_f64: F) + where + AliasMethodWeightedIndex: fmt::Debug, + { const NUM_WEIGHTS: usize = 10; const ZERO_WEIGHT_INDEX: usize = 3; - const NUM_SAMPLES: u32 = 10000; + const NUM_SAMPLES: u32 = 15000; let mut rng = ::test::rng(0x9c9fa0b0580a7031); let weights = { let mut weights = Vec::with_capacity(NUM_WEIGHTS); let random_weight_distribution = ::distributions::Uniform::new_inclusive( - 0_f64, - ::core::f64::MAX / NUM_WEIGHTS as f64, + W::ZERO, + W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(), ); for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample(random_weight_distribution)); + weights.push(rng.sample(&random_weight_distribution)); } - weights[ZERO_WEIGHT_INDEX] = 0.0; + weights[ZERO_WEIGHT_INDEX] = W::ZERO; weights }; - let weight_sum = weights.iter().sum::(); + let weight_sum = weights.iter().map(|w| *w).sum::(); let expected_counts = weights .iter() - .map(|&w| w / weight_sum * NUM_SAMPLES as f64) + .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) .collect::>(); let weight_distribution = AliasMethodWeightedIndex::new(weights).unwrap(); @@ -392,35 +576,15 @@ mod test { } assert_eq!( - AliasMethodWeightedIndex::new(vec![]).unwrap_err(), + AliasMethodWeightedIndex::::new(vec![]).unwrap_err(), AliasMethodWeightedIndexError::NoItem ); assert_eq!( - AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(), + AliasMethodWeightedIndex::new(vec![W::ZERO]).unwrap_err(), AliasMethodWeightedIndexError::AllWeightsZero ); assert_eq!( - AliasMethodWeightedIndex::new(vec![-0.0]).unwrap_err(), - AliasMethodWeightedIndexError::AllWeightsZero - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-::core::f64::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f64::NAN]).unwrap_err(), + AliasMethodWeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), AliasMethodWeightedIndexError::InvalidWeight ); } From 2af10fa047dae02907aeb93bd16fbdc4c4318b12 Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Sat, 23 Feb 2019 13:25:30 +0100 Subject: [PATCH 06/10] Added documentation for AliasMethodWeightedIndex --- src/distributions/weighted.rs | 70 +++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 520c1fd1a97..0a09c698f9c 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -132,7 +132,54 @@ impl Distribution for WeightedIndex where } } -#[allow(missing_docs)] // todo: add docs +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling an [`AliasMethodWeightedIndex`] distribution returns the index +/// of a randomly selected element from the vector used to create the +/// [`AliasMethodWeightedIndex`]. The chance of a given element being picked +/// is proportional to the value of the element. The weights can have any type +/// `W` for which an implementation of [`AliasMethodWeight`] exists. +/// +/// # Performance +/// +/// Given that `n` is the number of items in the vector used to create an +/// [`AliasMethodWeightedIndex`], [`AliasMethodWeightedIndex`] will take +/// up `O(n)` amount of memory. More specifically it takes up some constant +/// amount of memory plus the vector used to create it and a [`Vec`] with +/// capacity `n`. +/// +/// Time complexity for the creation of an [`AliasMethodWeightedIndex`] is +/// `O(n)`. Sampling is `O(1)`, it makes a call to [`Uniform::sample`] +/// and a call to [`Uniform::sample`]. +/// +/// # Example +/// +/// ``` +/// use rand::distributions::AliasMethodWeightedIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 1, 1]; +/// let dist = AliasMethodWeightedIndex::new(weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = AliasMethodWeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`AliasMethodWeightedIndex`]: AliasMethodWeightedIndex +/// [`AliasMethodWeight`]: AliasMethodWeight +/// [`Vec`]: Vec +/// [`Uniform::sample`]: Distribution::sample +/// [`Uniform::sample`]: Distribution::sample pub struct AliasMethodWeightedIndex { aliases: Vec, no_alias_odds: Vec, @@ -141,7 +188,13 @@ pub struct AliasMethodWeightedIndex { } impl AliasMethodWeightedIndex { - #[allow(missing_docs)] // todo: add docs + /// Creates an new [`AliasMethodWeightedIndex`]. + /// + /// Returns an error if: + /// - The vector is empty. + /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / + /// weights.len()`. + /// - The sum of weights is zero. pub fn new(weights: Vec) -> Result { let n = weights.len(); if n == 0 { @@ -338,6 +391,10 @@ fn pairwise_sum(values: &[T]) -> T { } } +/// Trait that must be implemented for weights, that are used with +/// [`AliasMethodWeightedIndex`]. Currently no guarantees on the correctness of +/// [`AliasMethodWeightedIndex`] are given for custom implementations of this +/// trait. pub trait AliasMethodWeight: Sized + Copy @@ -353,9 +410,13 @@ pub trait AliasMethodWeight: + DivAssign + Sum { + /// Maximum number representable by `Self`. const MAX: Self; + + /// Element of `Self` equivalent to 0. const ZERO: Self; + /// Converts an [`usize`] to a `Self`, rounding if necessary. fn try_from_usize_lossy(n: usize) -> Option; } @@ -629,11 +690,14 @@ impl fmt::Display for WeightedError { } } -#[allow(missing_docs)] // todo: add docs +/// Error type returned from [`AliasMethodWeightedIndex::new`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AliasMethodWeightedIndexError { + /// The weight vector is empty. NoItem, + /// A weight is either less than zero or greater than the supported maximum. InvalidWeight, + /// All weights in the provided vector are zero. AllWeightsZero, } From 9c44b6aa510b25408ce050d2785b5a9fcb4e98de Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Mon, 25 Feb 2019 22:35:31 +0100 Subject: [PATCH 07/10] Addressed documentation issues from review --- src/distributions/weighted.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 0a09c698f9c..2addc6280ed 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -137,14 +137,14 @@ impl Distribution for WeightedIndex where /// Sampling an [`AliasMethodWeightedIndex`] distribution returns the index /// of a randomly selected element from the vector used to create the /// [`AliasMethodWeightedIndex`]. The chance of a given element being picked -/// is proportional to the value of the element. The weights can have any type +/// is proportional to the value of the element. The weights can have any type /// `W` for which an implementation of [`AliasMethodWeight`] exists. /// /// # Performance /// /// Given that `n` is the number of items in the vector used to create an -/// [`AliasMethodWeightedIndex`], [`AliasMethodWeightedIndex`] will take -/// up `O(n)` amount of memory. More specifically it takes up some constant +/// [`AliasMethodWeightedIndex`], [`AliasMethodWeightedIndex`] will +/// require `O(n)` amount of memory. More specifically it takes up some constant /// amount of memory plus the vector used to create it and a [`Vec`] with /// capacity `n`. /// @@ -379,7 +379,7 @@ where } } -/// In comparision to naive accumulation, the pairwise sum algorithm reduces +/// In comparison to naive accumulation, the pairwise sum algorithm reduces /// rounding errors when there are many floating point values. fn pairwise_sum(values: &[T]) -> T { if values.len() <= 32 { @@ -416,7 +416,7 @@ pub trait AliasMethodWeight: /// Element of `Self` equivalent to 0. const ZERO: Self; - /// Converts an [`usize`] to a `Self`, rounding if necessary. + /// Converts a [`usize`] to a `Self`, rounding if necessary. fn try_from_usize_lossy(n: usize) -> Option; } From 8641a9bb1624e9a984a69ae0c24fcf16ee846bc8 Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Mon, 25 Feb 2019 22:35:31 +0100 Subject: [PATCH 08/10] Use pairwise sum only for floating point weights --- src/distributions/weighted.rs | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 2addc6280ed..b46b00fe72c 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -212,7 +212,7 @@ impl AliasMethodWeightedIndex { } // The sum of weights will represent 100% of no alias odds. - let weight_sum = pairwise_sum(weights.as_slice()); + let weight_sum = AliasMethodWeight::sum(weights.as_slice()); // Prevent floating point overflow due to rounding errors. let weight_sum = if weight_sum > W::MAX { W::MAX @@ -379,18 +379,6 @@ where } } -/// In comparison to naive accumulation, the pairwise sum algorithm reduces -/// rounding errors when there are many floating point values. -fn pairwise_sum(values: &[T]) -> T { - if values.len() <= 32 { - values.iter().map(|x| *x).sum() - } else { - let mid = values.len() / 2; - let (a, b) = values.split_at(mid); - pairwise_sum(a) + pairwise_sum(b) - } -} - /// Trait that must be implemented for weights, that are used with /// [`AliasMethodWeightedIndex`]. Currently no guarantees on the correctness of /// [`AliasMethodWeightedIndex`] are given for custom implementations of this @@ -418,6 +406,11 @@ pub trait AliasMethodWeight: /// Converts a [`usize`] to a `Self`, rounding if necessary. fn try_from_usize_lossy(n: usize) -> Option; + + /// Sums all values in slice `values`. + fn sum(values: &[Self]) -> Self { + values.iter().map(|x| *x).sum() + } } macro_rules! impl_alias_method_weight_for_float { @@ -429,10 +422,26 @@ macro_rules! impl_alias_method_weight_for_float { fn try_from_usize_lossy(n: usize) -> Option { Some(n as $T) } + + fn sum(values: &[Self]) -> Self { + pairwise_sum(values) + } } }; } +/// In comparison to naive accumulation, the pairwise sum algorithm reduces +/// rounding errors when there are many floating point values. +fn pairwise_sum(values: &[T]) -> T { + if values.len() <= 32 { + values.iter().map(|x| *x).sum() + } else { + let mid = values.len() / 2; + let (a, b) = values.split_at(mid); + pairwise_sum(a) + pairwise_sum(b) + } +} + macro_rules! impl_alias_method_weight_for_int { ($T: ident) => { impl AliasMethodWeight for $T { From 5b2934122f04e1ff41024b3c0597ace711d21f08 Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Wed, 6 Mar 2019 17:49:50 +0100 Subject: [PATCH 09/10] Reorganized distributions::weighted module --- benches/distributions.rs | 8 +- src/distributions/mod.rs | 8 +- src/distributions/weighted.rs | 737 --------------------- src/distributions/weighted/alias_method.rs | 484 ++++++++++++++ src/distributions/weighted/mod.rs | 236 +++++++ src/seq/mod.rs | 4 +- 6 files changed, 728 insertions(+), 749 deletions(-) delete mode 100644 src/distributions/weighted.rs create mode 100644 src/distributions/weighted/alias_method.rs create mode 100644 src/distributions/weighted/mod.rs diff --git a/benches/distributions.rs b/benches/distributions.rs index 0efa50d71d3..ccf78b0310d 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -187,10 +187,10 @@ distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); -distr_int!(distr_weighted_alias_method_i8, usize, AliasMethodWeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_alias_method_u32, usize, AliasMethodWeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_alias_method_f64, usize, AliasMethodWeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); -distr_int!(distr_weighted_alias_method_large_set, usize, AliasMethodWeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); +distr_int!(distr_weighted_alias_method_i8, usize, weighted::alias_method::WeightedIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_u32, usize, weighted::alias_method::WeightedIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); +distr_int!(distr_weighted_alias_method_f64, usize, weighted::alias_method::WeightedIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); +distr_int!(distr_weighted_alias_method_large_set, usize, weighted::alias_method::WeightedIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); // construct and sample from a range macro_rules! gen_range_int { diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index ab3f8541b70..c8ef0ef96d8 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -181,11 +181,7 @@ pub use self::other::Alphanumeric; #[doc(inline)] pub use self::uniform::Uniform; pub use self::float::{OpenClosed01, Open01}; pub use self::bernoulli::Bernoulli; -#[cfg(feature = "alloc")] -pub use self::weighted::{ - AliasMethodWeight, AliasMethodWeightedIndex, AliasMethodWeightedIndexError, WeightedError, - WeightedIndex, -}; +#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError}; #[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface; #[cfg(feature="std")] pub use self::unit_circle::UnitCircle; #[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, @@ -202,7 +198,7 @@ pub use self::weighted::{ pub mod uniform; mod bernoulli; -#[cfg(feature="alloc")] mod weighted; +#[cfg(feature="alloc")] pub mod weighted; #[cfg(feature="std")] mod unit_sphere; #[cfg(feature="std")] mod unit_circle; #[cfg(feature="std")] mod gamma; diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs deleted file mode 100644 index b46b00fe72c..00000000000 --- a/src/distributions/weighted.rs +++ /dev/null @@ -1,737 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use Rng; -use distributions::Distribution; -use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; -use ::core::cmp::PartialOrd; -use core::fmt; -use core::iter::Sum; -use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; - -// Note that this whole module is only imported if feature="alloc" is enabled. -#[cfg(not(feature="std"))] use alloc::vec::Vec; - -/// A distribution using weighted sampling to pick a discretely selected -/// item. -/// -/// Sampling a `WeightedIndex` distribution returns the index of a randomly -/// selected element from the iterator used when the `WeightedIndex` was -/// created. The chance of a given element being picked is proportional to the -/// value of the element. The weights can use any type `X` for which an -/// implementation of [`Uniform`] exists. -/// -/// # Performance -/// -/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its -/// size is the sum of the size of those objects, possibly plus some alignment. -/// -/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` -/// weights of type `X`, where `N` is the number of weights. However, since -/// `Vec` doesn't guarantee a particular growth strategy, additional memory -/// might be allocated but not used. Since the `WeightedIndex` object also -/// contains, this might cause additional allocations, though for primitive -/// types, ['Uniform`] doesn't allocate any memory. -/// -/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. -/// -/// Sampling from `WeightedIndex` will result in a single call to -/// `Uniform::sample` (method of the [`Distribution`] trait), which typically -/// will request a single value from the underlying [`RngCore`], though the -/// exact number depends on the implementaiton of `Uniform::sample`. -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::WeightedIndex; -/// -/// let choices = ['a', 'b', 'c']; -/// let weights = [2, 1, 1]; -/// let dist = WeightedIndex::new(&weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`Uniform`]: crate::distributions::uniform::Uniform -/// [`RngCore`]: rand_core::RngCore -#[derive(Debug, Clone)] -pub struct WeightedIndex { - cumulative_weights: Vec, - weight_distribution: X::Sampler, -} - -impl WeightedIndex { - /// Creates a new a `WeightedIndex` [`Distribution`] using the values - /// in `weights`. The weights can use any type `X` for which an - /// implementation of [`Uniform`] exists. - /// - /// Returns an error if the iterator is empty, if any weight is `< 0`, or - /// if its total value is 0. - /// - /// [`Uniform`]: crate::distributions::uniform::Uniform - pub fn new(weights: I) -> Result, WeightedError> - where I: IntoIterator, - I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + - Clone + - Default { - let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next() - .ok_or(WeightedError::NoItem)? - .borrow() - .clone(); - - let zero = ::default(); - if total_weight < zero { - return Err(WeightedError::NegativeWeight); - } - - let mut weights = Vec::::with_capacity(iter.size_hint().0); - for w in iter { - if *w.borrow() < zero { - return Err(WeightedError::NegativeWeight); - } - weights.push(total_weight.clone()); - total_weight += w.borrow(); - } - - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); - } - let distr = X::Sampler::new(zero, total_weight); - - Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) - } -} - -impl Distribution for WeightedIndex where - X: SampleUniform + PartialOrd { - fn sample(&self, rng: &mut R) -> usize { - use ::core::cmp::Ordering; - let chosen_weight = self.weight_distribution.sample(rng); - // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights.binary_search_by( - |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err() - } -} - -/// A distribution using weighted sampling to pick a discretely selected item. -/// -/// Sampling an [`AliasMethodWeightedIndex`] distribution returns the index -/// of a randomly selected element from the vector used to create the -/// [`AliasMethodWeightedIndex`]. The chance of a given element being picked -/// is proportional to the value of the element. The weights can have any type -/// `W` for which an implementation of [`AliasMethodWeight`] exists. -/// -/// # Performance -/// -/// Given that `n` is the number of items in the vector used to create an -/// [`AliasMethodWeightedIndex`], [`AliasMethodWeightedIndex`] will -/// require `O(n)` amount of memory. More specifically it takes up some constant -/// amount of memory plus the vector used to create it and a [`Vec`] with -/// capacity `n`. -/// -/// Time complexity for the creation of an [`AliasMethodWeightedIndex`] is -/// `O(n)`. Sampling is `O(1)`, it makes a call to [`Uniform::sample`] -/// and a call to [`Uniform::sample`]. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::AliasMethodWeightedIndex; -/// use rand::prelude::*; -/// -/// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 1, 1]; -/// let dist = AliasMethodWeightedIndex::new(weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = AliasMethodWeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`AliasMethodWeightedIndex`]: AliasMethodWeightedIndex -/// [`AliasMethodWeight`]: AliasMethodWeight -/// [`Vec`]: Vec -/// [`Uniform::sample`]: Distribution::sample -/// [`Uniform::sample`]: Distribution::sample -pub struct AliasMethodWeightedIndex { - aliases: Vec, - no_alias_odds: Vec, - uniform_index: super::Uniform, - uniform_within_weight_sum: super::Uniform, -} - -impl AliasMethodWeightedIndex { - /// Creates an new [`AliasMethodWeightedIndex`]. - /// - /// Returns an error if: - /// - The vector is empty. - /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / - /// weights.len()`. - /// - The sum of weights is zero. - pub fn new(weights: Vec) -> Result { - let n = weights.len(); - if n == 0 { - return Err(AliasMethodWeightedIndexError::NoItem); - } - - let max_weight_size = W::try_from_usize_lossy(n) - .map(|n| W::MAX / n) - .unwrap_or(W::ZERO); - if !weights - .iter() - .all(|&w| W::ZERO <= w && w <= max_weight_size) - { - return Err(AliasMethodWeightedIndexError::InvalidWeight); - } - - // The sum of weights will represent 100% of no alias odds. - let weight_sum = AliasMethodWeight::sum(weights.as_slice()); - // Prevent floating point overflow due to rounding errors. - let weight_sum = if weight_sum > W::MAX { - W::MAX - } else { - weight_sum - }; - if weight_sum == W::ZERO { - return Err(AliasMethodWeightedIndexError::AllWeightsZero); - } - - // `weight_sum` would have been zero if `try_from_lossy` causes an error here. - let n_converted = W::try_from_usize_lossy(n).unwrap(); - - let mut no_alias_odds = weights; - for odds in no_alias_odds.iter_mut() { - *odds *= n_converted; - // Prevent floating point overflow due to rounding errors. - *odds = if *odds > W::MAX { W::MAX } else { *odds }; - } - - /// This struct is designed to contain three data structures at once, - /// sharing the same memory. More precisely it contains two - /// linked lists and an alias map, which will be the output of this - /// method. To keep the three data structures from getting in - /// each other's way, it must be ensured that a single index is only - /// ever in one of them at the same time. - struct Aliases { - aliases: Vec, - smalls_head: usize, - bigs_head: usize, - } - - impl Aliases { - fn new(size: usize) -> Self { - Aliases { - aliases: vec![0; size], - smalls_head: ::core::usize::MAX, - bigs_head: ::core::usize::MAX, - } - } - - fn push_small(&mut self, idx: usize) { - self.aliases[idx] = self.smalls_head; - self.smalls_head = idx; - } - - fn push_big(&mut self, idx: usize) { - self.aliases[idx] = self.bigs_head; - self.bigs_head = idx; - } - - fn pop_small(&mut self) -> usize { - let popped = self.smalls_head; - self.smalls_head = self.aliases[popped]; - popped - } - - fn pop_big(&mut self) -> usize { - let popped = self.bigs_head; - self.bigs_head = self.aliases[popped]; - popped - } - - fn smalls_is_empty(&self) -> bool { - self.smalls_head == ::core::usize::MAX - } - - fn bigs_is_empty(&self) -> bool { - self.bigs_head == ::core::usize::MAX - } - - fn set_alias(&mut self, idx: usize, alias: usize) { - self.aliases[idx] = alias; - } - } - - let mut aliases = Aliases::new(n); - - // Split indices into those with small weights and those with big weights. - for (index, &odds) in no_alias_odds.iter().enumerate() { - if odds < weight_sum { - aliases.push_small(index); - } else { - aliases.push_big(index); - } - } - - // Build the alias map by finding an alias with big weight for each index with - // small weight. - while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { - let s = aliases.pop_small(); - let b = aliases.pop_big(); - - aliases.set_alias(s, b); - no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s]; - - if no_alias_odds[b] < weight_sum { - aliases.push_small(b); - } else { - aliases.push_big(b); - } - } - - // The remaining indices should have no alias odds of about 100%. This is due to - // numeric accuracy. Otherwise they would be exactly 100%. - while !aliases.smalls_is_empty() { - no_alias_odds[aliases.pop_small()] = weight_sum; - } - while !aliases.bigs_is_empty() { - no_alias_odds[aliases.pop_big()] = weight_sum; - } - - // Prepare distributions for sampling. Creating them beforehand improves - // sampling performance. - let uniform_index = super::Uniform::new(0, n); - let uniform_within_weight_sum = super::Uniform::new(W::ZERO, weight_sum); - - Ok(Self { - aliases: aliases.aliases, - no_alias_odds, - uniform_index, - uniform_within_weight_sum, - }) - } -} - -impl Distribution for AliasMethodWeightedIndex { - fn sample(&self, rng: &mut R) -> usize { - let candidate = rng.sample(self.uniform_index); - if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { - candidate - } else { - self.aliases[candidate] - } - } -} - -impl fmt::Debug for AliasMethodWeightedIndex -where - W: fmt::Debug, - super::Uniform: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("AliasMethodWeightedIndex") - .field("aliases", &self.aliases) - .field("no_alias_odds", &self.no_alias_odds) - .field("uniform_index", &self.uniform_index) - .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) - .finish() - } -} - -impl Clone for AliasMethodWeightedIndex -where - super::Uniform: Clone, -{ - fn clone(&self) -> Self { - Self { - aliases: self.aliases.clone(), - no_alias_odds: self.no_alias_odds.clone(), - uniform_index: self.uniform_index.clone(), - uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), - } - } -} - -/// Trait that must be implemented for weights, that are used with -/// [`AliasMethodWeightedIndex`]. Currently no guarantees on the correctness of -/// [`AliasMethodWeightedIndex`] are given for custom implementations of this -/// trait. -pub trait AliasMethodWeight: - Sized - + Copy - + SampleUniform - + PartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Mul - + MulAssign - + Div - + DivAssign - + Sum -{ - /// Maximum number representable by `Self`. - const MAX: Self; - - /// Element of `Self` equivalent to 0. - const ZERO: Self; - - /// Converts a [`usize`] to a `Self`, rounding if necessary. - fn try_from_usize_lossy(n: usize) -> Option; - - /// Sums all values in slice `values`. - fn sum(values: &[Self]) -> Self { - values.iter().map(|x| *x).sum() - } -} - -macro_rules! impl_alias_method_weight_for_float { - ($T: ident) => { - impl AliasMethodWeight for $T { - const MAX: Self = ::core::$T::MAX; - const ZERO: Self = 0.0; - - fn try_from_usize_lossy(n: usize) -> Option { - Some(n as $T) - } - - fn sum(values: &[Self]) -> Self { - pairwise_sum(values) - } - } - }; -} - -/// In comparison to naive accumulation, the pairwise sum algorithm reduces -/// rounding errors when there are many floating point values. -fn pairwise_sum(values: &[T]) -> T { - if values.len() <= 32 { - values.iter().map(|x| *x).sum() - } else { - let mid = values.len() / 2; - let (a, b) = values.split_at(mid); - pairwise_sum(a) + pairwise_sum(b) - } -} - -macro_rules! impl_alias_method_weight_for_int { - ($T: ident) => { - impl AliasMethodWeight for $T { - const MAX: Self = ::core::$T::MAX; - const ZERO: Self = 0; - - fn try_from_usize_lossy(n: usize) -> Option { - let n_converted = n as Self; - if n_converted >= Self::ZERO && n_converted as usize == n { - Some(n_converted) - } else { - None - } - } - } - }; -} - -impl_alias_method_weight_for_float!(f64); -impl_alias_method_weight_for_float!(f32); -impl_alias_method_weight_for_int!(usize); -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] -impl_alias_method_weight_for_int!(u128); -impl_alias_method_weight_for_int!(u64); -impl_alias_method_weight_for_int!(u32); -impl_alias_method_weight_for_int!(u16); -impl_alias_method_weight_for_int!(u8); -impl_alias_method_weight_for_int!(isize); -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] -impl_alias_method_weight_for_int!(i128); -impl_alias_method_weight_for_int!(i64); -impl_alias_method_weight_for_int!(i32); -impl_alias_method_weight_for_int!(i16); -impl_alias_method_weight_for_int!(i8); - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_weightedindex() { - let mut r = ::test::rng(700); - const N_REPS: u32 = 5000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // WeightedIndex from vec - let mut chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from slice - chosen = [0i32; 14]; - let distr = WeightedIndex::new(&weights[..]).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from iterator - chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.iter()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - for _ in 0..5 { - assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); - assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); - assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4); - } - - assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); - assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); - assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight); - assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); - assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); - } - - #[test] - fn test_alias_method_weighted_index_f32() { - test_alias_method_weighted_index(f32::into); - - // Floating point special cases - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-0_f32]).unwrap_err(), - AliasMethodWeightedIndexError::AllWeightsZero - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-1_f32]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - } - - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] - #[test] - fn test_alias_method_weighted_index_u128() { - test_alias_method_weighted_index(|x: u128| x as f64); - } - - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] - #[test] - fn test_alias_method_weighted_index_i128() { - test_alias_method_weighted_index(|x: i128| x as f64); - - // Signed integer special cases - assert_eq!( - AliasMethodWeightedIndex::new(vec![-1_i128]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - } - - #[test] - fn test_alias_method_weighted_index_u8() { - test_alias_method_weighted_index(u8::into); - } - - #[test] - fn test_alias_method_weighted_index_i8() { - test_alias_method_weighted_index(i8::into); - - // Signed integer special cases - assert_eq!( - AliasMethodWeightedIndex::new(vec![-1_i8]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - } - - fn test_alias_method_weighted_index f64>(w_to_f64: F) - where - AliasMethodWeightedIndex: fmt::Debug, - { - const NUM_WEIGHTS: usize = 10; - const ZERO_WEIGHT_INDEX: usize = 3; - const NUM_SAMPLES: u32 = 15000; - let mut rng = ::test::rng(0x9c9fa0b0580a7031); - - let weights = { - let mut weights = Vec::with_capacity(NUM_WEIGHTS); - let random_weight_distribution = ::distributions::Uniform::new_inclusive( - W::ZERO, - W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(), - ); - for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample(&random_weight_distribution)); - } - weights[ZERO_WEIGHT_INDEX] = W::ZERO; - weights - }; - let weight_sum = weights.iter().map(|w| *w).sum::(); - let expected_counts = weights - .iter() - .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) - .collect::>(); - let weight_distribution = AliasMethodWeightedIndex::new(weights).unwrap(); - - let mut counts = vec![0_usize; NUM_WEIGHTS]; - for _ in 0..NUM_SAMPLES { - counts[rng.sample(&weight_distribution)] += 1; - } - - assert_eq!(counts[ZERO_WEIGHT_INDEX], 0); - for (count, expected_count) in counts.into_iter().zip(expected_counts) { - let difference = (count as f64 - expected_count).abs(); - let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; - assert!(difference <= max_allowed_difference); - } - - assert_eq!( - AliasMethodWeightedIndex::::new(vec![]).unwrap_err(), - AliasMethodWeightedIndexError::NoItem - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![W::ZERO]).unwrap_err(), - AliasMethodWeightedIndexError::AllWeightsZero - ); - assert_eq!( - AliasMethodWeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - AliasMethodWeightedIndexError::InvalidWeight - ); - } -} - -/// Error type returned from `WeightedIndex::new`. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightedError { - /// The provided iterator contained no items. - NoItem, - - /// A weight lower than zero was used. - NegativeWeight, - - /// All items in the provided iterator had a weight of zero. - AllWeightsZero, -} - -impl WeightedError { - fn msg(&self) -> &str { - match *self { - WeightedError::NoItem => "No items found", - WeightedError::NegativeWeight => "Item has negative weight", - WeightedError::AllWeightsZero => "All items had weight zero", - } - } -} - -#[cfg(feature="std")] -impl ::std::error::Error for WeightedError { - fn description(&self) -> &str { - self.msg() - } - fn cause(&self) -> Option<&::std::error::Error> { - None - } -} - -impl fmt::Display for WeightedError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.msg()) - } -} - -/// Error type returned from [`AliasMethodWeightedIndex::new`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum AliasMethodWeightedIndexError { - /// The weight vector is empty. - NoItem, - /// A weight is either less than zero or greater than the supported maximum. - InvalidWeight, - /// All weights in the provided vector are zero. - AllWeightsZero, -} - -impl AliasMethodWeightedIndexError { - fn msg(&self) -> &str { - match *self { - AliasMethodWeightedIndexError::NoItem => "No items found.", - AliasMethodWeightedIndexError::InvalidWeight => "An item has an invalid weight.", - AliasMethodWeightedIndexError::AllWeightsZero => "All weights are zero.", - } - } -} - -impl fmt::Display for AliasMethodWeightedIndexError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(self.msg()) - } -} - -#[cfg(feature = "std")] -impl ::std::error::Error for AliasMethodWeightedIndexError { - fn description(&self) -> &str { - self.msg() - } - fn cause(&self) -> Option<&::std::error::Error> { - None - } -} diff --git a/src/distributions/weighted/alias_method.rs b/src/distributions/weighted/alias_method.rs new file mode 100644 index 00000000000..4e6dce7adaf --- /dev/null +++ b/src/distributions/weighted/alias_method.rs @@ -0,0 +1,484 @@ +//! This module contains an implementation of alias method for sampling random +//! indices with probabilities proportional to a collection of weights. + +use super::WeightedError; +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; +use core::fmt; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; +use distributions::uniform::SampleUniform; +use distributions::Distribution; +use distributions::Uniform; +use Rng; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling an [`WeightedIndex`] distribution returns the index of a +/// randomly selected element from the vector used to create the +/// [`WeightedIndex`]. The chance of a given element being picked is +/// proportional to the value of the element. The weights can have any type `W` +/// for which an implementation of [`Weight`] exists. +/// +/// # Performance +/// +/// Given that `n` is the number of items in the vector used to create an +/// [`WeightedIndex`], [`WeightedIndex`] will require `O(n)` amount of +/// memory. More specifically it takes up some constant amount of memory plus +/// the vector used to create it and a [`Vec`] with capacity `n`. +/// +/// Time complexity for the creation of an [`WeightedIndex`] is `O(n)`. +/// Sampling is `O(1)`, it makes a call to [`Uniform::sample`] and a call +/// to [`Uniform::sample`]. +/// +/// # Example +/// +/// ``` +/// use rand::distributions::weighted::alias_method::WeightedIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 1, 1]; +/// let dist = WeightedIndex::new(weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`WeightedIndex`]: WeightedIndex +/// [`Weight`]: Weight +/// [`Vec`]: Vec +/// [`Uniform::sample`]: Distribution::sample +/// [`Uniform::sample`]: Distribution::sample +pub struct WeightedIndex { + aliases: Vec, + no_alias_odds: Vec, + uniform_index: Uniform, + uniform_within_weight_sum: Uniform, +} + +impl WeightedIndex { + /// Creates an new [`WeightedIndex`]. + /// + /// Returns an error if: + /// - The vector is empty. + /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / + /// weights.len()`. + /// - The sum of weights is zero. + pub fn new(weights: Vec) -> Result { + let n = weights.len(); + if n == 0 { + return Err(WeightedError::NoItem); + } + + let max_weight_size = W::try_from_usize_lossy(n) + .map(|n| W::MAX / n) + .unwrap_or(W::ZERO); + if !weights + .iter() + .all(|&w| W::ZERO <= w && w <= max_weight_size) + { + return Err(WeightedError::InvalidWeight); + } + + // The sum of weights will represent 100% of no alias odds. + let weight_sum = Weight::sum(weights.as_slice()); + // Prevent floating point overflow due to rounding errors. + let weight_sum = if weight_sum > W::MAX { + W::MAX + } else { + weight_sum + }; + if weight_sum == W::ZERO { + return Err(WeightedError::AllWeightsZero); + } + + // `weight_sum` would have been zero if `try_from_lossy` causes an error here. + let n_converted = W::try_from_usize_lossy(n).unwrap(); + + let mut no_alias_odds = weights; + for odds in no_alias_odds.iter_mut() { + *odds *= n_converted; + // Prevent floating point overflow due to rounding errors. + *odds = if *odds > W::MAX { W::MAX } else { *odds }; + } + + /// This struct is designed to contain three data structures at once, + /// sharing the same memory. More precisely it contains two linked lists + /// and an alias map, which will be the output of this method. To keep + /// the three data structures from getting in each other's way, it must + /// be ensured that a single index is only ever in one of them at the + /// same time. + struct Aliases { + aliases: Vec, + smalls_head: usize, + bigs_head: usize, + } + + impl Aliases { + fn new(size: usize) -> Self { + Aliases { + aliases: vec![0; size], + smalls_head: ::core::usize::MAX, + bigs_head: ::core::usize::MAX, + } + } + + fn push_small(&mut self, idx: usize) { + self.aliases[idx] = self.smalls_head; + self.smalls_head = idx; + } + + fn push_big(&mut self, idx: usize) { + self.aliases[idx] = self.bigs_head; + self.bigs_head = idx; + } + + fn pop_small(&mut self) -> usize { + let popped = self.smalls_head; + self.smalls_head = self.aliases[popped]; + popped + } + + fn pop_big(&mut self) -> usize { + let popped = self.bigs_head; + self.bigs_head = self.aliases[popped]; + popped + } + + fn smalls_is_empty(&self) -> bool { + self.smalls_head == ::core::usize::MAX + } + + fn bigs_is_empty(&self) -> bool { + self.bigs_head == ::core::usize::MAX + } + + fn set_alias(&mut self, idx: usize, alias: usize) { + self.aliases[idx] = alias; + } + } + + let mut aliases = Aliases::new(n); + + // Split indices into those with small weights and those with big weights. + for (index, &odds) in no_alias_odds.iter().enumerate() { + if odds < weight_sum { + aliases.push_small(index); + } else { + aliases.push_big(index); + } + } + + // Build the alias map by finding an alias with big weight for each index with + // small weight. + while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { + let s = aliases.pop_small(); + let b = aliases.pop_big(); + + aliases.set_alias(s, b); + no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s]; + + if no_alias_odds[b] < weight_sum { + aliases.push_small(b); + } else { + aliases.push_big(b); + } + } + + // The remaining indices should have no alias odds of about 100%. This is due to + // numeric accuracy. Otherwise they would be exactly 100%. + while !aliases.smalls_is_empty() { + no_alias_odds[aliases.pop_small()] = weight_sum; + } + while !aliases.bigs_is_empty() { + no_alias_odds[aliases.pop_big()] = weight_sum; + } + + // Prepare distributions for sampling. Creating them beforehand improves + // sampling performance. + let uniform_index = Uniform::new(0, n); + let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum); + + Ok(Self { + aliases: aliases.aliases, + no_alias_odds, + uniform_index, + uniform_within_weight_sum, + }) + } +} + +impl Distribution for WeightedIndex { + fn sample(&self, rng: &mut R) -> usize { + let candidate = rng.sample(self.uniform_index); + if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] { + candidate + } else { + self.aliases[candidate] + } + } +} + +impl fmt::Debug for WeightedIndex +where + W: fmt::Debug, + Uniform: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("WeightedIndex") + .field("aliases", &self.aliases) + .field("no_alias_odds", &self.no_alias_odds) + .field("uniform_index", &self.uniform_index) + .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) + .finish() + } +} + +impl Clone for WeightedIndex +where + Uniform: Clone, +{ + fn clone(&self) -> Self { + Self { + aliases: self.aliases.clone(), + no_alias_odds: self.no_alias_odds.clone(), + uniform_index: self.uniform_index.clone(), + uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), + } + } +} + +/// Trait that must be implemented for weights, that are used with +/// [`WeightedIndex`]. Currently no guarantees on the correctness of +/// [`WeightedIndex`] are given for custom implementations of this trait. +pub trait Weight: + Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Mul + + MulAssign + + Div + + DivAssign + + Sum +{ + /// Maximum number representable by `Self`. + const MAX: Self; + + /// Element of `Self` equivalent to 0. + const ZERO: Self; + + /// Converts a [`usize`] to a `Self`, rounding if necessary. + fn try_from_usize_lossy(n: usize) -> Option; + + /// Sums all values in slice `values`. + fn sum(values: &[Self]) -> Self { + values.iter().map(|x| *x).sum() + } +} + +macro_rules! impl_weight_for_float { + ($T: ident) => { + impl Weight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0.0; + + fn try_from_usize_lossy(n: usize) -> Option { + Some(n as $T) + } + + fn sum(values: &[Self]) -> Self { + pairwise_sum(values) + } + } + }; +} + +/// In comparison to naive accumulation, the pairwise sum algorithm reduces +/// rounding errors when there are many floating point values. +fn pairwise_sum(values: &[T]) -> T { + if values.len() <= 32 { + values.iter().map(|x| *x).sum() + } else { + let mid = values.len() / 2; + let (a, b) = values.split_at(mid); + pairwise_sum(a) + pairwise_sum(b) + } +} + +macro_rules! impl_weight_for_int { + ($T: ident) => { + impl Weight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0; + + fn try_from_usize_lossy(n: usize) -> Option { + let n_converted = n as Self; + if n_converted >= Self::ZERO && n_converted as usize == n { + Some(n_converted) + } else { + None + } + } + } + }; +} + +impl_weight_for_float!(f64); +impl_weight_for_float!(f32); +impl_weight_for_int!(usize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_weight_for_int!(u128); +impl_weight_for_int!(u64); +impl_weight_for_int!(u32); +impl_weight_for_int!(u16); +impl_weight_for_int!(u8); +impl_weight_for_int!(isize); +#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +impl_weight_for_int!(i128); +impl_weight_for_int!(i64); +impl_weight_for_int!(i32); +impl_weight_for_int!(i16); +impl_weight_for_int!(i8); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_weighted_index_f32() { + test_weighted_index(f32::into); + + // Floating point special cases + assert_eq!( + WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![-0_f32]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(vec![-1_f32]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_weighted_index_u128() { + test_weighted_index(|x: u128| x as f64); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + fn test_weighted_index_i128() { + test_weighted_index(|x: i128| x as f64); + + // Signed integer special cases + assert_eq!( + WeightedIndex::new(vec![-1_i128]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[test] + fn test_weighted_index_u8() { + test_weighted_index(u8::into); + } + + #[test] + fn test_weighted_index_i8() { + test_weighted_index(i8::into); + + // Signed integer special cases + assert_eq!( + WeightedIndex::new(vec![-1_i8]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + fn test_weighted_index f64>(w_to_f64: F) + where + WeightedIndex: fmt::Debug, + { + const NUM_WEIGHTS: usize = 10; + const ZERO_WEIGHT_INDEX: usize = 3; + const NUM_SAMPLES: u32 = 15000; + let mut rng = ::test::rng(0x9c9fa0b0580a7031); + + let weights = { + let mut weights = Vec::with_capacity(NUM_WEIGHTS); + let random_weight_distribution = ::distributions::Uniform::new_inclusive( + W::ZERO, + W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(), + ); + for _ in 0..NUM_WEIGHTS { + weights.push(rng.sample(&random_weight_distribution)); + } + weights[ZERO_WEIGHT_INDEX] = W::ZERO; + weights + }; + let weight_sum = weights.iter().map(|w| *w).sum::(); + let expected_counts = weights + .iter() + .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) + .collect::>(); + let weight_distribution = WeightedIndex::new(weights).unwrap(); + + let mut counts = vec![0_usize; NUM_WEIGHTS]; + for _ in 0..NUM_SAMPLES { + counts[rng.sample(&weight_distribution)] += 1; + } + + assert_eq!(counts[ZERO_WEIGHT_INDEX], 0); + for (count, expected_count) in counts.into_iter().zip(expected_counts) { + let difference = (count as f64 - expected_count).abs(); + let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; + assert!(difference <= max_allowed_difference); + } + + assert_eq!( + WeightedIndex::::new(vec![]).unwrap_err(), + WeightedError::NoItem + ); + assert_eq!( + WeightedIndex::new(vec![W::ZERO]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), + WeightedError::InvalidWeight + ); + } +} diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs new file mode 100644 index 00000000000..b58cada9c25 --- /dev/null +++ b/src/distributions/weighted/mod.rs @@ -0,0 +1,236 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! This module contains different algorithms for sampling random indices with +//! probabilities proportional to a collection of weights. + +pub mod alias_method; + +use Rng; +use distributions::Distribution; +use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; +use ::core::cmp::PartialOrd; +use core::fmt; + +// Note that this whole module is only imported if feature="alloc" is enabled. +#[cfg(not(feature="std"))] use alloc::vec::Vec; + +/// A distribution using weighted sampling to pick a discretely selected +/// item. +/// +/// Sampling a `WeightedIndex` distribution returns the index of a randomly +/// selected element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// value of the element. The weights can use any type `X` for which an +/// implementation of [`Uniform`] exists. +/// +/// # Performance +/// +/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its +/// size is the sum of the size of those objects, possibly plus some alignment. +/// +/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` +/// weights of type `X`, where `N` is the number of weights. However, since +/// `Vec` doesn't guarantee a particular growth strategy, additional memory +/// might be allocated but not used. Since the `WeightedIndex` object also +/// contains, this might cause additional allocations, though for primitive +/// types, ['Uniform`] doesn't allocate any memory. +/// +/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where +/// `N` is the number of weights. +/// +/// Sampling from `WeightedIndex` will result in a single call to +/// `Uniform::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementaiton of `Uniform::sample`. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::WeightedIndex; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`Uniform`]: crate::distributions::uniform::Uniform +/// [`RngCore`]: rand_core::RngCore +#[derive(Debug, Clone)] +pub struct WeightedIndex { + cumulative_weights: Vec, + weight_distribution: X::Sampler, +} + +impl WeightedIndex { + /// Creates a new a `WeightedIndex` [`Distribution`] using the values + /// in `weights`. The weights can use any type `X` for which an + /// implementation of [`Uniform`] exists. + /// + /// Returns an error if the iterator is empty, if any weight is `< 0`, or + /// if its total value is 0. + /// + /// [`Uniform`]: crate::distributions::uniform::Uniform + pub fn new(weights: I) -> Result, WeightedError> + where I: IntoIterator, + I::Item: SampleBorrow, + X: for<'a> ::core::ops::AddAssign<&'a X> + + Clone + + Default { + let mut iter = weights.into_iter(); + let mut total_weight: X = iter.next() + .ok_or(WeightedError::NoItem)? + .borrow() + .clone(); + + let zero = ::default(); + if total_weight < zero { + return Err(WeightedError::InvalidWeight); + } + + let mut weights = Vec::::with_capacity(iter.size_hint().0); + for w in iter { + if *w.borrow() < zero { + return Err(WeightedError::InvalidWeight); + } + weights.push(total_weight.clone()); + total_weight += w.borrow(); + } + + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + let distr = X::Sampler::new(zero, total_weight); + + Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) + } +} + +impl Distribution for WeightedIndex where + X: SampleUniform + PartialOrd { + fn sample(&self, rng: &mut R) -> usize { + use ::core::cmp::Ordering; + let chosen_weight = self.weight_distribution.sample(rng); + // Find the first item which has a weight *higher* than the chosen weight. + self.cumulative_weights.binary_search_by( + |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_weightedindex() { + let mut r = ::test::rng(700); + const N_REPS: u32 = 5000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // WeightedIndex from vec + let mut chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from slice + chosen = [0i32; 14]; + let distr = WeightedIndex::new(&weights[..]).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from iterator + chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.iter()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + for _ in 0..5 { + assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); + assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); + assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4); + } + + assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); + assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); + assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight); + assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight); + assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight); + } +} + +/// Error type returned from `WeightedIndex::new`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WeightedError { + /// The provided weight collection contains no items. + NoItem, + + /// A weight is either less than zero, greater than the supported maximum or + /// otherwise invalid. + InvalidWeight, + + /// All items in the provided weight collection are zero. + AllWeightsZero, +} + +impl WeightedError { + fn msg(&self) -> &str { + match *self { + WeightedError::NoItem => "No weights provided.", + WeightedError::InvalidWeight => "A weight is invalid.", + WeightedError::AllWeightsZero => "All weights are zero.", + } + } +} + +#[cfg(feature="std")] +impl ::std::error::Error for WeightedError { + fn description(&self) -> &str { + self.msg() + } + fn cause(&self) -> Option<&::std::error::Error> { + None + } +} + +impl fmt::Display for WeightedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.msg()) + } +} diff --git a/src/seq/mod.rs b/src/seq/mod.rs index d0f83084a66..72fb211e08b 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -823,7 +823,7 @@ mod test { assert_eq!(empty_slice.choose_weighted(&mut r, |_| 1), Err(WeightedError::NoItem)); assert_eq!(empty_slice.choose_weighted_mut(&mut r, |_| 1), Err(WeightedError::NoItem)); assert_eq!(['x'].choose_weighted_mut(&mut r, |_| 0), Err(WeightedError::AllWeightsZero)); - assert_eq!([0, -1].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight)); - assert_eq!([-1, 0].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::NegativeWeight)); + assert_eq!([0, -1].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::InvalidWeight)); + assert_eq!([-1, 0].choose_weighted_mut(&mut r, |x| *x), Err(WeightedError::InvalidWeight)); } } From 950afb3e1b1975dbf35bb1be8fb5f981709fb5cd Mon Sep 17 00:00:00 2001 From: zroug <37004975+zroug@users.noreply.github.com> Date: Wed, 6 Mar 2019 18:07:32 +0100 Subject: [PATCH 10/10] Use a instead of an when appropriate --- src/distributions/weighted/alias_method.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/distributions/weighted/alias_method.rs b/src/distributions/weighted/alias_method.rs index 4e6dce7adaf..9fdba92ec77 100644 --- a/src/distributions/weighted/alias_method.rs +++ b/src/distributions/weighted/alias_method.rs @@ -14,11 +14,11 @@ use Rng; /// A distribution using weighted sampling to pick a discretely selected item. /// -/// Sampling an [`WeightedIndex`] distribution returns the index of a -/// randomly selected element from the vector used to create the -/// [`WeightedIndex`]. The chance of a given element being picked is -/// proportional to the value of the element. The weights can have any type `W` -/// for which an implementation of [`Weight`] exists. +/// Sampling a [`WeightedIndex`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedIndex`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which a implementation of +/// [`Weight`] exists. /// /// # Performance /// @@ -27,7 +27,7 @@ use Rng; /// memory. More specifically it takes up some constant amount of memory plus /// the vector used to create it and a [`Vec`] with capacity `n`. /// -/// Time complexity for the creation of an [`WeightedIndex`] is `O(n)`. +/// Time complexity for the creation of a [`WeightedIndex`] is `O(n)`. /// Sampling is `O(1)`, it makes a call to [`Uniform::sample`] and a call /// to [`Uniform::sample`]. /// @@ -67,7 +67,7 @@ pub struct WeightedIndex { } impl WeightedIndex { - /// Creates an new [`WeightedIndex`]. + /// Creates a new [`WeightedIndex`]. /// /// Returns an error if: /// - The vector is empty.