Skip to content

Commit

Permalink
Merge pull request #866 from vks/update-weights
Browse files Browse the repository at this point in the history
WeightedIndex: Make it possible to update a subset of weights
  • Loading branch information
dhardy committed Aug 22, 2019
2 parents 29056a0 + c9428a0 commit 8616945
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 2 deletions.
36 changes: 36 additions & 0 deletions benches/weighted.rs
@@ -0,0 +1,36 @@
// Copyright 2019 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#![feature(test)]

extern crate test;

use test::Bencher;
use rand::Rng;
use rand::distributions::WeightedIndex;

#[bench]
fn weighted_index_creation(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
b.iter(|| {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
rng.sample(distr)
})
}

#[bench]
fn weighted_index_modification(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
b.iter(|| {
distr.update_weights(&[(2, &4), (5, &1)]).unwrap();
rng.sample(&distr)
})
}
119 changes: 117 additions & 2 deletions src/distributions/weighted/mod.rs
Expand Up @@ -84,6 +84,7 @@ use core::fmt;
#[derive(Debug, Clone)]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
total_weight: X,
weight_distribution: X::Sampler,
}

Expand Down Expand Up @@ -125,9 +126,98 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
}
let distr = X::Sampler::new(zero, total_weight);
let distr = X::Sampler::new(zero, total_weight.clone());

Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr })
}

/// Update a subset of weights, without changing the number of weights.
///
/// `new_weights` must be sorted by the index.
///
/// Using this method instead of `new` might be more efficient if only a small number of
/// weights is modified. No allocations are performed, unless the weight type `X` uses
/// allocation internally.
///
/// In case of error, `self` is not modified.
pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
where X: for<'a> ::core::ops::AddAssign<&'a X> +
for<'a> ::core::ops::SubAssign<&'a X> +
Clone +
Default {
if new_weights.is_empty() {
return Ok(());
}

let zero = <X as Default>::default();

let mut total_weight = self.total_weight.clone();

// Check for errors first, so we don't modify `self` in case something
// goes wrong.
let mut prev_i = None;
for &(i, w) in new_weights {
if let Some(old_i) = prev_i {
if old_i >= i {
return Err(WeightedError::InvalidWeight);
}
}
if *w < zero {
return Err(WeightedError::InvalidWeight);
}
if i >= self.cumulative_weights.len() + 1 {
return Err(WeightedError::TooMany);
}

let mut old_w = if i < self.cumulative_weights.len() {
self.cumulative_weights[i].clone()
} else {
self.total_weight.clone()
};
if i > 0 {
old_w -= &self.cumulative_weights[i - 1];
}

total_weight -= &old_w;
total_weight += w;
prev_i = Some(i);
}
if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
}

// Update the weights. Because we checked all the preconditions in the
// previous loop, this should never panic.
let mut iter = new_weights.iter();

let mut prev_weight = zero.clone();
let mut next_new_weight = iter.next();
let &(first_new_index, _) = next_new_weight.unwrap();
let mut cumulative_weight = if first_new_index > 0 {
self.cumulative_weights[first_new_index - 1].clone()
} else {
zero.clone()
};
for i in first_new_index..self.cumulative_weights.len() {
match next_new_weight {
Some(&(j, w)) if i == j => {
cumulative_weight += w;
next_new_weight = iter.next();
},
_ => {
let mut tmp = self.cumulative_weights[i].clone();
tmp -= &prev_weight; // We know this is positive.
cumulative_weight += &tmp;
}
}
prev_weight = cumulative_weight.clone();
core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
}

self.total_weight = total_weight;
self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());

Ok(())
}
}

Expand Down Expand Up @@ -201,6 +291,31 @@ mod test {
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight);
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight);
}

#[test]
fn test_update_weights() {
let data = [
(&[10u32, 2, 3, 4][..],
&[(1, &100), (2, &4)][..], // positive change
&[10, 100, 4, 4][..]),
(&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
&[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]),
];

for (weights, update, expected_weights) in data.into_iter() {
let total_weight = weights.iter().sum::<u32>();
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, total_weight);

distr.update_weights(update).unwrap();
let expected_total_weight = expected_weights.iter().sum::<u32>();
let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, expected_total_weight);
assert_eq!(distr.total_weight, expected_distr.total_weight);
assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
}
}
}

/// Error type returned from `WeightedIndex::new`.
Expand Down

0 comments on commit 8616945

Please sign in to comment.