Skip to content

Commit

Permalink
AliasMethodWeightedIndex: Make implementation details more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
zroug committed Jan 24, 2019
1 parent fd5fafe commit 73d86e5
Showing 1 changed file with 37 additions and 30 deletions.
67 changes: 37 additions & 30 deletions src/distributions/weighted.rs
Expand Up @@ -136,32 +136,35 @@ pub struct AliasMethodWeightedIndex {
aliases: Vec<usize>,
no_alias_odds: Vec<f64>,
uniform_index: super::Uniform<usize>,
uniform_within_weight_sum: super::Uniform<f64>,
}

impl AliasMethodWeightedIndex {
#[allow(missing_docs)] // todo: add docs
pub fn new(weights: Vec<f64>) -> Result<Self, AliasMethodWeightedIndexError> {
if weights.is_empty() {
let n = weights.len();
if n == 0 {
return Err(AliasMethodWeightedIndexError::NoItem);
}
if !weights.iter().all(|&w| w >= 0.0) {

let max_weight_size = ::core::f64::MAX / n as f64;
if !weights.iter().all(|&w| 0_f64 <= w && w <= max_weight_size) {
return Err(AliasMethodWeightedIndexError::InvalidWeight);
}

let n = weights.len();
// The sum of weights will represent 100% of no alias odds.
let weight_sum = pairwise_sum_f64(weights.as_slice());
if weight_sum.is_infinite() {
return Err(AliasMethodWeightedIndexError::WeightSumToBig);
}

let weight_scale = n as f64 / weight_sum;
if weight_scale.is_infinite() {
return Err(AliasMethodWeightedIndexError::WeightSumToSmall);
// Prevent floating point overflow due to rounding errors.
let weight_sum = weight_sum.min(::core::f64::MAX);
if weight_sum == 0_f64 {
return Err(AliasMethodWeightedIndexError::AllWeightsZero);
}

let mut no_alias_odds = weights;
for odds in no_alias_odds.iter_mut() {
*odds *= weight_scale;
*odds *= n as f64;
// Prevent floating point overflow due to rounding errors.
*odds = odds.min(::core::f64::MAX);
}

/// This struct is designed to contain three data structures at once,
Expand Down Expand Up @@ -224,7 +227,7 @@ impl AliasMethodWeightedIndex {

// Split indices into those with small weights and those with big weights.
for (index, &odds) in no_alias_odds.iter().enumerate() {
if odds < 1.0 {
if odds < weight_sum {
aliases.push_small(index);
} else {
aliases.push_big(index);
Expand All @@ -238,40 +241,42 @@ impl AliasMethodWeightedIndex {
let b = aliases.pop_big();

aliases.set_alias(s, b);
no_alias_odds[b] = no_alias_odds[b] - 1.0 + no_alias_odds[s];
no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s];

if no_alias_odds[b] < 1.0 {
if no_alias_odds[b] < weight_sum {
aliases.push_small(b);
} else {
aliases.push_big(b);
}
}

// The remaining indices should have no alias odds of about 1. This is due to
// numeric accuracy. Otherwise they would be exactly 1.
// The remaining indices should have no alias odds of about 100%. This is due to
// numeric accuracy. Otherwise they would be exactly 100%.
while !aliases.smalls_is_empty() {
no_alias_odds[aliases.pop_small()] = 1.0;
no_alias_odds[aliases.pop_small()] = weight_sum;
}
while !aliases.bigs_is_empty() {
no_alias_odds[aliases.pop_big()] = 1.0;
no_alias_odds[aliases.pop_big()] = weight_sum;
}

// Prepare a distribution to sample random indices. Creating it beforehand
// improves sampling performance.
// Prepare distributions for sampling. Creating them beforehand improves
// sampling performance.
let uniform_index = super::Uniform::new(0, n);
let uniform_within_weight_sum = super::Uniform::new(0_f64, weight_sum);

Ok(Self {
aliases: aliases.aliases,
no_alias_odds,
uniform_index,
uniform_within_weight_sum,
})
}
}

impl Distribution<usize> for AliasMethodWeightedIndex {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let candidate = rng.sample(self.uniform_index);
if rng.sample::<f64, _>(super::Standard) < self.no_alias_odds[candidate] {
if rng.sample(self.uniform_within_weight_sum) < self.no_alias_odds[candidate] {
candidate
} else {
self.aliases[candidate]
Expand Down Expand Up @@ -357,8 +362,12 @@ mod test {

let weights = {
let mut weights = Vec::with_capacity(NUM_WEIGHTS);
let random_weight_distribution = ::distributions::Uniform::new_inclusive(
0_f64,
::core::f64::MAX / NUM_WEIGHTS as f64,
);
for _ in 0..NUM_WEIGHTS {
weights.push(rng.sample::<f64, _>(::distributions::Standard));
weights.push(rng.sample(random_weight_distribution));
}
weights[ZERO_WEIGHT_INDEX] = 0.0;
weights
Expand Down Expand Up @@ -388,19 +397,19 @@ mod test {
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![0.0]).unwrap_err(),
AliasMethodWeightedIndexError::WeightSumToSmall
AliasMethodWeightedIndexError::AllWeightsZero
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-0.0]).unwrap_err(),
AliasMethodWeightedIndexError::WeightSumToSmall
AliasMethodWeightedIndexError::AllWeightsZero
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::INFINITY]).unwrap_err(),
AliasMethodWeightedIndexError::WeightSumToBig
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![::core::f64::MAX, ::core::f64::MAX]).unwrap_err(),
AliasMethodWeightedIndexError::WeightSumToBig
AliasMethodWeightedIndexError::InvalidWeight
);
assert_eq!(
AliasMethodWeightedIndex::new(vec![-1.0]).unwrap_err(),
Expand Down Expand Up @@ -461,17 +470,15 @@ impl fmt::Display for WeightedError {
pub enum AliasMethodWeightedIndexError {
NoItem,
InvalidWeight,
WeightSumToSmall,
WeightSumToBig,
AllWeightsZero,
}

impl AliasMethodWeightedIndexError {
fn msg(&self) -> &str {
match *self {
AliasMethodWeightedIndexError::NoItem => "No items found.",
AliasMethodWeightedIndexError::InvalidWeight => "An item has an invalid weight.",
AliasMethodWeightedIndexError::WeightSumToSmall => "The sum of weights is to small.",
AliasMethodWeightedIndexError::WeightSumToBig => "The sum of weights is to big.",
AliasMethodWeightedIndexError::AllWeightsZero => "All weights are zero.",
}
}
}
Expand Down

0 comments on commit 73d86e5

Please sign in to comment.