diff --git a/Cargo.toml b/Cargo.toml index e0581952329..0f81c9fc7d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ appveyor = { repository = "rust-random/rand" } [features] # Meta-features: default = ["std", "std_rng"] -nightly = ["simd_support"] # enables all features requiring nightly rust +nightly = ["simd_support", "partition_at_index"] # enables all features requiring nightly rust serde1 = [] # does nothing, deprecated # Option (enabled by default): without "std" rand uses libcore; this option @@ -45,6 +45,9 @@ std_rng = ["rand_chacha", "rand_hc"] # Option: enable SmallRng small_rng = ["rand_pcg"] +# Option (requires nightly): better performance of choose_multiple_weighted +partition_at_index = [] + [workspace] members = [ "rand_core", diff --git a/benches/seq.rs b/benches/seq.rs index 7da2ff8a0fd..5b6fccf51ee 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -177,3 +177,24 @@ sample_indices!(misc_sample_indices_100_of_1G, sample, 100, 1000_000_000); sample_indices!(misc_sample_indices_200_of_1G, sample, 200, 1000_000_000); sample_indices!(misc_sample_indices_400_of_1G, sample, 400, 1000_000_000); sample_indices!(misc_sample_indices_600_of_1G, sample, 600, 1000_000_000); + +macro_rules! sample_indices_rand_weights { + ($name:ident, $amount:expr, $length:expr) => { + #[bench] + fn $name(b: &mut Bencher) { + let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); + b.iter(|| { + index::sample_weighted(&mut rng, $length, |idx| (1 + (idx % 100)) as u32, $amount) + }) + } + }; +} + +sample_indices_rand_weights!(misc_sample_weighted_indices_1_of_1k, 1, 1000); +sample_indices_rand_weights!(misc_sample_weighted_indices_10_of_1k, 10, 1000); +sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1k, 100, 1000); +sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1M, 100, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_200_of_1M, 200, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_400_of_1M, 400, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_600_of_1M, 600, 1000_000); +sample_indices_rand_weights!(misc_sample_weighted_indices_1k_of_1M, 1000, 1000_000); diff --git a/src/lib.rs b/src/lib.rs index a8483945ecb..3563a17d1ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,7 @@ #![doc(test(attr(allow(unused_variables), deny(warnings))))] #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(all(feature = "simd_support", feature = "nightly"), feature(stdsimd))] +#![cfg_attr(all(feature = "partition_at_index", feature = "nightly"), feature(slice_partition_at_index))] #![allow( clippy::excessive_precision, clippy::unreadable_literal, diff --git a/src/seq/index.rs b/src/seq/index.rs index 79cc5e9c11c..c58423277c3 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -8,18 +8,21 @@ //! Low-level API for sampling indices -#[cfg(feature = "alloc")] use core::slice; +#[cfg(feature = "alloc")] +use core::slice; #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::vec::{self, Vec}; -#[cfg(feature = "std")] use std::vec; +#[cfg(feature = "std")] +use std::vec; // BTreeMap is not as fast in tests, but better than nothing. #[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::collections::BTreeSet; -#[cfg(feature = "std")] use std::collections::HashSet; +#[cfg(feature = "std")] +use std::collections::HashSet; #[cfg(feature = "alloc")] -use crate::distributions::{uniform::SampleUniform, Distribution, Uniform}; +use crate::distributions::{uniform::SampleUniform, Distribution, Uniform, WeightedError}; use crate::Rng; /// A vector of indices. @@ -249,6 +252,117 @@ where R: Rng + ?Sized { } } +/// Randomly sample exactly `amount` distinct indices from `0..length`, and +/// return them in an arbitrary order (there is no guarantee of shuffling or +/// ordering). The weights are to be provided by the input function `weights`, +/// which will be called once for each index. +/// +/// This method is used internally by the slice sampling methods, but it can +/// sometimes be useful to have the indices themselves so this is provided as +/// an alternative. +/// +/// This implementation uses `O(length + amount)` space and `O(length)` time +/// if the "partition_at_index" feature is enabled, or `O(length)` space and +/// `O(length + amount * log length)` time otherwise. +/// +/// Panics if `amount > length`. +pub fn sample_weighted( + rng: &mut R, length: usize, weight: F, amount: usize, +) -> Result +where + R: Rng + ?Sized, + F: Fn(usize) -> X, + X: Into, +{ + if amount > length { + panic!("`amount` of samples must be less than or equal to `length`"); + } + + // This implementation uses the algorithm described by Efraimidis and Spirakis + // in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 + + struct Element { + index: usize, + key: f64, + } + impl PartialOrd for Element { + fn partial_cmp(&self, other: &Self) -> Option { + self.key + .partial_cmp(&other.key) + .or(Some(core::cmp::Ordering::Less)) + } + } + impl Ord for Element { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.partial_cmp(other).unwrap() // partial_cmp will always produce a value + } + } + impl PartialEq for Element { + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } + } + impl Eq for Element {} + + #[cfg(feature = "partition_at_index")] + { + if length == 0 { + return Ok(IndexVec::USize(Vec::new())); + } + + let mut candidates = Vec::with_capacity(length); + for index in 0..length { + let weight = weight(index).into(); + if weight < 0.0 || weight.is_nan() { + return Err(WeightedError::InvalidWeight); + } + + let key = rng.gen::().powf(1.0 / weight); + candidates.push(Element { index, key }) + } + + // Partially sort the array to find the `amount` elements with the greatest + // keys. Do this by using `partition_at_index` to put the elements with + // the *smallest* keys at the beginning of the list in `O(n)` time, which + // provides equivalent information about the elements with the *greatest* keys. + let (_, mid, greater) = candidates.partition_at_index(length - amount); + + let mut result = Vec::with_capacity(amount); + result.push(mid.index); + for element in greater { + result.push(element.index); + } + Ok(IndexVec::USize(result)) + } + + #[cfg(not(feature = "partition_at_index"))] + { + #[cfg(all(feature = "alloc", not(feature = "std")))] + use crate::alloc::collections::BinaryHeap; + #[cfg(feature = "std")] + use std::collections::BinaryHeap; + + // Partially sort the array such that the `amount` elements with the largest + // keys are first using a binary max heap. + let mut candidates = BinaryHeap::with_capacity(length); + for index in 0..length { + let weight = weight(index).into(); + if weight < 0.0 || weight.is_nan() { + return Err(WeightedError::InvalidWeight); + } + + let key = rng.gen::().powf(1.0 / weight); + candidates.push(Element { index, key }); + } + + let mut result = Vec::with_capacity(amount); + while result.len() < amount { + result.push(candidates.pop().unwrap().index); + } + Ok(IndexVec::USize(result)) + } +} + /// Randomly sample exactly `amount` indices from `0..length`, using Floyd's /// combination algorithm. /// diff --git a/src/seq/mod.rs b/src/seq/mod.rs index d7a12e1617e..299b66edfca 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -178,6 +178,46 @@ pub trait SliceRandom { + Clone + Default; + /// Similar to [`choose_multiple`], but where the likelihood of each element's + /// inclusion in the output may be specified. The elements are returned in an + /// arbitrary, unspecified order. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// If all of the weights are equal, even if they are all zero, each element has + /// an equal likelihood of being selected. + /// + /// The complexity of this method depends on the feature `partition_at_index`. + /// If the feature is enabled, then for slices of length `n`, the complexity + /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and + /// `O(n * log amount)` time. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = thread_rng(); + /// // First Draw * Second Draw = total odds + /// // ----------------------- + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. + /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. + /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); + /// ``` + /// [`choose_multiple`]: SliceRandom::choose_multiple + #[cfg(feature = "alloc")] + fn choose_multiple_weighted( + &self, rng: &mut R, amount: usize, weight: F, + ) -> Result, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> X, + X: Into; + /// Shuffle a mutable slice in place. /// /// For slices of length `n`, complexity is `O(n)`. @@ -450,6 +490,29 @@ impl SliceRandom for [T] { Ok(&mut self[distr.sample(rng)]) } + #[cfg(feature = "alloc")] + fn choose_multiple_weighted( + &self, rng: &mut R, amount: usize, weight: F, + ) -> Result, WeightedError> + where + R: Rng + ?Sized, + F: Fn(&Self::Item) -> X, + X: Into, + { + let amount = ::core::cmp::min(amount, self.len()); + Ok(SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample_weighted( + rng, + self.len(), + |idx| weight(&self[idx]).into(), + amount, + )? + .into_iter(), + }) + } + fn shuffle(&mut self, rng: &mut R) where R: Rng + ?Sized { for i in (1..self.len()).rev() { @@ -953,4 +1016,130 @@ mod test { do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); } } + + #[test] + #[cfg(feature = "alloc")] + fn test_multiple_weighted_edge_cases() { + use super::*; + + let mut rng = crate::test::rng(413); + + // Case 1: One of the weights is 0 + let choices = [('a', 2), ('b', 1), ('c', 0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + assert!(!result.iter().any(|val| val.0 == 'c')); + } + + // Case 2: All of the weights are 0 + let choices = [('a', 0), ('b', 0), ('c', 0)]; + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 2); + + // Case 3: Negative weights + let choices = [('a', -1), ('b', 1), ('c', 1)]; + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); + + // Case 4: Empty list + let choices = []; + let result = choices + .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 0); + + // Case 5: NaN weights + let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); + + // Case 6: +infinity weights + let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 2); + assert!(result.iter().any(|val| val.0 == 'a')); + } + + // Case 7: -infinity weights + let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap_err(), + WeightedError::InvalidWeight + ); + + // Case 8: -0 weights + let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; + assert!(choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .is_ok()); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_multiple_weighted_distributions() { + use super::*; + + // The theoretical probabilities of the different outcomes are: + // AB: 0.5 * 0.5 = 0.250 + // AC: 0.5 * 0.5 = 0.250 + // BA: 0.25 * 0.67 = 0.167 + // BC: 0.25 * 0.33 = 0.082 + // CA: 0.25 * 0.67 = 0.167 + // CB: 0.25 * 0.33 = 0.082 + let choices = [('a', 2), ('b', 1), ('c', 1)]; + let mut rng = crate::test::rng(414); + + let mut results = [0i32; 3]; + let expected_results = [4167, 4167, 1666]; + for _ in 0..10000 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + + match (result[0].0, result[1].0) { + ('a', 'b') | ('b', 'a') => { + results[0] += 1; + } + ('a', 'c') | ('c', 'a') => { + results[1] += 1; + } + ('b', 'c') | ('c', 'b') => { + results[2] += 1; + } + (_, _) => panic!("unexpected result"), + } + } + + let mut diffs = results + .iter() + .zip(&expected_results) + .map(|(a, b)| (a - b).abs()); + assert!(!diffs.any(|deviation| deviation > 100)); + } }