diff --git a/CHANGELOG.md b/CHANGELOG.md index d69d7948a83..4e2838c3229 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. (#1094, #1108) - Add range overflow check in `Uniform` float distributions (#1108) +### Distributions +- Add slice distribution (#1107) + ## [0.8.3] - 2021-01-25 ### Fixes - Fix `no-std` + `alloc` build by gating `choose_multiple_weighted` on `std` (#1088) diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index ddf4f4fbc2c..5444fb034b3 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -99,7 +99,9 @@ use core::iter; pub use self::bernoulli::{Bernoulli, BernoulliError}; pub use self::float::{Open01, OpenClosed01}; pub use self::other::Alphanumeric; -#[doc(inline)] pub use self::uniform::Uniform; +pub use self::slice::Slice; +#[doc(inline)] +pub use self::uniform::Uniform; #[cfg(feature = "alloc")] pub use self::weighted_index::{WeightedError, WeightedIndex}; @@ -107,14 +109,18 @@ pub use self::weighted_index::{WeightedError, WeightedIndex}; mod bernoulli; pub mod uniform; -#[deprecated(since = "0.8.0", note = "use rand::distributions::{WeightedIndex, WeightedError} instead")] +#[deprecated( + since = "0.8.0", + note = "use rand::distributions::{WeightedIndex, WeightedError} instead" +)] #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod weighted; -#[cfg(feature = "alloc")] mod weighted_index; +#[cfg(feature = "alloc")] +mod weighted_index; #[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; mod float; #[doc(hidden)] @@ -123,6 +129,7 @@ pub mod hidden_export { } mod integer; mod other; +mod slice; mod utils; /// Types (distributions) that can be used to create a random instance of `T`. @@ -200,7 +207,6 @@ impl<'a, T, D: Distribution> Distribution for &'a D { } } - /// An iterator that generates random values of `T` with distribution `D`, /// using `R` as the source of randomness. /// @@ -250,7 +256,6 @@ where { } - /// A generic random value distribution, implemented for many primitive types. /// Usually generates values with a numerically uniform distribution, and with a /// range appropriate to the type. @@ -337,7 +342,6 @@ where #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Standard; - #[cfg(test)] mod tests { use super::{Distribution, Uniform}; diff --git a/src/distributions/slice.rs b/src/distributions/slice.rs new file mode 100644 index 00000000000..3302deb2a40 --- /dev/null +++ b/src/distributions/slice.rs @@ -0,0 +1,117 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::distributions::{Distribution, Uniform}; + +/// A distribution to sample items uniformly from a slice. +/// +/// [`Slice::new`] constructs a distribution referencing a slice and uniformly +/// samples references from the items in the slice. It may do extra work up +/// front to make sampling of multiple values faster; if only one sample from +/// the slice is required, [`SliceRandom::choose`] can be more efficient. +/// +/// Steps are taken to avoid bias which might be present in naive +/// implementations; for example `slice[rng.gen() % slice.len()]` samples from +/// the slice, but may be more likely to select numbers in the low range than +/// other values. +/// +/// This distribution samples with replacement; each sample is independent. +/// Sampling without replacement requires state to be retained, and therefore +/// cannot be handled by a distribution; you should instead consider methods +/// on [`SliceRandom`], such as [`SliceRandom::choose_multiple`]. +/// +/// # Example +/// +/// ``` +/// use rand::Rng; +/// use rand::distributions::Slice; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let vowels_dist = Slice::new(&vowels).unwrap(); +/// let rng = rand::thread_rng(); +/// +/// // build a string of 10 vowels +/// let vowel_string: String = rng +/// .sample_iter(&vowels_dist) +/// .take(10) +/// .collect(); +/// +/// println!("{}", vowel_string); +/// assert_eq!(vowel_string.len(), 10); +/// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); +/// ``` +/// +/// For a single sample, [`SliceRandom::choose`][crate::seq::SliceRandom::choose] +/// may be preferred: +/// +/// ``` +/// use rand::seq::SliceRandom; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let mut rng = rand::thread_rng(); +/// +/// println!("{}", vowels.choose(&mut rng).unwrap()) +/// ``` +/// +/// [`SliceRandom`]: crate::seq::SliceRandom +/// [`SliceRandom::choose`]: crate::seq::SliceRandom::choose +/// [`SliceRandom::choose_multiple`]: crate::seq::SliceRandom::choose_multiple +#[derive(Debug, Clone, Copy)] +pub struct Slice<'a, T> { + slice: &'a [T], + range: Uniform, +} + +impl<'a, T> Slice<'a, T> { + /// Create a new `Slice` instance which samples uniformly from the slice. + /// Returns `Err` if the slice is empty. + pub fn new(slice: &'a [T]) -> Result { + match slice.len() { + 0 => Err(EmptySlice), + len => Ok(Self { + slice, + range: Uniform::new(0, len), + }), + } + } +} + +impl<'a, T> Distribution<&'a T> for Slice<'a, T> { + fn sample(&self, rng: &mut R) -> &'a T { + let idx = self.range.sample(rng); + + debug_assert!( + idx < self.slice.len(), + "Uniform::new(0, {}) somehow returned {}", + self.slice.len(), + idx + ); + + // Safety: at construction time, it was ensured that the slice was + // non-empty, and that the `Uniform` range produces values in range + // for the slice + unsafe { self.slice.get_unchecked(idx) } + } +} + +/// Error type indicating that a [`Slice`] distribution was improperly +/// constructed with an empty slice. +#[derive(Debug, Clone, Copy)] +pub struct EmptySlice; + +impl core::fmt::Display for EmptySlice { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "Tried to create a `distributions::Slice` with an empty slice" + ) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for EmptySlice {}