Skip to content

Commit

Permalink
AliasMethod weighted index: use u32 internally
Browse files Browse the repository at this point in the history
Primarily for value stability, also slight performance boost.
  • Loading branch information
dhardy committed Jun 3, 2019
1 parent 1c57f70 commit 31348b7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 48 deletions.
102 changes: 54 additions & 48 deletions src/distributions/weighted/alias_method.rs
Expand Up @@ -25,10 +25,10 @@ use Rng;
/// Given that `n` is the number of items in the vector used to create an
/// [`WeightedIndex<W>`], [`WeightedIndex<W>`] will require `O(n)` amount of
/// memory. More specifically it takes up some constant amount of memory plus
/// the vector used to create it and a [`Vec<usize>`] with capacity `n`.
/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
///
/// Time complexity for the creation of a [`WeightedIndex<W>`] is `O(n)`.
/// Sampling is `O(1)`, it makes a call to [`Uniform<usize>::sample`] and a call
/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
/// to [`Uniform<W>::sample`].
///
/// # Example
Expand Down Expand Up @@ -56,13 +56,13 @@ use Rng;
///
/// [`WeightedIndex<W>`]: crate::distributions::weighted::alias_method::WeightedIndex
/// [`Weight`]: crate::distributions::weighted::alias_method::Weight
/// [`Vec<usize>`]: Vec
/// [`Uniform<usize>::sample`]: Distribution::sample
/// [`Vec<u32>`]: Vec
/// [`Uniform<u32>::sample`]: Distribution::sample
/// [`Uniform<W>::sample`]: Distribution::sample
pub struct WeightedIndex<W: Weight> {
aliases: Vec<usize>,
aliases: Vec<u32>,
no_alias_odds: Vec<W>,
uniform_index: Uniform<usize>,
uniform_index: Uniform<u32>,
uniform_within_weight_sum: Uniform<W>,
}

Expand All @@ -71,16 +71,20 @@ impl<W: Weight> WeightedIndex<W> {
///
/// Returns an error if:
/// - The vector is empty.
/// - The vector is longer than `u32::MAX`.
/// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
/// weights.len()`.
/// - The sum of weights is zero.
pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
let n = weights.len();
if n == 0 {
return Err(WeightedError::NoItem);
} else if n > ::core::u32::MAX as usize {
return Err(WeightedError::TooMany);
}
let n = n as u32;

let max_weight_size = W::try_from_usize_lossy(n)
let max_weight_size = W::try_from_u32_lossy(n)
.map(|n| W::MAX / n)
.unwrap_or(W::ZERO);
if !weights
Expand All @@ -103,7 +107,7 @@ impl<W: Weight> WeightedIndex<W> {
}

// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
let n_converted = W::try_from_usize_lossy(n).unwrap();
let n_converted = W::try_from_u32_lossy(n).unwrap();

let mut no_alias_odds = weights;
for odds in no_alias_odds.iter_mut() {
Expand All @@ -119,52 +123,52 @@ impl<W: Weight> WeightedIndex<W> {
/// be ensured that a single index is only ever in one of them at the
/// same time.
struct Aliases {
aliases: Vec<usize>,
smalls_head: usize,
bigs_head: usize,
aliases: Vec<u32>,
smalls_head: u32,
bigs_head: u32,
}

impl Aliases {
fn new(size: usize) -> Self {
fn new(size: u32) -> Self {
Aliases {
aliases: vec![0; size],
smalls_head: ::core::usize::MAX,
bigs_head: ::core::usize::MAX,
aliases: vec![0; size as usize],
smalls_head: ::core::u32::MAX,
bigs_head: ::core::u32::MAX,
}
}

fn push_small(&mut self, idx: usize) {
self.aliases[idx] = self.smalls_head;
fn push_small(&mut self, idx: u32) {
self.aliases[idx as usize] = self.smalls_head;
self.smalls_head = idx;
}

fn push_big(&mut self, idx: usize) {
self.aliases[idx] = self.bigs_head;
fn push_big(&mut self, idx: u32) {
self.aliases[idx as usize] = self.bigs_head;
self.bigs_head = idx;
}

fn pop_small(&mut self) -> usize {
fn pop_small(&mut self) -> u32 {
let popped = self.smalls_head;
self.smalls_head = self.aliases[popped];
self.smalls_head = self.aliases[popped as usize];
popped
}

fn pop_big(&mut self) -> usize {
fn pop_big(&mut self) -> u32 {
let popped = self.bigs_head;
self.bigs_head = self.aliases[popped];
self.bigs_head = self.aliases[popped as usize];
popped
}

fn smalls_is_empty(&self) -> bool {
self.smalls_head == ::core::usize::MAX
self.smalls_head == ::core::u32::MAX
}

fn bigs_is_empty(&self) -> bool {
self.bigs_head == ::core::usize::MAX
self.bigs_head == ::core::u32::MAX
}

fn set_alias(&mut self, idx: usize, alias: usize) {
self.aliases[idx] = alias;
fn set_alias(&mut self, idx: u32, alias: u32) {
self.aliases[idx as usize] = alias;
}
}

Expand All @@ -173,9 +177,9 @@ impl<W: Weight> WeightedIndex<W> {
// Split indices into those with small weights and those with big weights.
for (index, &odds) in no_alias_odds.iter().enumerate() {
if odds < weight_sum {
aliases.push_small(index);
aliases.push_small(index as u32);
} else {
aliases.push_big(index);
aliases.push_big(index as u32);
}
}

Expand All @@ -186,9 +190,11 @@ impl<W: Weight> WeightedIndex<W> {
let b = aliases.pop_big();

aliases.set_alias(s, b);
no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s];
no_alias_odds[b as usize] = no_alias_odds[b as usize]
- weight_sum
+ no_alias_odds[s as usize];

if no_alias_odds[b] < weight_sum {
if no_alias_odds[b as usize] < weight_sum {
aliases.push_small(b);
} else {
aliases.push_big(b);
Expand All @@ -198,10 +204,10 @@ impl<W: Weight> WeightedIndex<W> {
// The remaining indices should have no alias odds of about 100%. This is due to
// numeric accuracy. Otherwise they would be exactly 100%.
while !aliases.smalls_is_empty() {
no_alias_odds[aliases.pop_small()] = weight_sum;
no_alias_odds[aliases.pop_small() as usize] = weight_sum;
}
while !aliases.bigs_is_empty() {
no_alias_odds[aliases.pop_big()] = weight_sum;
no_alias_odds[aliases.pop_big() as usize] = weight_sum;
}

// Prepare distributions for sampling. Creating them beforehand improves
Expand All @@ -221,10 +227,10 @@ impl<W: Weight> WeightedIndex<W> {
impl<W: Weight> Distribution<usize> for WeightedIndex<W> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let candidate = rng.sample(self.uniform_index);
if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] {
candidate
if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
candidate as usize
} else {
self.aliases[candidate]
self.aliases[candidate as usize] as usize
}
}
}
Expand Down Expand Up @@ -282,10 +288,10 @@ pub trait Weight:
/// Element of `Self` equivalent to 0.
const ZERO: Self;

/// Produce an instance of `Self` from a `usize` value, or return `None` if
/// Produce an instance of `Self` from a `u32` value, or return `None` if
/// out of range. Loss of precision (where `Self` is a floating point type)
/// is acceptable.
fn try_from_usize_lossy(n: usize) -> Option<Self>;
fn try_from_u32_lossy(n: u32) -> Option<Self>;

/// Sums all values in slice `values`.
fn sum(values: &[Self]) -> Self {
Expand All @@ -299,7 +305,7 @@ macro_rules! impl_weight_for_float {
const MAX: Self = ::core::$T::MAX;
const ZERO: Self = 0.0;

fn try_from_usize_lossy(n: usize) -> Option<Self> {
fn try_from_u32_lossy(n: u32) -> Option<Self> {
Some(n as $T)
}

Expand Down Expand Up @@ -328,9 +334,9 @@ macro_rules! impl_weight_for_int {
const MAX: Self = ::core::$T::MAX;
const ZERO: Self = 0;

fn try_from_usize_lossy(n: usize) -> Option<Self> {
fn try_from_u32_lossy(n: u32) -> Option<Self> {
let n_converted = n as Self;
if n_converted >= Self::ZERO && n_converted as usize == n {
if n_converted >= Self::ZERO && n_converted as u32 == n {
Some(n_converted)
} else {
None
Expand Down Expand Up @@ -439,21 +445,21 @@ mod test {
where
WeightedIndex<W>: fmt::Debug,
{
const NUM_WEIGHTS: usize = 10;
const ZERO_WEIGHT_INDEX: usize = 3;
const NUM_WEIGHTS: u32 = 10;
const ZERO_WEIGHT_INDEX: u32 = 3;
const NUM_SAMPLES: u32 = 15000;
let mut rng = ::test::rng(0x9c9fa0b0580a7031);

let weights = {
let mut weights = Vec::with_capacity(NUM_WEIGHTS);
let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
let random_weight_distribution = ::distributions::Uniform::new_inclusive(
W::ZERO,
W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(),
W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
);
for _ in 0..NUM_WEIGHTS {
weights.push(rng.sample(&random_weight_distribution));
}
weights[ZERO_WEIGHT_INDEX] = W::ZERO;
weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
weights
};
let weight_sum = weights.iter().map(|w| *w).sum::<W>();
Expand All @@ -463,12 +469,12 @@ mod test {
.collect::<Vec<f64>>();
let weight_distribution = WeightedIndex::new(weights).unwrap();

let mut counts = vec![0_usize; NUM_WEIGHTS];
let mut counts = vec![0; NUM_WEIGHTS as usize];
for _ in 0..NUM_SAMPLES {
counts[rng.sample(&weight_distribution)] += 1;
}

assert_eq!(counts[ZERO_WEIGHT_INDEX], 0);
assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
for (count, expected_count) in counts.into_iter().zip(expected_counts) {
let difference = (count as f64 - expected_count).abs();
let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
Expand Down
4 changes: 4 additions & 0 deletions src/distributions/weighted/mod.rs
Expand Up @@ -208,6 +208,9 @@ pub enum WeightedError {

/// All items in the provided weight collection are zero.
AllWeightsZero,

/// Too many weights are provided (length greater than `u32::MAX`)
TooMany,
}

impl WeightedError {
Expand All @@ -216,6 +219,7 @@ impl WeightedError {
WeightedError::NoItem => "No weights provided.",
WeightedError::InvalidWeight => "A weight is invalid.",
WeightedError::AllWeightsZero => "All weights are zero.",
WeightedError::TooMany => "Too many weights (hit u32::MAX)",
}
}
}
Expand Down

0 comments on commit 31348b7

Please sign in to comment.