Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SliceRandom::choose_multiple_weighted, implementing weighted sampling without replacement #976

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion Cargo.toml
Expand Up @@ -23,7 +23,7 @@ appveyor = { repository = "rust-random/rand" }
[features]
# Meta-features:
default = ["std", "std_rng"]
nightly = ["simd_support"] # enables all features requiring nightly rust
nightly = ["simd_support", "partition_at_index"] # enables all features requiring nightly rust
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simpler approach might be to just use the nightly feature instead of introducing a new one. This would have the advantage that we don't have to worry about dropping the feature in the future without breaking code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning here is that this allows leaving the feature flag in place even once "slice_partition_at_index" stabilizes, which would allow supporting older rustc versions. I also see your point about dropping the feature being a breaking change, though...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I think we will likely just raise the minimum Rust version. @dhardy What do you think?

Copy link
Member

@dhardy dhardy May 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it is not possible to reliably support old nightlies, and not very useful either. So removing a feature flag only usable on nightly compilers once that feature has stabilised is not an issue.

serde1 = [] # does nothing, deprecated

# Option (enabled by default): without "std" rand uses libcore; this option
Expand All @@ -45,6 +45,9 @@ std_rng = ["rand_chacha", "rand_hc"]
# Option: enable SmallRng
small_rng = ["rand_pcg"]

# Option (requires nightly): better performance of choose_multiple_weighted
partition_at_index = []

[workspace]
members = [
"rand_core",
Expand Down
21 changes: 21 additions & 0 deletions benches/seq.rs
Expand Up @@ -177,3 +177,24 @@ sample_indices!(misc_sample_indices_100_of_1G, sample, 100, 1000_000_000);
sample_indices!(misc_sample_indices_200_of_1G, sample, 200, 1000_000_000);
sample_indices!(misc_sample_indices_400_of_1G, sample, 400, 1000_000_000);
sample_indices!(misc_sample_indices_600_of_1G, sample, 600, 1000_000_000);

macro_rules! sample_indices_rand_weights {
($name:ident, $amount:expr, $length:expr) => {
#[bench]
fn $name(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
b.iter(|| {
index::sample_weighted(&mut rng, $length, |idx| (1 + (idx % 100)) as u32, $amount)
})
}
};
}

sample_indices_rand_weights!(misc_sample_weighted_indices_1_of_1k, 1, 1000);
sample_indices_rand_weights!(misc_sample_weighted_indices_10_of_1k, 10, 1000);
sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1k, 100, 1000);
sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1M, 100, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_200_of_1M, 200, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_400_of_1M, 400, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_600_of_1M, 600, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_1k_of_1M, 1000, 1000_000);
1 change: 1 addition & 0 deletions src/lib.rs
Expand Up @@ -50,6 +50,7 @@
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(all(feature = "simd_support", feature = "nightly"), feature(stdsimd))]
#![cfg_attr(all(feature = "partition_at_index", feature = "nightly"), feature(slice_partition_at_index))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This however makes separate use of partition_at_index pointless — I think unlike simd_support we should not depend on nightly here. As a bonus, this makes it impossible to use partition_at_index on stable compilers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand what the advantage of a separate partition_at_index is.

#![allow(
clippy::excessive_precision,
clippy::unreadable_literal,
Expand Down
119 changes: 115 additions & 4 deletions src/seq/index.rs
Expand Up @@ -8,18 +8,21 @@

//! Low-level API for sampling indices

#[cfg(feature = "alloc")] use core::slice;
#[cfg(feature = "alloc")]
use core::slice;

#[cfg(all(feature = "alloc", not(feature = "std")))]
use crate::alloc::vec::{self, Vec};
#[cfg(feature = "std")] use std::vec;
#[cfg(feature = "std")]
use std::vec;
// BTreeMap is not as fast in tests, but better than nothing.
#[cfg(all(feature = "alloc", not(feature = "std")))]
use crate::alloc::collections::BTreeSet;
#[cfg(feature = "std")] use std::collections::HashSet;
#[cfg(feature = "std")]
use std::collections::HashSet;
Comment on lines -11 to +22
Copy link
Member

@dhardy dhardy May 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave the reformatting out please.


#[cfg(feature = "alloc")]
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform};
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform, WeightedError};
use crate::Rng;

/// A vector of indices.
Expand Down Expand Up @@ -249,6 +252,114 @@ where R: Rng + ?Sized {
}
}

/// Randomly sample exactly `amount` distinct indices from `0..length`, and
/// return them in an arbitrary order (there is no guarantee of shuffling or
/// ordering). The weights are to be provided by the input function `weights`,
/// which will be called once for each index.
///
/// This method is used internally by the slice sampling methods, but it can
/// sometimes be useful to have the indices themselves so this is provided as
/// an alternative.
///
/// This implementation uses `O(length)` space and `O(length)` time if the
zrneely marked this conversation as resolved.
Show resolved Hide resolved
/// "partition_at_index" feature is enabled, or `O(length)` space and
/// `O(length * log amount)` time otherwise.
///
/// Panics if `amount > length`.
pub fn sample_weighted<R, F, X>(
rng: &mut R, length: usize, weight: F, amount: usize,
) -> Result<IndexVecIntoIter, WeightedError>
zrneely marked this conversation as resolved.
Show resolved Hide resolved
where
R: Rng + ?Sized,
F: Fn(usize) -> X,
X: Into<f64>,
{
if amount > length {
panic!("`amount` of samples must be less than or equal to `length`");
}

// This implementation uses the algorithm described by Efraimidis and Spirakis
// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003

struct Element {
index: usize,
key: f64,
}
impl PartialOrd for Element {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.key
.partial_cmp(&other.key)
.or(Some(std::cmp::Ordering::Less))
}
}
impl Ord for Element {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap() // partial_cmp will always produce a value
}
}
impl PartialEq for Element {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl Eq for Element {}

#[cfg(feature = "partition_at_index")]
{
if length == 0 {
return Ok(IndexVecIntoIter::USize(Vec::new().into_iter()));
}

let mut candidates = Vec::with_capacity(length);
for index in 0..length {
let weight = weight(index).into();
if weight < 0.0 || weight.is_nan() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer !(weight >= 0.0), though it is equivalent.

return Err(WeightedError::InvalidWeight);
}

let key = rng.gen::<f64>().powf(1.0 / weight);
candidates.push(Element { index, key })
}

// Partially sort the array to find the `amount` elements with the greatest
// keys. Do this by using `partition_at_index` to put the elements with
// the *smallest* keys at the beginning of the list in `O(n)` time, which
// provides equivalent information about the elements with the *greatest* keys.
let (_, mid, greater) = candidates.partition_at_index(length - amount);

let mut result = Vec::with_capacity(amount);
result.push(mid.index);
for element in greater {
result.push(element.index);
}
Ok(IndexVecIntoIter::USize(result.into_iter()))
}

#[cfg(not(feature = "partition_at_index"))]
{
use std::collections::BinaryHeap;

// Partially sort the array such that the `amount` elements with the largest
// keys are first using a binary max heap.
let mut candidates = BinaryHeap::with_capacity(length);
for index in 0..length {
let weight = weight(index).into();
if weight < 0.0 || weight.is_nan() {
return Err(WeightedError::InvalidWeight);
}

let key = rng.gen::<f64>().powf(1.0 / weight);
candidates.push(Element { index, key });
}

let mut result = Vec::with_capacity(amount);
while result.len() < amount {
result.push(candidates.pop().unwrap().index);
}
Ok(IndexVecIntoIter::USize(result.into_iter()))
}
}

/// Randomly sample exactly `amount` indices from `0..length`, using Floyd's
/// combination algorithm.
///
Expand Down
180 changes: 180 additions & 0 deletions src/seq/mod.rs
Expand Up @@ -178,6 +178,46 @@ pub trait SliceRandom {
+ Clone
+ Default;

/// Similar to [`choose_multiple`], but where the likelihood of each element's
/// inclusion in the output may be specified. The elements are returned in an
/// arbitrary, unspecified order.
///
/// The specified function `weight` maps each item `x` to a relative
/// likelihood `weight(x)`. The probability of each item being selected is
/// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
///
/// If all of the weights are equal, even if they are all zero, each element has
/// an equal likelihood of being selected.
///
/// The complexity of this method depends on the feature `partition_at_index`.
/// If the feature is enabled, then for slices of length `n`, the complexity
/// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and
/// `O(n * log amount)` time.
///
/// # Example
///
/// ```
/// use rand::prelude::*;
///
/// let choices = [('a', 2), ('b', 1), ('c', 1)];
/// let mut rng = thread_rng();
/// // First Draw * Second Draw = total odds
/// // -----------------------
/// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order.
/// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order.
/// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order.
/// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
/// ```
/// [`choose_multiple`]: SliceRandom::choose_multiple
#[cfg(feature = "alloc")]
fn choose_multiple_weighted<R, F, X>(
&self, rng: &mut R, amount: usize, weight: F,
) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
where
R: Rng + ?Sized,
F: Fn(&Self::Item) -> X,
X: Into<f64>;

/// Shuffle a mutable slice in place.
///
/// For slices of length `n`, complexity is `O(n)`.
Expand Down Expand Up @@ -450,6 +490,28 @@ impl<T> SliceRandom for [T] {
Ok(&mut self[distr.sample(rng)])
}

#[cfg(feature = "alloc")]
fn choose_multiple_weighted<R, F, X>(
&self, rng: &mut R, amount: usize, weight: F,
) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError>
where
R: Rng + ?Sized,
F: Fn(&Self::Item) -> X,
X: Into<f64>,
{
let amount = ::core::cmp::min(amount, self.len());
Ok(SliceChooseIter {
slice: self,
_phantom: Default::default(),
indices: index::sample_weighted(
rng,
self.len(),
|idx| weight(&self[idx]).into(),
amount,
)?,
})
}

fn shuffle<R>(&mut self, rng: &mut R)
where R: Rng + ?Sized {
for i in (1..self.len()).rev() {
Expand Down Expand Up @@ -953,4 +1015,122 @@ mod test {
do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]);
}
}

#[test]
fn test_multiple_weighted_edge_cases() {
use super::*;

let mut rng = crate::test::rng(413);

// Case 1: One of the weights is 0
let choices = [('a', 2), ('b', 1), ('c', 0)];
for _ in 0..100 {
let result = choices
.choose_multiple_weighted(&mut rng, 2, |item| item.1)
.unwrap()
.collect::<Vec<_>>();

assert_eq!(result.len(), 2);
assert!(!result.iter().any(|val| val.0 == 'c'));
}

// Case 2: All of the weights are 0
let choices = [('a', 0), ('b', 0), ('c', 0)];
let result = choices
.choose_multiple_weighted(&mut rng, 2, |item| item.1)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(result.len(), 2);

// Case 3: Negative weights
let choices = [('a', -1), ('b', 1), ('c', 1)];
assert!(matches!(
choices.choose_multiple_weighted(&mut rng, 2, |item| item.1),
Err(WeightedError::InvalidWeight)
));

// Case 4: Empty list
let choices = [];
let result = choices
.choose_multiple_weighted(&mut rng, 0, |_: &()| 0)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(result.len(), 0);

// Case 5: NaN weights
let choices = [('a', std::f64::NAN), ('b', 1.0), ('c', 1.0)];
assert!(matches!(
choices.choose_multiple_weighted(&mut rng, 2, |item| item.1),
Err(WeightedError::InvalidWeight)
));

// Case 6: +infinity weights
let choices = [('a', std::f64::INFINITY), ('b', 1.0), ('c', 1.0)];
for _ in 0..100 {
let result = choices
.choose_multiple_weighted(&mut rng, 2, |item| item.1)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(result.len(), 2);
assert!(result.iter().any(|val| val.0 == 'a'));
}

// Case 7: -infinity weights
let choices = [('a', std::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)];
assert!(matches!(
choices.choose_multiple_weighted(&mut rng, 2, |item| item.1),
Err(WeightedError::InvalidWeight)
));

// Case 8: -0 weights
let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)];
assert!(choices
.choose_multiple_weighted(&mut rng, 2, |item| item.1)
.is_ok());
}

#[test]
fn test_multiple_weighted_distributions() {
use super::*;

// The theoretical probabilities of the different outcomes are:
// AB: 0.5 * 0.5 = 0.250
// AC: 0.5 * 0.5 = 0.250
// BA: 0.25 * 0.67 = 0.167
// BC: 0.25 * 0.33 = 0.082
// CA: 0.25 * 0.67 = 0.167
// CB: 0.25 * 0.33 = 0.082
let choices = [('a', 2), ('b', 1), ('c', 1)];
let mut rng = crate::test::rng(414);

let mut results = [0i32; 3];
let expected_results = [4167, 4167, 1666];
for _ in 0..10000 {
let result = choices
.choose_multiple_weighted(&mut rng, 2, |item| item.1)
.unwrap()
.collect::<Vec<_>>();

assert_eq!(result.len(), 2);

match (result[0].0, result[1].0) {
('a', 'b') | ('b', 'a') => {
results[0] += 1;
}
('a', 'c') | ('c', 'a') => {
results[1] += 1;
}
('b', 'c') | ('c', 'b') => {
results[2] += 1;
}
(_, _) => panic!("unexpected result"),
}
}

let mut diffs = results
.iter()
.zip(&expected_results)
.map(|(a, b)| (a - b).abs());
assert!(!diffs.any(|deviation| deviation > 100));
}
}