From 390dfa27ad7f8b1e20ef364757ba7c0b4620367e Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Tue, 13 Aug 2019 17:40:00 +0200 Subject: [PATCH 1/9] WeightedIndex: Make it possible to update a subset of weights --- src/distributions/weighted/mod.rs | 75 ++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index 5c2cd97c21f..7dbefffd41c 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -84,6 +84,7 @@ use core::fmt; #[derive(Debug, Clone)] pub struct WeightedIndex { cumulative_weights: Vec, + total_weight: X, weight_distribution: X::Sampler, } @@ -125,9 +126,63 @@ impl WeightedIndex { if total_weight == zero { return Err(WeightedError::AllWeightsZero); } - let distr = X::Sampler::new(zero, total_weight); + let distr = X::Sampler::new(zero, total_weight.clone()); - Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) + Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr }) + } + + /// Update a subset of weights, without changing the number of weights. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. For weights that are `Copy`, no allocations are performed. + /// + /// In case of error, `self` is not modified. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> + where X: for<'a> ::core::ops::AddAssign<&'a X> + + for<'a> ::core::ops::SubAssign<&'a X> + + Clone + + Default { + let zero = ::default(); + + let mut total_weight = self.total_weight.clone(); + + for &(i, w) in new_weights { + if *w < zero { + return Err(WeightedError::InvalidWeight); + } + if i >= self.cumulative_weights.len() { + return Err(WeightedError::TooMany); + } + + // Unfortunately, we will have to calculate the non-cumulative weight a second time, to + // avoid producing an invalid state of `self`. + let mut old_w = self.cumulative_weights[i].clone(); + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + total_weight -= &old_w; + total_weight += w; + } + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + + for &(i, w) in new_weights { + let mut old_w = self.cumulative_weights[i].clone(); + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + for j in i..self.cumulative_weights.len() { + self.cumulative_weights[j] -= &old_w; + self.cumulative_weights[j] += w; + } + } + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); + + Ok(()) } } @@ -201,6 +256,22 @@ mod test { assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight); assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight); } + + #[test] + fn test_update_weights() { + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); + let expected_weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; + let expected_total_weight = expected_weights.iter().sum::(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } } /// Error type returned from `WeightedIndex::new`. From 7351358e5a3515019ba1a22556adc559569ef556 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Wed, 14 Aug 2019 15:57:14 +0200 Subject: [PATCH 2/9] Benchmark creation vs. modification of `WeightedIndex` --- benches/weighted.rs | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 benches/weighted.rs diff --git a/benches/weighted.rs b/benches/weighted.rs new file mode 100644 index 00000000000..611c762470c --- /dev/null +++ b/benches/weighted.rs @@ -0,0 +1,38 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#![feature(test)] + +extern crate test; + +const RAND_BENCH_N: u64 = 1000; + +use test::Bencher; +use rand::Rng; +use rand::distributions::WeightedIndex; + +#[bench] +fn weighted_index_creation(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + b.iter(|| { + let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + rng.sample(distr) + }) +} + +#[bench] +fn weighted_index_modification(b: &mut Bencher) { + let mut rng = rand::thread_rng(); + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + b.iter(|| { + distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); + rng.sample(&distr) + }) +} From 8d846a388ad3c805a1690eb50d101e1dc320d486 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Thu, 15 Aug 2019 14:56:23 +0200 Subject: [PATCH 3/9] WeightedIndex::update_weights: Correct comment --- src/distributions/weighted/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index 7dbefffd41c..e3d81864cce 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -134,7 +134,8 @@ impl WeightedIndex { /// Update a subset of weights, without changing the number of weights. /// /// Using this method instead of `new` might be more efficient if only a small number of - /// weights is modified. For weights that are `Copy`, no allocations are performed. + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. /// /// In case of error, `self` is not modified. pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> From 304ba15835ef3c3176dab97864634ab13223521a Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Thu, 15 Aug 2019 18:21:56 +0200 Subject: [PATCH 4/9] WeightedIndex::update_weights: More efficient implementation --- src/distributions/weighted/mod.rs | 135 +++++++++++++++++++++++++----- 1 file changed, 112 insertions(+), 23 deletions(-) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index e3d81864cce..f4e7f8ee0dc 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -133,6 +133,8 @@ impl WeightedIndex { /// Update a subset of weights, without changing the number of weights. /// + /// `new_weights` must be sorted by the index. + /// /// Using this method instead of `new` might be more efficient if only a small number of /// weights is modified. No allocations are performed, unless the weight type `X` uses /// allocation internally. @@ -143,43 +145,121 @@ impl WeightedIndex { for<'a> ::core::ops::SubAssign<&'a X> + Clone + Default { + if new_weights.is_empty() { + return Ok(()); + } + let zero = ::default(); let mut total_weight = self.total_weight.clone(); + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(WeightedError::InvalidWeight); + } + } if *w < zero { return Err(WeightedError::InvalidWeight); } - if i >= self.cumulative_weights.len() { + if i >= self.cumulative_weights.len() + 1 { return Err(WeightedError::TooMany); } - // Unfortunately, we will have to calculate the non-cumulative weight a second time, to - // avoid producing an invalid state of `self`. - let mut old_w = self.cumulative_weights[i].clone(); + // Unfortunately, we will have to calculate the non-cumulative + // weight a second time, to avoid producing an invalid state of + // `self`. + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; if i > 0 { old_w -= &self.cumulative_weights[i - 1]; } total_weight -= &old_w; total_weight += w; + prev_i = Some(i); } if total_weight == zero { return Err(WeightedError::AllWeightsZero); } - for &(i, w) in new_weights { - let mut old_w = self.cumulative_weights[i].clone(); - if i > 0 { - old_w -= &self.cumulative_weights[i - 1]; + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + let &(first_new_index, first_new_weight) = iter.next().unwrap(); + + // `X` might be an unsigned type, so we have to be careful to avoid + // negative numbers. This is done by tracking the sign of `change` using + // `pos_sign`. + let mut change: X = first_new_weight.clone(); + let mut pos_sign = true; + if first_new_index > 0 { + change += &self.cumulative_weights[first_new_index - 1]; + } + + let add = |x: &mut X, pos_sign: &mut bool, y: &X| { + if !*pos_sign { + if *x < *y { + let tmp = x.clone(); + *x = y.clone(); + *x -= &tmp; + *pos_sign = !*pos_sign; + } else { + *x -= y; + } + } else { + *x += y; } + }; - for j in i..self.cumulative_weights.len() { - self.cumulative_weights[j] -= &old_w; - self.cumulative_weights[j] += w; + let sub = |x: &mut X, pos_sign: &mut bool, y: &X| { + if *pos_sign { + if *x < *y { + let tmp = x.clone(); + *x = y.clone(); + *x -= &tmp; + *pos_sign = !*pos_sign; + } else { + *x -= y; + } + } else { + *x += y; } + }; + + { + let first_new_cweight = if first_new_index < self.cumulative_weights.len() { + &self.cumulative_weights[first_new_index] + } else { + &self.total_weight + }; + sub(&mut change, &mut pos_sign, first_new_cweight); } + let mut next_new_weight = iter.next(); + for i in first_new_index..self.cumulative_weights.len() { + if let Some(&(j, w)) = next_new_weight { + // j > first_new_index >= 0 + if i >= j { + change = w.clone(); + pos_sign = true; + add(&mut change, &mut pos_sign, &self.cumulative_weights[j - 1]); + sub(&mut change, &mut pos_sign, &self.cumulative_weights[j]); + next_new_weight = iter.next(); + } + } + if pos_sign { + self.cumulative_weights[i] += &change; + } else { + self.cumulative_weights[i] -= &change; + } + } + self.total_weight = total_weight; self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); @@ -260,18 +340,27 @@ mod test { #[test] fn test_update_weights() { - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::(); - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, total_weight); - - distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); - let expected_weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; - let expected_total_weight = expected_weights.iter().sum::(); - let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, expected_total_weight); - assert_eq!(distr.total_weight, expected_distr.total_weight); - assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + let data = [ + (&[1u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[1, 100, 4, 4][..]), + (&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]), + ]; + + for (weights, update, expected_weights) in data.into_iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } } } From ee6f33453e9e08e5616b7f4dee8e883ebc55d977 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Thu, 15 Aug 2019 18:24:00 +0200 Subject: [PATCH 5/9] WeightedIndex: Clean up benchmark --- benches/weighted.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/benches/weighted.rs b/benches/weighted.rs index 611c762470c..5ddca3f2284 100644 --- a/benches/weighted.rs +++ b/benches/weighted.rs @@ -10,8 +10,6 @@ extern crate test; -const RAND_BENCH_N: u64 = 1000; - use test::Bencher; use rand::Rng; use rand::distributions::WeightedIndex; @@ -19,8 +17,8 @@ use rand::distributions::WeightedIndex; #[bench] fn weighted_index_creation(b: &mut Bencher) { let mut rng = rand::thread_rng(); + let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; b.iter(|| { - let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; let distr = WeightedIndex::new(weights.to_vec()).unwrap(); rng.sample(distr) }) From 5ef9bdeeea98818b861d1c456366f3bea9c5f354 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Fri, 16 Aug 2019 11:03:31 +0200 Subject: [PATCH 6/9] Avoid an unnecessary clone --- src/distributions/weighted/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index f4e7f8ee0dc..f20c5138ad7 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -206,9 +206,9 @@ impl WeightedIndex { let add = |x: &mut X, pos_sign: &mut bool, y: &X| { if !*pos_sign { if *x < *y { - let tmp = x.clone(); - *x = y.clone(); - *x -= &tmp; + let mut tmp = y.clone(); + tmp -= x; + std::mem::swap(x, &mut tmp); *pos_sign = !*pos_sign; } else { *x -= y; @@ -221,9 +221,9 @@ impl WeightedIndex { let sub = |x: &mut X, pos_sign: &mut bool, y: &X| { if *pos_sign { if *x < *y { - let tmp = x.clone(); - *x = y.clone(); - *x -= &tmp; + let mut tmp = y.clone(); + tmp -= x; + std::mem::swap(x, &mut tmp); *pos_sign = !*pos_sign; } else { *x -= y; From 8c258dc7cc4845d9c861aeb66fe4dece091447bc Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Fri, 16 Aug 2019 13:11:12 +0200 Subject: [PATCH 7/9] Simplify `WeightedIndex::update_weights` --- src/distributions/weighted/mod.rs | 82 ++++++++----------------------- 1 file changed, 20 insertions(+), 62 deletions(-) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index f20c5138ad7..bbed659b694 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -192,72 +192,30 @@ impl WeightedIndex { // Update the weights. Because we checked all the preconditions in the // previous loop, this should never panic. let mut iter = new_weights.iter(); - let &(first_new_index, first_new_weight) = iter.next().unwrap(); - - // `X` might be an unsigned type, so we have to be careful to avoid - // negative numbers. This is done by tracking the sign of `change` using - // `pos_sign`. - let mut change: X = first_new_weight.clone(); - let mut pos_sign = true; - if first_new_index > 0 { - change += &self.cumulative_weights[first_new_index - 1]; - } - let add = |x: &mut X, pos_sign: &mut bool, y: &X| { - if !*pos_sign { - if *x < *y { - let mut tmp = y.clone(); - tmp -= x; - std::mem::swap(x, &mut tmp); - *pos_sign = !*pos_sign; - } else { - *x -= y; - } - } else { - *x += y; - } - }; - - let sub = |x: &mut X, pos_sign: &mut bool, y: &X| { - if *pos_sign { - if *x < *y { - let mut tmp = y.clone(); - tmp -= x; - std::mem::swap(x, &mut tmp); - *pos_sign = !*pos_sign; - } else { - *x -= y; - } - } else { - *x += y; - } - }; - - { - let first_new_cweight = if first_new_index < self.cumulative_weights.len() { - &self.cumulative_weights[first_new_index] - } else { - &self.total_weight - }; - sub(&mut change, &mut pos_sign, first_new_cweight); - } + let mut prev_weight = zero.clone(); let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; for i in first_new_index..self.cumulative_weights.len() { - if let Some(&(j, w)) = next_new_weight { - // j > first_new_index >= 0 - if i >= j { - change = w.clone(); - pos_sign = true; - add(&mut change, &mut pos_sign, &self.cumulative_weights[j - 1]); - sub(&mut change, &mut pos_sign, &self.cumulative_weights[j]); + //if next_new_weight.is_some() && i == next_new_weight.unwrap().0 { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; next_new_weight = iter.next(); + }, + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; } } - if pos_sign { - self.cumulative_weights[i] += &change; - } else { - self.cumulative_weights[i] -= &change; - } + prev_weight = cumulative_weight.clone(); + std::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); } self.total_weight = total_weight; @@ -341,9 +299,9 @@ mod test { #[test] fn test_update_weights() { let data = [ - (&[1u32, 2, 3, 4][..], + (&[10u32, 2, 3, 4][..], &[(1, &100), (2, &4)][..], // positive change - &[1, 100, 4, 4][..]), + &[10, 100, 4, 4][..]), (&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]), From 20ebbd9def3342a18388883917f45f6abab57e4e Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Fri, 16 Aug 2019 13:34:20 +0200 Subject: [PATCH 8/9] Remove outdated comments --- src/distributions/weighted/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index bbed659b694..ad16988c984 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -169,9 +169,6 @@ impl WeightedIndex { return Err(WeightedError::TooMany); } - // Unfortunately, we will have to calculate the non-cumulative - // weight a second time, to avoid producing an invalid state of - // `self`. let mut old_w = if i < self.cumulative_weights.len() { self.cumulative_weights[i].clone() } else { @@ -202,7 +199,6 @@ impl WeightedIndex { zero.clone() }; for i in first_new_index..self.cumulative_weights.len() { - //if next_new_weight.is_some() && i == next_new_weight.unwrap().0 { match next_new_weight { Some(&(j, w)) if i == j => { cumulative_weight += w; From c9428a0423c74c22bfee7f3c6ea3e81068e21e8c Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Fri, 16 Aug 2019 15:01:10 +0200 Subject: [PATCH 9/9] Fix `WeightedIndex` for `alloc` builds --- src/distributions/weighted/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distributions/weighted/mod.rs b/src/distributions/weighted/mod.rs index ad16988c984..27116375fd4 100644 --- a/src/distributions/weighted/mod.rs +++ b/src/distributions/weighted/mod.rs @@ -211,7 +211,7 @@ impl WeightedIndex { } } prev_weight = cumulative_weight.clone(); - std::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); } self.total_weight = total_weight;