Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement weighted sampling API #518

Merged
merged 1 commit into from Jul 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions benches/distributions.rs
Expand Up @@ -115,6 +115,11 @@ distr_int!(distr_binomial, u64, Binomial::new(20, 0.7));
distr_int!(distr_poisson, u64, Poisson::new(4.0));
distr!(distr_bernoulli, bool, Bernoulli::new(0.18));

// Weighted
distr_int!(distr_weighted_i8, usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());

// construct and sample from a range
macro_rules! gen_range_int {
Expand Down
46 changes: 25 additions & 21 deletions src/distributions/mod.rs
Expand Up @@ -73,6 +73,8 @@
//! numbers of the `char` type; in contrast [`Standard`] may sample any valid
//! `char`.
//!
//! [`WeightedIndex`] can be used to do weighted sampling from a set of items,
//! such as from an array.
//!
//! # Non-uniform probability distributions
//!
Expand Down Expand Up @@ -167,12 +169,15 @@
//! [`Uniform`]: struct.Uniform.html
//! [`Uniform::new`]: struct.Uniform.html#method.new
//! [`Uniform::new_inclusive`]: struct.Uniform.html#method.new_inclusive
//! [`WeightedIndex`]: struct.WeightedIndex.html

use Rng;

#[doc(inline)] pub use self::other::Alphanumeric;
#[doc(inline)] pub use self::uniform::Uniform;
#[doc(inline)] pub use self::float::{OpenClosed01, Open01};
#[cfg(feature="alloc")]
#[doc(inline)] pub use self::weighted::WeightedIndex;
#[cfg(feature="std")]
#[doc(inline)] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
#[cfg(feature="std")]
Expand All @@ -192,6 +197,8 @@ use Rng;
#[doc(inline)] pub use self::dirichlet::Dirichlet;

pub mod uniform;
#[cfg(feature="alloc")]
#[doc(hidden)] pub mod weighted;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no reason to make this module public so better make it private I think. The doc(hidden) stuff is supposedly there for backwards compatibility, though probably it's been used erroneously in a couple of cases.

#[cfg(feature="std")]
#[doc(hidden)] pub mod gamma;
#[cfg(feature="std")]
Expand Down Expand Up @@ -373,6 +380,8 @@ pub struct Standard;


/// A value with a particular weight for use with `WeightedChoice`.
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
#[derive(Copy, Clone, Debug)]
pub struct Weighted<T> {
/// The numerical weight of this item
Expand All @@ -383,34 +392,19 @@ pub struct Weighted<T> {

/// A distribution that selects from a finite collection of weighted items.
///
/// Each item has an associated weight that influences how likely it
/// is to be chosen: higher weight is more likely.
/// Deprecated: use [`WeightedIndex`] instead.
///
/// The `Clone` restriction is a limitation of the `Distribution` trait.
/// Note that `&T` is (cheaply) `Clone` for all `T`, as is `u32`, so one can
/// store references or indices into another vector.
///
/// # Example
///
/// ```
/// use rand::distributions::{Weighted, WeightedChoice, Distribution};
///
/// let mut items = vec!(Weighted { weight: 2, item: 'a' },
/// Weighted { weight: 4, item: 'b' },
/// Weighted { weight: 1, item: 'c' });
/// let wc = WeightedChoice::new(&mut items);
/// let mut rng = rand::thread_rng();
/// for _ in 0..16 {
/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
/// println!("{}", wc.sample(&mut rng));
/// }
/// ```
/// [`WeightedIndex`]: struct.WeightedIndex.html
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
#[derive(Debug)]
pub struct WeightedChoice<'a, T:'a> {
items: &'a mut [Weighted<T>],
weight_range: Uniform<u32>,
}

#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
impl<'a, T: Clone> WeightedChoice<'a, T> {
/// Create a new `WeightedChoice`.
///
Expand Down Expand Up @@ -448,6 +442,8 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
}
}

#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)]
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
// we want to find the first element that has cumulative
Expand Down Expand Up @@ -557,9 +553,11 @@ fn ziggurat<R: Rng + ?Sized, P, Z>(
#[cfg(test)]
mod tests {
use rngs::mock::StepRng;
#[allow(deprecated)]
use super::{WeightedChoice, Weighted, Distribution};

#[test]
#[allow(deprecated)]
fn test_weighted_choice() {
// this makes assumptions about the internal implementation of
// WeightedChoice. It may fail when the implementation in
Expand Down Expand Up @@ -619,6 +617,7 @@ mod tests {
}

#[test]
#[allow(deprecated)]
fn test_weighted_clone_initialization() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let clone = initial.clone();
Expand All @@ -627,6 +626,7 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_clone_change_weight() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
Expand All @@ -635,6 +635,7 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_clone_change_item() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
let mut clone = initial.clone();
Expand All @@ -644,15 +645,18 @@ mod tests {
}

#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_no_items() {
WeightedChoice::<isize>::new(&mut []);
}
#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_zero_weight() {
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
Weighted { weight: 0, item: 1}]);
}
#[test] #[should_panic]
#[allow(deprecated)]
fn test_weighted_choice_weight_overflows() {
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
Expand Down
170 changes: 170 additions & 0 deletions src/distributions/weighted.rs
@@ -0,0 +1,170 @@
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use Rng;
use distributions::Distribution;
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
use ::core::cmp::PartialOrd;
use ::{Error, ErrorKind};

// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature="std"))] use alloc::Vec;

/// A distribution using weighted sampling to pick a discretely selected item.
///
/// Sampling a `WeightedIndex` distribution returns the index of a randomly
/// selected element from the iterator used when the `WeightedIndex` was
/// created. The chance of a given element being picked is proportional to the
/// value of the element. The weights can use any type `X` for which an
/// implementation of [`Uniform<X>`] exists.
///
/// # Example
///
/// ```
/// use rand::prelude::*;
/// use rand::distributions::WeightedIndex;
///
/// let choices = ['a', 'b', 'c'];
/// let weights = [2, 1, 1];
/// let dist = WeightedIndex::new(&weights).unwrap();
/// let mut rng = thread_rng();
/// for _ in 0..100 {
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// println!("{}", choices[dist.sample(&mut rng)]);
/// }
///
/// let items = [('a', 0), ('b', 3), ('c', 7)];
/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
/// for _ in 0..100 {
/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
/// println!("{}", items[dist2.sample(&mut rng)].0);
/// }
/// ```
#[derive(Debug, Clone)]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
weight_distribution: X::Sampler,
}

impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
/// Creates a new a `WeightedIndex` [`Distribution`] using the values
/// in `weights`. The weights can use any type `X` for which an
/// implementation of [`Uniform<X>`] exists.
///
/// Returns an error if the iterator is empty, if any weight is `< 0`, or
/// if its total value is 0.
///
/// [`Distribution`]: trait.Distribution.html
/// [`Uniform<X>`]: struct.Uniform.html
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
where I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> +
Clone +
Default {
let mut iter = weights.into_iter();
let mut total_weight: X = iter.next()
.ok_or(Error::new(ErrorKind::Unexpected, "Empty iterator in WeightedIndex::new"))?
.borrow()
.clone();

let zero = <X as Default>::default();
if total_weight < zero {
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
}

let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
for w in iter {
if *w.borrow() < zero {
return Err(Error::new(ErrorKind::Unexpected, "Negative weight in WeightedIndex::new"));
}
weights.push(total_weight.clone());
total_weight += w.borrow();
}

if total_weight == zero {
return Err(Error::new(ErrorKind::Unexpected, "Total weight is zero in WeightedIndex::new"));
}
let distr = X::Sampler::new(zero, total_weight);

Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
}
}

impl<X> Distribution<usize> for WeightedIndex<X> where
X: SampleUniform + PartialOrd {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
use ::core::cmp::Ordering;
let chosen_weight = self.weight_distribution.sample(rng);
// Find the first item which has a weight *higher* than the chosen weight.
self.cumulative_weights.binary_search_by(
|w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this panic if Ok? Will that just not happen? (Won't happen)

}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_weightedindex() {
let mut r = ::test::rng(700);
const N_REPS: u32 = 5000;
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let total_weight = weights.iter().sum::<u32>() as f32;

let verify = |result: [i32; 14]| {
for (i, count) in result.iter().enumerate() {
let exp = (weights[i] * N_REPS) as f32 / total_weight;
let mut err = (*count as f32 - exp).abs();
if err != 0.0 {
err /= exp;
}
assert!(err <= 0.25);
}
};

// WeightedIndex from vec
let mut chosen = [0i32; 14];
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

// WeightedIndex from slice
chosen = [0i32; 14];
let distr = WeightedIndex::new(&weights[..]).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

// WeightedIndex from iterator
chosen = [0i32; 14];
let distr = WeightedIndex::new(weights.iter()).unwrap();
for _ in 0..N_REPS {
chosen[distr.sample(&mut r)] += 1;
}
verify(chosen);

for _ in 0..5 {
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
}

assert!(WeightedIndex::new(&[10][0..0]).is_err());
assert!(WeightedIndex::new(&[0]).is_err());
assert!(WeightedIndex::new(&[10, 20, -1, 30]).is_err());
assert!(WeightedIndex::new(&[-10, 20, 1, 30]).is_err());
assert!(WeightedIndex::new(&[-10]).is_err());
}
}
5 changes: 0 additions & 5 deletions src/lib.rs
Expand Up @@ -134,10 +134,6 @@
//!
//! For more slice/sequence related functionality, look in the [`seq` module].
//!
//! There is also [`distributions::WeightedChoice`], which can be used to pick
//! elements at random with some probability. But it does not work well at the
//! moment and is going through a redesign.
//!
//!
//! # Error handling
//!
Expand Down Expand Up @@ -187,7 +183,6 @@
//!
//!
//! [`distributions` module]: distributions/index.html
//! [`distributions::WeightedChoice`]: distributions/struct.WeightedChoice.html
//! [`EntropyRng`]: rngs/struct.EntropyRng.html
//! [`Error`]: struct.Error.html
//! [`gen_range`]: trait.Rng.html#method.gen_range
Expand Down