Skip to content

Commit

Permalink
Merge pull request #945 from dhardy/alias-method
Browse files Browse the repository at this point in the history
Move Alias-method WeightedIndex to rand_distr
  • Loading branch information
vks committed Mar 26, 2020
2 parents 0c86451 + 1a23bc8 commit 8592ad3
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 27 deletions.
7 changes: 5 additions & 2 deletions rand_distr/src/lib.rs
Expand Up @@ -68,8 +68,8 @@
//! - [`UnitDisc`] distribution

pub use rand::distributions::{
uniform, weighted, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01,
OpenClosed01, Standard, Uniform,
uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01,
Standard, Uniform,
};

pub use self::binomial::{Binomial, Error as BinomialError};
Expand All @@ -91,6 +91,9 @@ pub use self::unit_disc::UnitDisc;
pub use self::unit_sphere::UnitSphere;
pub use self::utils::Float;
pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::weighted::{WeightedError, WeightedIndex};

pub mod weighted;

mod binomial;
mod cauchy;
Expand Down
@@ -1,16 +1,20 @@
// Copyright 2019 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! This module contains an implementation of alias method for sampling random
//! indices with probabilities proportional to a collection of weights.

use super::WeightedError;
#[cfg(not(feature = "std"))] use crate::alloc::vec;
#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;
use crate::distributions::uniform::SampleUniform;
use crate::distributions::Distribution;
use crate::distributions::Uniform;
use crate::Rng;
use crate::{uniform::SampleUniform, Distribution, Uniform};
use core::fmt;
use core::iter::Sum;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
use rand::Rng;

/// A distribution using weighted sampling to pick a discretely selected item.
///
Expand All @@ -34,7 +38,7 @@ use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
/// # Example
///
/// ```
/// use rand::distributions::weighted::alias_method::WeightedIndex;
/// use rand_distr::weighted::alias_method::WeightedIndex;
/// use rand::prelude::*;
///
/// let choices = vec!['a', 'b', 'c'];
Expand Down Expand Up @@ -400,7 +404,7 @@ mod test {
test_weighted_index(|x: u128| x as f64);
}

#[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
#[cfg(not(target_os = "emscripten"))]
#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_weighted_index_i128() {
Expand Down Expand Up @@ -448,7 +452,7 @@ mod test {

let weights = {
let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
let random_weight_distribution = crate::distributions::Uniform::new_inclusive(
let random_weight_distribution = Uniform::new_inclusive(
W::ZERO,
W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
);
Expand Down
21 changes: 21 additions & 0 deletions rand_distr/src/weighted/mod.rs
@@ -0,0 +1,21 @@
// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Weighted index sampling
//!
//! This module provides two implementations for sampling indices:
//!
//! * [`WeightedIndex`] allows `O(log N)` sampling
//! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with
//! much greater set-up cost
//!
//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html

pub mod alias_method;

pub use rand::distributions::weighted::{WeightedError, WeightedIndex};
6 changes: 5 additions & 1 deletion src/distributions/mod.rs
Expand Up @@ -100,12 +100,16 @@ pub use self::bernoulli::{Bernoulli, BernoulliError};
pub use self::float::{Open01, OpenClosed01};
pub use self::other::Alphanumeric;
#[doc(inline)] pub use self::uniform::Uniform;

#[cfg(feature = "alloc")]
pub use self::weighted::{WeightedError, WeightedIndex};
pub use self::weighted_index::{WeightedError, WeightedIndex};

mod bernoulli;
pub mod uniform;

#[deprecated(since = "0.8.0", note = "use rand::distributions::{WeightedIndex, WeightedError} instead")]
#[cfg(feature = "alloc")] pub mod weighted;
#[cfg(feature = "alloc")] mod weighted_index;

mod float;
#[doc(hidden)]
Expand Down
48 changes: 48 additions & 0 deletions src/distributions/weighted.rs
@@ -0,0 +1,48 @@
// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Weighted index sampling
//!
//! This module is deprecated. Use [`crate::distributions::WeightedIndex`] and
//! [`crate::distributions::WeightedError`] instead.

pub use super::{WeightedIndex, WeightedError};

#[allow(missing_docs)]
#[deprecated(since = "0.8.0", note = "moved to rand_distr crate")]
pub mod alias_method {
// This module exists to provide a deprecation warning which minimises
// compile errors, but still fails to compile if ever used.
use core::marker::PhantomData;
#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;
use super::WeightedError;

#[derive(Debug)]
pub struct WeightedIndex<W: Weight> {
_phantom: PhantomData<W>,
}
impl<W: Weight> WeightedIndex<W> {
pub fn new(_weights: Vec<W>) -> Result<Self, WeightedError> {
Err(WeightedError::NoItem)
}
}

pub trait Weight {}
macro_rules! impl_weight {
() => {};
($T:ident, $($more:ident,)*) => {
impl Weight for $T {}
impl_weight!($($more,)*);
};
}
impl_weight!(f64, f32,);
impl_weight!(u8, u16, u32, u64, usize,);
impl_weight!(i8, i16, i32, i64, isize,);
#[cfg(not(target_os = "emscripten"))]
impl_weight!(u128, i128,);
}
Expand Up @@ -7,16 +7,6 @@
// except according to those terms.

//! Weighted index sampling
//!
//! This module provides two implementations for sampling indices:
//!
//! * [`WeightedIndex`] allows `O(log N)` sampling
//! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with
//! much greater set-up cost
//!
//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html

pub mod alias_method;

use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
use crate::distributions::Distribution;
Expand All @@ -27,8 +17,7 @@ use core::fmt;
// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;

/// A distribution using weighted sampling to pick a discretely selected
/// item.
/// A distribution using weighted sampling of discrete items
///
/// Sampling a `WeightedIndex` distribution returns the index of a randomly
/// selected element from the iterator used when the `WeightedIndex` was
Expand All @@ -38,6 +27,11 @@ use core::fmt;
///
/// # Performance
///
/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
/// `N` is the number of weights. As an alternative,
/// [`rand_distr::weighted::alias_method`](https://docs.rs/rand_distr/*/rand_distr/weighted/alias_method/index.html)
/// supports `O(1)` sampling, but with much higher initialisation cost.
///
/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
/// size is the sum of the size of those objects, possibly plus some alignment.
///
Expand All @@ -48,9 +42,6 @@ use core::fmt;
/// contains, this might cause additional allocations, though for primitive
/// types, ['Uniform<X>`] doesn't allocate any memory.
///
/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
/// `N` is the number of weights.
///
/// Sampling from `WeightedIndex` will result in a single call to
/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
/// will request a single value from the underlying [`RngCore`], though the
Expand Down

0 comments on commit 8592ad3

Please sign in to comment.