Skip to content

Commit

Permalink
sample_weighted: Use less memory for length <= u32::MAX
Browse files Browse the repository at this point in the history
  • Loading branch information
vks committed Aug 4, 2020
1 parent 9cfe1ab commit 1d7471f
Showing 1 changed file with 50 additions and 21 deletions.
71 changes: 50 additions & 21 deletions src/seq/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,44 +280,72 @@ where
F: Fn(usize) -> X,
X: Into<f64>,
{
if amount == 0 {
return Ok(IndexVec::USize(Vec::new()));
if length > (::core::u32::MAX as usize) {
sample_efraimidis_spirakis(rng, length, weight, amount)
} else {
let amount = amount as u32;
let length = length as u32;
sample_efraimidis_spirakis(rng, length, weight, amount)
}
}


/// Randomly sample exactly `amount` distinct indices from `0..length`, and
/// return them in an arbitrary order (there is no guarantee of shuffling or
/// ordering). The weights are to be provided by the input function `weights`,
/// which will be called once for each index.
///
/// This implementation uses the algorithm described by Efraimidis and Spirakis
/// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003
/// It uses `O(length + amount)` space and `O(length)` time if the
/// "partition_at_index" feature is enabled, or `O(length)` space and `O(length
/// + amount * log length)` time otherwise.
///
/// Panics if `amount > length`.
fn sample_efraimidis_spirakis<R, F, X, N>(
rng: &mut R, length: N, weight: F, amount: N,
) -> Result<IndexVec, WeightedError>
where
R: Rng + ?Sized,
F: Fn(usize) -> X,
X: Into<f64>,
N: UInt,
{
if amount == N::zero() {
return Ok(IndexVec::U32(Vec::new()));
}

if amount > length {
panic!("`amount` of samples must be less than or equal to `length`");
}

// This implementation uses the algorithm described by Efraimidis and Spirakis
// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003

struct Element {
index: usize,
struct Element<N> {
index: N,
key: f64,
}
impl PartialOrd for Element {
impl<N> PartialOrd for Element<N> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.key.partial_cmp(&other.key)
}
}
impl Ord for Element {
impl<N> Ord for Element<N> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
// partial_cmp will always produce a value,
// because we check that the weights are not nan
self.partial_cmp(other).unwrap()
}
}
impl PartialEq for Element {
impl<N> PartialEq for Element<N> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl Eq for Element {}
impl<N> Eq for Element<N> {}

#[cfg(feature = "partition_at_index")]
{
let mut candidates = Vec::with_capacity(length);
for index in 0..length {
let mut candidates = Vec::with_capacity(length.as_usize());
for index in 0..length.as_usize() {
let weight = weight(index).into();
if !(weight >= 0.) {
return Err(WeightedError::InvalidWeight);
Expand All @@ -331,14 +359,15 @@ where
// keys. Do this by using `partition_at_index` to put the elements with
// the *smallest* keys at the beginning of the list in `O(n)` time, which
// provides equivalent information about the elements with the *greatest* keys.
let (_, mid, greater) = candidates.partition_at_index(length - amount);
let (_, mid, greater)
= candidates.partition_at_index(length.as_usize() - amount.as_usize());

let mut result = Vec::with_capacity(amount);
let mut result = Vec::with_capacity(amount.as_usize());
result.push(mid.index);
for element in greater {
result.push(element.index);
}
Ok(IndexVec::USize(result))
Ok(IndexVec::from(result))
}

#[cfg(not(feature = "partition_at_index"))]
Expand All @@ -350,8 +379,8 @@ where

// Partially sort the array such that the `amount` elements with the largest
// keys are first using a binary max heap.
let mut candidates = BinaryHeap::with_capacity(length);
for index in 0..length {
let mut candidates = BinaryHeap::with_capacity(length.as_usize());
for index in 0..length.as_usize() {
let weight = weight(index).into();
if weight < 0.0 || weight.is_nan() {
return Err(WeightedError::InvalidWeight);
Expand All @@ -361,11 +390,11 @@ where
candidates.push(Element { index, key });
}

let mut result = Vec::with_capacity(amount);
while result.len() < amount {
let mut result = Vec::with_capacity(amount.as_usize());
while result.len() < amount.as_usize() {
result.push(candidates.pop().unwrap().index);
}
Ok(IndexVec::USize(result))
Ok(IndexVec::from(result))
}
}

Expand Down

0 comments on commit 1d7471f

Please sign in to comment.