Skip to content

Commit

Permalink
Use static dispatch for the argument to gen_range
Browse files Browse the repository at this point in the history
Note that this means that `Rng::gen_range(&0, &10)` and similar are no
longer supported, so this is a breaking change.
  • Loading branch information
vks committed Jul 31, 2020
1 parent f67517b commit 683b0b5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
- Added a `serde1` feature and added Serialize/Deserialize to `UniformInt` and `WeightedIndex` (#974)

### Changes
- `gen_range(a, b)` was replaced with `gen_range(a..b)`, and `gen_range(a..=b)` is supported
- `gen_range(a, b)` was replaced with `gen_range(a..b)`, and `gen_range(a..=b)`
is supported (#744, #1003). Note that `a` and `b` can no longer be references.

## [0.7.3] - 2020-01-10
### Fixes
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ mod tests {
fn test_float_assertions() {
use super::SampleUniform;
use std::panic::catch_unwind;
fn range<T: SampleUniform>(low: T, high: T) {
fn range<T: SampleUniform + PartialOrd>(low: T, high: T) {
let mut rng = crate::test::rng(253);
rng.gen_range(low..high);
}
Expand Down
58 changes: 43 additions & 15 deletions src/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,43 @@ use crate::distributions::uniform::{SampleUniform, UniformSampler};
use crate::distributions::{self, Distribution, Standard};
use core::num::Wrapping;
use core::{mem, slice};
use core::ops::RangeBounds;
use core::ops::{Range, RangeInclusive};

/// Range that supports generating a single sample efficiently.
///
/// Any type implementing this trait can be used to specify the sampled range
/// for `Rng::gen_range`.
pub trait SampleRange<T> {
/// Generate a sample from the given range.
fn sample_single<R: RngCore + ?Sized>(self, rng: &mut R) -> T;

/// Check whether the range is empty.
fn is_empty(&self) -> bool;
}

impl<T: SampleUniform + PartialOrd> SampleRange<T> for Range<T> {
#[inline]
fn sample_single<R: RngCore + ?Sized>(self, rng: &mut R) -> T {
T::Sampler::sample_single(self.start, self.end, rng)
}

#[inline]
fn is_empty(&self) -> bool {
!(self.start < self.end)
}
}

impl<T: SampleUniform + PartialOrd> SampleRange<T> for RangeInclusive<T> {
#[inline]
fn sample_single<R: RngCore + ?Sized>(self, rng: &mut R) -> T {
T::Sampler::sample_single_inclusive(self.start(), self.end(), rng)
}

#[inline]
fn is_empty(&self) -> bool {
!(self.start() <= self.end())
}
}

/// An automatically-implemented extension trait on [`RngCore`] providing high-level
/// generic methods for sampling values and other convenience methods.
Expand Down Expand Up @@ -104,7 +140,7 @@ pub trait Rng: RngCore {
///
/// # Panics
///
/// Panics if the range is not `low..high` or `low..=high`.
/// Panics if the range is empty.
///
/// # Example
///
Expand All @@ -128,18 +164,10 @@ pub trait Rng: RngCore {
fn gen_range<T, R>(&mut self, range: R) -> T
where
T: SampleUniform,
R: RangeBounds<T>
R: SampleRange<T>
{
use core::ops::Bound;
if let Bound::Included(low) = range.start_bound() {
match range.end_bound() {
Bound::Excluded(high) => T::Sampler::sample_single(low, high, self),
Bound::Included(high) => T::Sampler::sample_single_inclusive(low, high, self),
Bound::Unbounded => panic!("invalid upper bound"),
}
} else {
panic!("invalid lower bound");
}
assert!(!range.is_empty(), "cannot sample empty range");
range.sample_single(self)
}

/// Sample a new value, using the given distribution.
Expand Down Expand Up @@ -496,11 +524,11 @@ mod test {
assert!(a >= -4711 && a < 17);
let a = r.gen_range(-3i8..42);
assert!(a >= -3i8 && a < 42i8);
let a: u16 = r.gen_range(&10..&99);
let a: u16 = r.gen_range(10..99);
assert!(a >= 10u16 && a < 99u16);
let a = r.gen_range(-100i32..2000);
assert!(a >= -100i32 && a < 2000i32);
let a: u32 = r.gen_range(&12..&24);
let a: u32 = r.gen_range(12..24);
assert!(a >= 12u32 && a < 24u32);

assert_eq!(r.gen_range(0u32..1), 0u32);
Expand Down

0 comments on commit 683b0b5

Please sign in to comment.