Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement weighted sampling API #518

Merged
merged 1 commit into from Jul 1, 2018
Merged

Conversation

sicking
Copy link
Contributor

@sicking sicking commented Jun 20, 2018

Having experimented with various ways of doing weighted sampling, I think this API is the best I've thought of so far.

This PR adds the following APIs:

  • IntoIteratorRandom::choose_index_weighted. This is the most low-level way of going a single sampling. It consumes an iterable of weights and returns an index. This index can then be used to index into a slice, another iterator, etc. This is to weighted sampling what gen_range is to uniform sampling.
  • IntoIteratorRandom::into_weighted_index_distribution. If you're sampling multiple times using the same set of weights, it is more optimal to build up an array of cumulative weights and then do a binary search to find the index corresponding to the randomly generated value. So this is the most low-level API for repeated sampling. This is to weighted sampling what Uniform::new is to uniform sampling.
  • SliceRandom::choose_weighted. This is a convenience function on top of IntoIteratorRandom::choose_index_weighted which allows using a slice and a mapping function to get a weighted sample from the slice. This function is to weighted sampling what SliceRandom::choose is to uniform sampling.
  • SliceRandom::choose_weighted_mut. Same as SliceRandom::choose_weighted, but returns a mutable reference. This function is to weighted sampling what SliceRandom::choose_mut is to uniform sampling.

There's still lots of other ways we could approach this. But this set of functions felt pretty good.

The last two felt like very nice pleasant to use APIs which should cover most common use cases. Especially on modern versions of rustc where closures automatically implement Clone when they can. And they nicely match with SliceRandom::choose and SliceRandom::choose_mut.

The first two functions seem like pretty good low-level APIs for doing anything that choose_weighted and choose_weighted_mut. However I'm less sure that these functions live where they should and have the correct names. In particular, maybe IntoIteratorRandom::into_weighted_index_distribution should be a separate struct with a constructor which takes an iterator, more similar to Uniform::new.

And for IntoIteratorRandom::choose_index_weighted, SliceRandom::choose_weighted and SliceRandom::choose_weighted_mut there's the questions of if they should live on SliceRandom or IntoIteratorRandom.

The advantage of things living on IntoIteratorRandom is that they appear on both slices and on iterators, which is a good thing. The downside is that it means that my_vec.choose_index_weighted(...) will actually silently clone the Vec and its contents, which is unlikely expected. However it will also consume the Vec, which will likely lead to compilation errors, allowing the developer to catch this. The fix is to write (&my_vec).choose_index_weighted(...) which will not clone the Vec or its contents, and will not consume the Vec. This problem is unique to Vec objects and do not happen when calling on a slice or an array.

Copy link
Member

@dhardy dhardy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm less convinced about adding the IntoIteratorRandom trait because (1) the only advantage over IteratorRandom is implicit conversions (e.g. direct usage on a Vec) which aren't always a good idea (as you note) and (2) it's an extra trait. Also (3) this would let us add SliceRandom methods by the same name later without name conflicts.

src/seq.rs Outdated
/// use rand::prelude::*;
///
/// let choices = [('a', 2), ('b', 1), ('c', 1)];
/// // In rustc version XXX and newer, you can use a closure instead
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea what XXX is? 1.26 is the only version since 1.22 mentioning closures in the release announcement, and doesn't mention cloning. We don't support older than 1.22 anyway.

src/seq.rs Outdated
@@ -91,6 +94,52 @@ pub trait SliceRandom {
fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item>
where R: Rng + ?Sized;

/// Similar to [`choose`], but each item in the slice don't have the same
/// likelyhood of getting returned. The likelyhood of a given item getting
/// returned is proportional to the value returned by the mapping function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's spelled likelihood. Also, "don't have" is bad grammar and the sentence is indirectly describing the purpose. Try:

A variant of choose where the likelihood of each outcome may be specified. The specified function func maps items x to a relative likelihood func(x). The probability of each item being selected is therefore func(x) / S, where S is the sum of all func(x).

Perhaps also rename funcweight or w or f?

src/seq.rs Outdated
/// Extension trait on IntoIterator, providing random sampling methods.
pub trait IntoIteratorRandom: IntoIterator + Sized {

/// Return a the index of a random element from this iterator where the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'a the'

src/seq.rs Outdated
Self::Item: SampleBorrow<X> {
let mut iter = self.into_iter();
let mut total_weight: X = iter.next()
.expect("Can't create Distribution for empty set of weights")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should panic in this case. Maybe return some type of error instead.

@vks
Copy link
Collaborator

vks commented Jun 20, 2018

However I'm less sure that these functions live where they should and have the correct names. In particular, maybe IntoIteratorRandom::into_weighted_index_distribution should be a separate struct with a constructor which takes an iterator, more similar to Uniform::new.

I feel like this should be a distributions::Discrete struct, because it is more general: What you implemented can not just be used to generate indices into containers, it can be used to sample from arbitrary discrete distributions. I'm not convinced we need something like gen_range for case, and Discrete::new is a lot cleaner and more consistent with the other distributions than IntoIteratorRandom::into_weighted_index_distribution.

@dhardy
Copy link
Member

dhardy commented Jun 20, 2018

Why not put the weight function func into choose_index_weighted and instead call it just choose_weighted? That way it could do the same thing if called like weights.zip(0..N).choose_weighted(rng, |(w, i)| w).1 but also be more flexible. Of course it's worth testing whether the extra flexibility would impair the performance. The motivation is that it would make the existing choose_weighted redundant (more accurately: the slice version would be redundant and the iterator equivalent would not need a third iteration over the input). The _mut variant could be dropped: it's likely less useful and can be constructed as above anyway.

One thing on my mind here is avoiding adding too much code/API for relatively obscure functionality. We may be able to reduce the API to just IteratorRandom::choose_weighted and IteratorRandom::weighted_distribution (or Weighted::from_iter).

Can you add some benchmarks?

@sicking
Copy link
Contributor Author

sicking commented Jun 27, 2018

I pushed an updated API. It contains

  • WeightedIndex::new creates a new Distribution which generates weighted indexes.
  • SliceRandom::choose_weighted, wraps WeightedIndex::new using a passed in mapping function to do a weighted selection from the slice.
  • SliceRandom::choose_weighted_mut, same but returns a mutable reference.

In this version SliceRandom::choose_weighted/SliceRandom::choose_weighted_mut is slightly slower than in the initial version. However I think that's ok since for performance critical scenarios you likely want to create a WeightedIndex and sample from that multiple times.

SliceRandom::choose_weighted/SliceRandom::choose_weighted_mut now also require #[cfg(feature = "alloc")] since WeightedIndex::new does, but I think that's ok too.

I don't think we can make this a lot smaller. The bulk of the code is in WeightedIndex::new, which I think is the minimal API for weighted sampling. The two new functions on SliceRandom could technically be dropped, but they feel very nice to use, and require very little code.

@sicking
Copy link
Contributor Author

sicking commented Jun 27, 2018

Oh, but on the flip side, SliceRandom::choose_weighted/SliceRandom::choose_weighted_mut now work with closures in all versions of rustc since the function no longer needs to be clonable. (Clonable closures were added in 1.26 FWIW. It's only mentioned in the detailed release notes).

@dhardy dhardy added the D-review Do: needs review label Jun 27, 2018
@sicking sicking force-pushed the weighted branch 6 times, most recently from 261d7b0 to 8538619 Compare June 27, 2018 09:32
Copy link
Member

@dhardy dhardy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some typos I noticed.

I'm still not sure about the approach though.

@@ -0,0 +1,182 @@
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018

// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature="std"))] use alloc::Vec;

/// A distribution using weighted sampling to pick an discretely selected item.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a discretely

/// of a random element from the iterator used when the `WeightedIndex` was
/// created. The chance of a given element being picked is proportional to the
/// value of the element. The weights can use any type `X` for which an
/// implementaiton of [`Uniform<X>`] exists.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implementation


/// A distribution using weighted sampling to pick an discretely selected item.
///
/// When a `WeightedIndex` is sampled from, it returns the index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sampling a WeightedIndex distribution returns ...

impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
/// Creates a new a `WeightedIndex` [`Distribution`] using the values
/// in `weights`. The weights can use any type `X` for which an
/// implementaiton of [`Uniform<X>`] exists.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same typo

@sicking
Copy link
Contributor Author

sicking commented Jun 27, 2018

I pushed those comment fixes as well as a couple of more tests.

One thing that I wasn't sure which way to go on is what to do with negative weights. Definitely feels like if that happens that it's a pretty severe bug in the calling code, so panicking seems warranted. But also seems somewhat strange to have some errors return a Result::Err, and others panic.

@TheIronBorn
Copy link
Collaborator

I'm curious: why are we re-implementing binary search?

@sicking
Copy link
Contributor Author

sicking commented Jun 27, 2018

I'm curious: why are we re-implementing binary search?

Good point! Using binary_search_by now.

let chosen_weight = self.weight_distribution.sample(rng);
// Find the first item which has a weight *higher* than the chosen weight.
self.cumulative_weights.binary_search_by(
|w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this panic if Ok? Will that just not happen? (Won't happen)

@sicking
Copy link
Contributor Author

sicking commented Jun 27, 2018

Since the closure never returns Ordering::Equal, binary_search_by will never return Ok.

In general binary_search_* is a little wonky in that it tries to both support searching ranges (which is what we're doing here), and searching for items (i.e. an item with a particular value). It works for both but is better at the latter IMO.

@dhardy dhardy mentioned this pull request Jun 29, 2018
Copy link
Member

@dhardy dhardy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of little things (sorry I didn't catch this before)!

I also opened #535 but that can be a separate PR.

@@ -192,6 +197,8 @@ use Rng;
#[doc(inline)] pub use self::dirichlet::Dirichlet;

pub mod uniform;
#[cfg(feature="alloc")]
#[doc(hidden)] pub mod weighted;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no reason to make this module public so better make it private I think. The doc(hidden) stuff is supposedly there for backwards compatibility, though probably it's been used erroneously in a couple of cases.

///
/// # Panics
///
/// If a value in the iterator is `< 0`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to panic on some errors while handling others via Result.


let zero = <X as Default>::default();
let weights = iter.map(|w| {
assert!(*w.borrow() >= zero, "Negative weight in WeightedIndex::new");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. this can just return an error

@sicking
Copy link
Contributor Author

sicking commented Jun 30, 2018

Pushed an updated PR that fixes those comments. I'll do a separate PR to address #535.

@dhardy dhardy merged commit 0faff20 into rust-random:master Jul 1, 2018
@sicking sicking deleted the weighted branch July 12, 2018 13:03
@dhardy dhardy mentioned this pull request Jul 13, 2018
28 tasks
@dhardy dhardy added this to the 0.6 release milestone Aug 23, 2018
@mrLSD
Copy link

mrLSD commented Aug 28, 2018

It will possible use that as:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.choice.html

most important things is size, and replace options.

It most useful in practical life.

P.S. AFAIK numpy.random.choice used CDF algorithm.

@dhardy
Copy link
Member

dhardy commented Aug 28, 2018

@mrLSD we have very different requirements from Python/numpy, so I don't think there is much to take from their approach. Besides which this has already been merged.

This samples with replacement, with given weights. We also have code for sampling without replacement, but without user-defined weights: see choose_multiple in the SliceRandom doc. We do not have code for sampling without replacement using custom weights, though I'm not sure this is widely used?

I don't know why you would want to a size parameter here; wrapper functions can be used to fill an array/matrix/etc. with values. If you really want a function to return a variadic multi-dimensional array, that doesn't belong in the Rand lib; you probably want to use ndarray.

@mrLSD
Copy link

mrLSD commented Aug 29, 2018

It's very different functions choose_multiple and choose_weighted.

About size. Very simple example: I have 1kk users with their weights. And I should select only 100 unique users.

Do you really think that it's rare case?

@dhardy
Copy link
Member

dhardy commented Aug 29, 2018

Fair point. Sounds like you have a simple feature request then, rather than wanting a different API or multi-dimensional matrices? Numpy embeds many things in one place, not always for the better.

I opened #596 (which you could have done to start with). I have no plans to work on this now, but at least now there is a visible ticket open.

@mrLSD
Copy link

mrLSD commented Aug 29, 2018

@dhardy Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants