diff --git a/benches/weighted.rs b/benches/weighted.rs new file mode 100644 index 00000000000..5ddca3f2284 --- /dev/null +++ b/benches/weighted.rs @@ -0,0 +1,36 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![feature(test)] + +extern crate test; + +use test::Bencher; +use rand::Rng; +use rand::distributions::WeightedIndex; + +#[bench] +fn weighted_index_creation(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; + b.iter(|| { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + rng.sample(distr) + }) +} + +#[bench] +fn weighted_index_modification(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + b.iter(|| { + distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); + rng.sample(&distr) + }) +} diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index 5c2cd97c21f..27116375fd4 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -84,6 +84,7 @@ use core::fmt; #[derive(Debug, Clone)] pub struct WeightedIndex { cumulative_weights: Vec, + total_weight: X, weight_distribution: X::Sampler, } @@ -125,9 +126,98 @@ impl WeightedIndex { if total_weight == zero { return Err(WeightedError::AllWeightsZero); } - let distr = X::Sampler::new(zero, total_weight); + let distr = X::Sampler::new(zero, total_weight.clone()); - Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) + Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr }) + } + + /// Update a subset of weights, without changing the number of weights. + /// + /// `new_weights` must be sorted by the index. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. + /// + /// In case of error, `self` is not modified. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> + where X: for<'a> ::core::ops::AddAssign<&'a X> + + for<'a> ::core::ops::SubAssign<&'a X> + + Clone + + Default { + if new_weights.is_empty() { + return Ok(()); + } + + let zero = ::default(); + + let mut total_weight = self.total_weight.clone(); + + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; + for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(WeightedError::InvalidWeight); + } + } + if *w < zero { + return Err(WeightedError::InvalidWeight); + } + if i >= self.cumulative_weights.len() + 1 { + return Err(WeightedError::TooMany); + } + + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + total_weight -= &old_w; + total_weight += w; + prev_i = Some(i); + } + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + + let mut prev_weight = zero.clone(); + let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; + for i in first_new_index..self.cumulative_weights.len() { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; + next_new_weight = iter.next(); + }, + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; + } + } + prev_weight = cumulative_weight.clone(); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + } + + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); + + Ok(()) } } @@ -201,6 +291,31 @@ mod test { assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight); assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight); } + + #[test] + fn test_update_weights() { + let data = [ + (&[10u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[10, 100, 4, 4][..]), + (&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]), + ]; + + for (weights, update, expected_weights) in data.into_iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } + } } /// Error type returned from `WeightedIndex::new`.