Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeightedIndex: Make it possible to update a subset of weights #866

Merged
merged 9 commits into from Aug 22, 2019
36 changes: 36 additions & 0 deletions benches/weighted.rs
@@ -0,0 +1,36 @@
// Copyright 2019 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#![feature(test)]

extern crate test;

use test::Bencher;
use rand::Rng;
use rand::distributions::WeightedIndex;

#[bench]
fn weighted_index_creation(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7];
b.iter(|| {
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
rng.sample(distr)
})
}

#[bench]
fn weighted_index_modification(b: &mut Bencher) {
let mut rng = rand::thread_rng();
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
b.iter(|| {
distr.update_weights(&[(2, &4), (5, &1)]).unwrap();
rng.sample(&distr)
})
}
119 changes: 117 additions & 2 deletions src/distributions/weighted/mod.rs
Expand Up @@ -84,6 +84,7 @@ use core::fmt;
#[derive(Debug, Clone)]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
total_weight: X,
weight_distribution: X::Sampler,
}

Expand Down Expand Up @@ -125,9 +126,98 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
}
let distr = X::Sampler::new(zero, total_weight);
let distr = X::Sampler::new(zero, total_weight.clone());

Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr })
}

/// Update a subset of weights, without changing the number of weights.
///
/// `new_weights` must be sorted by the index.
///
/// Using this method instead of `new` might be more efficient if only a small number of
/// weights is modified. No allocations are performed, unless the weight type `X` uses
/// allocation internally.
///
/// In case of error, `self` is not modified.
pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
where X: for<'a> ::core::ops::AddAssign<&'a X> +
for<'a> ::core::ops::SubAssign<&'a X> +
Clone +
Default {
if new_weights.is_empty() {
return Ok(());
}

let zero = <X as Default>::default();

let mut total_weight = self.total_weight.clone();

// Check for errors first, so we don't modify `self` in case something
// goes wrong.
let mut prev_i = None;
for &(i, w) in new_weights {
if let Some(old_i) = prev_i {
if old_i >= i {
return Err(WeightedError::InvalidWeight);
}
}
if *w < zero {
return Err(WeightedError::InvalidWeight);
}
if i >= self.cumulative_weights.len() + 1 {
return Err(WeightedError::TooMany);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be worth adding InvalidIndex, except that it's a breaking change. Perhaps do so in a separate PR which we don't land until we start preparing the next Rand version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I though about this as well. Will do once this is merged.

}

let mut old_w = if i < self.cumulative_weights.len() {
self.cumulative_weights[i].clone()
} else {
self.total_weight.clone()
};
if i > 0 {
old_w -= &self.cumulative_weights[i - 1];
}

total_weight -= &old_w;
total_weight += w;
prev_i = Some(i);
}
if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
}

// Update the weights. Because we checked all the preconditions in the
// previous loop, this should never panic.
let mut iter = new_weights.iter();

let mut prev_weight = zero.clone();
let mut next_new_weight = iter.next();
let &(first_new_index, _) = next_new_weight.unwrap();
let mut cumulative_weight = if first_new_index > 0 {
self.cumulative_weights[first_new_index - 1].clone()
} else {
zero.clone()
};
for i in first_new_index..self.cumulative_weights.len() {
match next_new_weight {
Some(&(j, w)) if i == j => {
cumulative_weight += w;
next_new_weight = iter.next();
},
_ => {
let mut tmp = self.cumulative_weights[i].clone();
tmp -= &prev_weight; // We know this is positive.
cumulative_weight += &tmp;
}
}
prev_weight = cumulative_weight.clone();
core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
}

self.total_weight = total_weight;
self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());

Ok(())
}
}

Expand Down Expand Up @@ -201,6 +291,31 @@ mod test {
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight);
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight);
}

#[test]
fn test_update_weights() {
let data = [
(&[10u32, 2, 3, 4][..],
&[(1, &100), (2, &4)][..], // positive change
&[10, 100, 4, 4][..]),
(&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
&[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
&[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]),
];

for (weights, update, expected_weights) in data.into_iter() {
let total_weight = weights.iter().sum::<u32>();
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, total_weight);

distr.update_weights(update).unwrap();
let expected_total_weight = expected_weights.iter().sum::<u32>();
let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
assert_eq!(distr.total_weight, expected_total_weight);
assert_eq!(distr.total_weight, expected_distr.total_weight);
assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
}
}
}

/// Error type returned from `WeightedIndex::new`.
Expand Down