Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename alias_method::WeightedIndex to WeightedAliasIndex #1008

Merged
merged 4 commits into from
Aug 2, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
//! The following are re-exported:
//!
//! - The [`Distribution`] trait and [`DistIter`] helper type
//! - The [`Standard`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], [`Open01`] and [`Bernoulli`] distributions
//! - The [`weighted`] sub-module
//! - The [`Standard`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`],
//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions
//!
//! ## Distributions
//!
Expand Down Expand Up @@ -107,12 +107,14 @@ pub use self::unit_disc::UnitDisc;
pub use self::unit_sphere::UnitSphere;
pub use self::weibull::{Error as WeibullError, Weibull};
#[cfg(feature = "alloc")]
pub use self::weighted::{WeightedError, WeightedIndex};
pub use rand::distributions::weighted::{WeightedError, WeightedIndex};
#[cfg(feature = "alloc")]
pub use weighted_alias::WeightedAliasIndex;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this results in WeightedAliasIndex being shown as re-exported item in documentation.


pub use num_traits;

#[cfg(feature = "alloc")]
pub mod weighted;
pub mod weighted_alias;

mod binomial;
mod cauchy;
Expand Down
21 changes: 0 additions & 21 deletions rand_distr/src/weighted/mod.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,59 +19,59 @@ use alloc::{boxed::Box, vec, vec::Vec};

/// A distribution using weighted sampling to pick a discretely selected item.
///
/// Sampling a [`WeightedIndex<W>`] distribution returns the index of a randomly
/// selected element from the vector used to create the [`WeightedIndex<W>`].
/// Sampling a [`WeightedAliasIndex<W>`] distribution returns the index of a randomly
/// selected element from the vector used to create the [`WeightedAliasIndex<W>`].
/// The chance of a given element being picked is proportional to the value of
/// the element. The weights can have any type `W` for which a implementation of
/// [`Weight`] exists.
/// [`AliasableWeight`] exists.
///
/// # Performance
///
/// Given that `n` is the number of items in the vector used to create an
/// [`WeightedIndex<W>`], [`WeightedIndex<W>`] will require `O(n)` amount of
/// memory. More specifically it takes up some constant amount of memory plus
/// [`WeightedAliasIndex<W>`], it will require `O(n)` amount of memory.
/// More specifically it takes up some constant amount of memory plus
/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
///
/// Time complexity for the creation of a [`WeightedIndex<W>`] is `O(n)`.
/// Time complexity for the creation of a [`WeightedAliasIndex<W>`] is `O(n)`.
/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
/// to [`Uniform<W>::sample`].
///
/// # Example
///
/// ```
/// use rand_distr::weighted::alias_method::WeightedIndex;
/// use rand_distr::weighted::WeightedAliasIndex;
/// use rand::prelude::*;
///
/// let choices = vec!['a', 'b', 'c'];
/// let weights = vec![2, 1, 1];
/// let dist = WeightedIndex::new(weights).unwrap();
/// let dist = WeightedAliasIndex::new(weights).unwrap();
/// let mut rng = thread_rng();
/// for _ in 0..100 {
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// println!("{}", choices[dist.sample(&mut rng)]);
/// }
///
/// let items = [('a', 0), ('b', 3), ('c', 7)];
/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap();
/// let dist2 = WeightedAliasIndex::new(items.iter().map(|item| item.1).collect()).unwrap();
/// for _ in 0..100 {
/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
/// println!("{}", items[dist2.sample(&mut rng)].0);
/// }
/// ```
///
/// [`WeightedIndex<W>`]: WeightedIndex
/// [`WeightedAliasIndex<W>`]: WeightedAliasIndex
/// [`Vec<u32>`]: Vec
/// [`Uniform<u32>::sample`]: Distribution::sample
/// [`Uniform<W>::sample`]: Distribution::sample
pub struct WeightedIndex<W: Weight> {
pub struct WeightedAliasIndex<W: AliasableWeight> {
aliases: Box<[u32]>,
no_alias_odds: Box<[W]>,
uniform_index: Uniform<u32>,
uniform_within_weight_sum: Uniform<W>,
}

impl<W: Weight> WeightedIndex<W> {
/// Creates a new [`WeightedIndex`].
impl<W: AliasableWeight> WeightedAliasIndex<W> {
/// Creates a new [`WeightedAliasIndex`].
///
/// Returns an error if:
/// - The vector is empty.
Expand Down Expand Up @@ -99,7 +99,7 @@ impl<W: Weight> WeightedIndex<W> {
}

// The sum of weights will represent 100% of no alias odds.
let weight_sum = Weight::sum(weights.as_slice());
let weight_sum = AliasableWeight::sum(weights.as_slice());
// Prevent floating point overflow due to rounding errors.
let weight_sum = if weight_sum > W::MAX {
W::MAX
Expand Down Expand Up @@ -227,7 +227,7 @@ impl<W: Weight> WeightedIndex<W> {
}
}

impl<W: Weight> Distribution<usize> for WeightedIndex<W> {
impl<W: AliasableWeight> Distribution<usize> for WeightedAliasIndex<W> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let candidate = rng.sample(self.uniform_index);
if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
Expand All @@ -238,13 +238,13 @@ impl<W: Weight> Distribution<usize> for WeightedIndex<W> {
}
}

impl<W: Weight> fmt::Debug for WeightedIndex<W>
impl<W: AliasableWeight> fmt::Debug for WeightedAliasIndex<W>
where
W: fmt::Debug,
Uniform<W>: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WeightedIndex")
f.debug_struct("WeightedAliasIndex")
.field("aliases", &self.aliases)
.field("no_alias_odds", &self.no_alias_odds)
.field("uniform_index", &self.uniform_index)
Expand All @@ -253,7 +253,7 @@ where
}
}

impl<W: Weight> Clone for WeightedIndex<W>
impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
where Uniform<W>: Clone
{
fn clone(&self) -> Self {
Expand All @@ -267,9 +267,9 @@ where Uniform<W>: Clone
}

/// Trait that must be implemented for weights, that are used with
/// [`WeightedIndex`]. Currently no guarantees on the correctness of
/// [`WeightedIndex`] are given for custom implementations of this trait.
pub trait Weight:
/// [`WeightedAliasIndex`]. Currently no guarantees on the correctness of
/// [`WeightedAliasIndex`] are given for custom implementations of this trait.
pub trait AliasableWeight:
Sized
+ Copy
+ SampleUniform
Expand Down Expand Up @@ -303,7 +303,7 @@ pub trait Weight:

macro_rules! impl_weight_for_float {
($T: ident) => {
impl Weight for $T {
impl AliasableWeight for $T {
const MAX: Self = ::core::$T::MAX;
const ZERO: Self = 0.0;

Expand All @@ -320,7 +320,7 @@ macro_rules! impl_weight_for_float {

/// In comparison to naive accumulation, the pairwise sum algorithm reduces
/// rounding errors when there are many floating point values.
fn pairwise_sum<T: Weight>(values: &[T]) -> T {
fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
if values.len() <= 32 {
values.iter().map(|x| *x).sum()
} else {
Expand All @@ -332,7 +332,7 @@ fn pairwise_sum<T: Weight>(values: &[T]) -> T {

macro_rules! impl_weight_for_int {
($T: ident) => {
impl Weight for $T {
impl AliasableWeight for $T {
const MAX: Self = ::core::$T::MAX;
const ZERO: Self = 0;

Expand Down Expand Up @@ -376,23 +376,23 @@ mod test {

// Floating point special cases
assert_eq!(
WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
WeightedError::InvalidWeight
);
assert_eq!(
WeightedIndex::new(vec![-0_f32]).unwrap_err(),
WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(),
WeightedError::AllWeightsZero
);
assert_eq!(
WeightedIndex::new(vec![-1_f32]).unwrap_err(),
WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(),
WeightedError::InvalidWeight
);
assert_eq!(
WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
WeightedError::InvalidWeight
);
assert_eq!(
WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(),
WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(),
WeightedError::InvalidWeight
);
}
Expand All @@ -412,11 +412,11 @@ mod test {

// Signed integer special cases
assert_eq!(
WeightedIndex::new(vec![-1_i128]).unwrap_err(),
WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(),
WeightedError::InvalidWeight
);
assert_eq!(
WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(),
WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(),
WeightedError::InvalidWeight
);
}
Expand All @@ -434,17 +434,17 @@ mod test {

// Signed integer special cases
assert_eq!(
WeightedIndex::new(vec![-1_i8]).unwrap_err(),
WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(),
WeightedError::InvalidWeight
);
assert_eq!(
WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(),
WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(),
WeightedError::InvalidWeight
);
}

fn test_weighted_index<W: Weight, F: Fn(W) -> f64>(w_to_f64: F)
where WeightedIndex<W>: fmt::Debug {
fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
where WeightedAliasIndex<W>: fmt::Debug {
const NUM_WEIGHTS: u32 = 10;
const ZERO_WEIGHT_INDEX: u32 = 3;
const NUM_SAMPLES: u32 = 15000;
Expand All @@ -467,7 +467,7 @@ mod test {
.iter()
.map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
.collect::<Vec<f64>>();
let weight_distribution = WeightedIndex::new(weights).unwrap();
let weight_distribution = WeightedAliasIndex::new(weights).unwrap();

let mut counts = vec![0; NUM_WEIGHTS as usize];
for _ in 0..NUM_SAMPLES {
Expand All @@ -482,24 +482,24 @@ mod test {
}

assert_eq!(
WeightedIndex::<W>::new(vec![]).unwrap_err(),
WeightedAliasIndex::<W>::new(vec![]).unwrap_err(),
WeightedError::NoItem
);
assert_eq!(
WeightedIndex::new(vec![W::ZERO]).unwrap_err(),
WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(),
WeightedError::AllWeightsZero
);
assert_eq!(
WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
WeightedError::InvalidWeight
);
}

#[test]
fn value_stability() {
fn test_samples<W: Weight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) {
fn test_samples<W: AliasableWeight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) {
assert_eq!(buf.len(), expected.len());
let distr = WeightedIndex::new(weights).unwrap();
let distr = WeightedAliasIndex::new(weights).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
for r in buf.iter_mut() {
*r = rng.sample(&distr);
Expand Down