Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re: weighted sampling without replacement #1013

Merged
merged 11 commits into from Aug 27, 2020
21 changes: 21 additions & 0 deletions benches/seq.rs
Expand Up @@ -177,3 +177,24 @@ sample_indices!(misc_sample_indices_100_of_1G, sample, 100, 1000_000_000);
sample_indices!(misc_sample_indices_200_of_1G, sample, 200, 1000_000_000);
sample_indices!(misc_sample_indices_400_of_1G, sample, 400, 1000_000_000);
sample_indices!(misc_sample_indices_600_of_1G, sample, 600, 1000_000_000);

macro_rules! sample_indices_rand_weights {
($name:ident, $amount:expr, $length:expr) => {
#[bench]
fn $name(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
b.iter(|| {
index::sample_weighted(&mut rng, $length, |idx| (1 + (idx % 100)) as u32, $amount)
})
}
};
}

sample_indices_rand_weights!(misc_sample_weighted_indices_1_of_1k, 1, 1000);
sample_indices_rand_weights!(misc_sample_weighted_indices_10_of_1k, 10, 1000);
sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1k, 100, 1000);
sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1M, 100, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_200_of_1M, 200, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_400_of_1M, 400, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_600_of_1M, 600, 1000_000);
sample_indices_rand_weights!(misc_sample_weighted_indices_1k_of_1M, 1000, 1000_000);
1 change: 1 addition & 0 deletions src/lib.rs
Expand Up @@ -50,6 +50,7 @@
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(all(feature = "simd_support", feature = "nightly"), feature(stdsimd))]
#![cfg_attr(feature = "nightly", feature(slice_partition_at_index))]
#![cfg_attr(doc_cfg, feature(doc_cfg))]
#![allow(
clippy::excessive_precision,
Expand Down
184 changes: 182 additions & 2 deletions src/seq/index.rs
Expand Up @@ -19,7 +19,7 @@ use crate::alloc::collections::BTreeSet;
#[cfg(feature = "std")] use std::collections::HashSet;

#[cfg(feature = "alloc")]
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform};
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform, WeightedError};
use crate::Rng;

#[cfg(feature = "serde1")]
Expand Down Expand Up @@ -258,6 +258,154 @@ where R: Rng + ?Sized {
}
}

/// Randomly sample exactly `amount` distinct indices from `0..length`, and
/// return them in an arbitrary order (there is no guarantee of shuffling or
/// ordering). The weights are to be provided by the input function `weights`,
/// which will be called once for each index.
///
/// This method is used internally by the slice sampling methods, but it can
/// sometimes be useful to have the indices themselves so this is provided as
/// an alternative.
///
/// This implementation uses `O(length + amount)` space and `O(length)` time
/// if the "nightly" feature is enabled, or `O(length)` space and
/// `O(length + amount * log length)` time otherwise.
///
/// Panics if `amount > length`.
pub fn sample_weighted<R, F, X>(
rng: &mut R, length: usize, weight: F, amount: usize,
) -> Result<IndexVec, WeightedError>
where
R: Rng + ?Sized,
F: Fn(usize) -> X,
X: Into<f64>,
{
if length > (core::u32::MAX as usize) {
sample_efraimidis_spirakis(rng, length, weight, amount)
} else {
assert!(amount <= core::u32::MAX as usize);
let amount = amount as u32;
let length = length as u32;
sample_efraimidis_spirakis(rng, length, weight, amount)
}
}


/// Randomly sample exactly `amount` distinct indices from `0..length`, and
/// return them in an arbitrary order (there is no guarantee of shuffling or
/// ordering). The weights are to be provided by the input function `weights`,
/// which will be called once for each index.
///
/// This implementation uses the algorithm described by Efraimidis and Spirakis
/// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003
/// It uses `O(length + amount)` space and `O(length)` time if the
/// "nightly" feature is enabled, or `O(length)` space and `O(length
/// + amount * log length)` time otherwise.
///
/// Panics if `amount > length`.
fn sample_efraimidis_spirakis<R, F, X, N>(
rng: &mut R, length: N, weight: F, amount: N,
) -> Result<IndexVec, WeightedError>
where
R: Rng + ?Sized,
F: Fn(usize) -> X,
X: Into<f64>,
N: UInt,
IndexVec: From<Vec<N>>,
{
if amount == N::zero() {
return Ok(IndexVec::U32(Vec::new()));
}

if amount > length {
panic!("`amount` of samples must be less than or equal to `length`");
}

struct Element<N> {
index: N,
key: f64,
}
impl<N> PartialOrd for Element<N> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.key.partial_cmp(&other.key)
}
}
impl<N> Ord for Element<N> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
// partial_cmp will always produce a value,
// because we check that the weights are not nan
self.partial_cmp(other).unwrap()
}
}
impl<N> PartialEq for Element<N> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl<N> Eq for Element<N> {}

#[cfg(feature = "nightly")]
{
let mut candidates = Vec::with_capacity(length.as_usize());
let mut index = N::zero();
while index < length {
let weight = weight(index.as_usize()).into();
if !(weight >= 0.) {
return Err(WeightedError::InvalidWeight);
}

let key = rng.gen::<f64>().powf(1.0 / weight);
candidates.push(Element { index, key });

index += N::one();
}

// Partially sort the array to find the `amount` elements with the greatest
// keys. Do this by using `partition_at_index` to put the elements with
// the *smallest* keys at the beginning of the list in `O(n)` time, which
// provides equivalent information about the elements with the *greatest* keys.
let (_, mid, greater)
= candidates.partition_at_index(length.as_usize() - amount.as_usize());

let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
result.push(mid.index);
for element in greater {
result.push(element.index);
}
Ok(IndexVec::from(result))
}

#[cfg(not(feature = "nightly"))]
{
#[cfg(all(feature = "alloc", not(feature = "std")))]
use crate::alloc::collections::BinaryHeap;
#[cfg(feature = "std")]
use std::collections::BinaryHeap;

// Partially sort the array such that the `amount` elements with the largest
// keys are first using a binary max heap.
let mut candidates = BinaryHeap::with_capacity(length.as_usize());
let mut index = N::zero();
while index < length {
let weight = weight(index.as_usize()).into();
if !(weight >= 0.) {
return Err(WeightedError::InvalidWeight);
}

let key = rng.gen::<f64>().powf(1.0 / weight);
candidates.push(Element { index, key });

index += N::one();
}

let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
while result.len() < amount.as_usize() {
result.push(candidates.pop().unwrap().index);
}
Ok(IndexVec::from(result))
}
}

/// Randomly sample exactly `amount` indices from `0..length`, using Floyd's
/// combination algorithm.
///
Expand Down Expand Up @@ -322,8 +470,10 @@ where R: Rng + ?Sized {
IndexVec::from(indices)
}

trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash {
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform
+ core::hash::Hash + core::ops::AddAssign {
fn zero() -> Self;
fn one() -> Self;
fn as_usize(self) -> usize;
}
impl UInt for u32 {
Expand All @@ -332,6 +482,11 @@ impl UInt for u32 {
0
}

#[inline]
fn one() -> Self {
1
}

#[inline]
fn as_usize(self) -> usize {
self as usize
Expand All @@ -343,6 +498,11 @@ impl UInt for usize {
0
}

#[inline]
fn one() -> Self {
1
}

#[inline]
fn as_usize(self) -> usize {
self
Expand Down Expand Up @@ -462,6 +622,26 @@ mod test {
assert_eq!(v1, v2);
}

#[test]
fn test_sample_weighted() {
let seed_rng = crate::test::rng;
for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] {
let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap();
match v {
IndexVec::U32(mut indices) => {
assert_eq!(indices.len(), amount);
indices.sort();
indices.dedup();
assert_eq!(indices.len(), amount);
for &i in &indices {
assert!((i as usize) < len);
}
},
IndexVec::USize(_) => panic!("expected `IndexVec::U32`"),
}
}
}

#[test]
fn value_stability_sample() {
let do_test = |length, amount, values: &[u32]| {
Expand Down