Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeightedChoice could use the Walker or Vose Alias method for O(1) sampling (instead of O(log n)) #601

Closed
huonw opened this issue Sep 7, 2018 · 8 comments

Comments

@huonw
Copy link
Contributor

huonw commented Sep 7, 2018

https://en.wikipedia.org/wiki/Alias_method allows for O(1) sampling of variates with still only O(n) of preprocessing, which is significantly better than the O(log n) of the current method (binary search, sometimes called Roulette Wheel Selection).

@dhardy
Copy link
Member

dhardy commented Sep 7, 2018

Thanks. For now I'm more interested in getting the API right, and this isn't a core function, so hasn't the highest priority IMO.

@huonw
Copy link
Contributor Author

huonw commented Sep 8, 2018

Yeah, definitely. That said, unfortunately I suspect an implementation of this may require changing the current API (both the old WeightedChoice form and even master's WeightedIndex form that allocates), so it isn't entirely just an internal optimization.

@zroug
Copy link
Contributor

zroug commented Jan 11, 2019

I created an implementation for myself. I will put it here so that anyone who needs it can use it. It is based on this and it should be possible to tweak it to use integer weights by scaling what is 100%.

use rand::distributions::Distribution;
use rand::Rng;
use std::collections::VecDeque;

pub struct AliasMethodWeightedIndex {
    aliases: Vec<usize>,
    no_alias_odds: Vec<f64>,
}

impl AliasMethodWeightedIndex {
    pub fn new(weights: Vec<f64>) -> Self {
        debug_assert!(weights.iter().all(|&w| w >= 0.0));

        let weight_sum = pairwise_sum_f64(weights.as_slice());
        if !weight_sum.is_finite() {
            panic!("Sum of weights not finite.");
        }
        let n = weights.len();

        let mut no_alias_odds = weights;
        for p in no_alias_odds.iter_mut() {
            *p *= n as f64 / weight_sum;
        }

        // Split indices into indices with small weights and indices with big weights.
        // Instead of two `Vec` with unknown capacity we use a single `VecDeque` with
        // known capacity. Front represents smalls and back represents bigs. We also
        // need to keep track of the size of each virtual `Vec`.
        let mut smalls_bigs = VecDeque::with_capacity(n);
        let mut smalls_len = 0_usize;
        let mut bigs_len = 0_usize;
        for (index, &weight) in no_alias_odds.iter().enumerate() {
            if weight < 1.0 {
                smalls_bigs.push_front(index);
                smalls_len += 1;
            } else {
                smalls_bigs.push_back(index);
                bigs_len += 1;
            }
        }

        let mut aliases = vec![0; n];
        while smalls_len > 0 && bigs_len > 0 {
            let s = smalls_bigs.pop_front().unwrap();
            smalls_len -= 1;
            let b = smalls_bigs.pop_back().unwrap();
            bigs_len -= 1;

            aliases[s] = b;
            no_alias_odds[b] = no_alias_odds[s] + no_alias_odds[b] - 1.0;

            if no_alias_odds[b] < 1.0 {
                smalls_bigs.push_front(b);
                smalls_len += 1;
            } else {
                smalls_bigs.push_back(b);
                bigs_len += 1;
            }
        }

        // The remaining indices should have no alias odds of about 1. This is due to
        // numeric accuracy. Otherwise they would be exactly 1.
        for index in smalls_bigs.into_iter() {
            // Because p = 1 we don't need to set an alias. It will never be accessed.
            no_alias_odds[index] = 1.0;
        }

        Self {
            aliases,
            no_alias_odds,
        }
    }
}

impl Distribution<usize> for AliasMethodWeightedIndex {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
        let candidate = rng.gen_range(0, self.no_alias_odds.len());
        if rng.gen_bool(self.no_alias_odds[candidate]) {
            candidate
        } else {
            self.aliases[candidate]
        }
    }
}

pub fn pairwise_sum_f64(values: &[f64]) -> f64 {
    if values.len() <= 32 {
        values.iter().sum()
    } else {
        let mid = values.len() / 2;
        let (a, b) = values.split_at(mid);
        pairwise_sum_f64(a) + pairwise_sum_f64(b)
    }
}

Note that I used a pairwise summation algorithm to improve accuracy when there are many floating point weights. I benchmarked it to find a good size for the base case and found that it is about twice as fast as simple loop/iterator summation on my machine. I don't know why that is, because I would have expected a little bit of overhead instead but I haven't investigated further.

@dhardy
Copy link
Member

dhardy commented Jan 11, 2019

@zroug thanks for the code. We would welcome a PR if you can see how to integrate this.

@dhardy
Copy link
Member

dhardy commented Jan 25, 2019

I believe we can close this now

@dhardy dhardy closed this as completed Jan 25, 2019
@huonw
Copy link
Contributor Author

huonw commented Jan 27, 2019

Hm, it doesn't look to me like the code has changed (e.g. #692 hasn't landed), so this doesn't seem fixed? Is the issue-management approach written down somewhere (so I can understand)?

@dhardy
Copy link
Member

dhardy commented Jan 27, 2019

No, the code hasn't landed yet, but we have a PR in reasonable shape. Do we need a tracking issue too?

Maybe I should just follow "standard" practice and keep this open for now then.

@dhardy dhardy reopened this Jan 27, 2019
@dhardy dhardy mentioned this issue Jan 28, 2019
22 tasks
@vks
Copy link
Collaborator

vks commented Apr 9, 2019

#692 is merged now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants