Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Bernoulli::new return a Result #803

Merged
merged 1 commit into from May 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/distributions.rs
Expand Up @@ -206,7 +206,7 @@ distr_float!(distr_triangular, f64, Triangular::new(0., 1., 0.9).unwrap());
distr_int!(distr_binomial, u64, Binomial::new(20, 0.7).unwrap());
distr_int!(distr_binomial_small, u64, Binomial::new(1000000, 1e-30).unwrap());
distr_int!(distr_poisson, u64, Poisson::new(4.0).unwrap());
distr!(distr_bernoulli, bool, Bernoulli::new(0.18));
distr!(distr_bernoulli, bool, Bernoulli::new(0.18).unwrap());
distr_arr!(distr_circle, [f64; 2], UnitCircle);
distr_arr!(distr_sphere, [f64; 3], UnitSphere);

Expand Down
4 changes: 2 additions & 2 deletions benches/misc.rs
Expand Up @@ -72,7 +72,7 @@ fn misc_gen_ratio_var(b: &mut Bencher) {
fn misc_bernoulli_const(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let d = rand::distributions::Bernoulli::new(0.18);
let d = rand::distributions::Bernoulli::new(0.18).unwrap();
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.sample(d);
Expand All @@ -88,7 +88,7 @@ fn misc_bernoulli_var(b: &mut Bencher) {
let mut accum = true;
let mut p = 0.18;
for _ in 0..::RAND_BENCH_N {
let d = Bernoulli::new(p);
let d = Bernoulli::new(p).unwrap();
accum ^= rng.sample(d);
p += 0.0001;
}
Expand Down
44 changes: 22 additions & 22 deletions src/distributions/bernoulli.rs
Expand Up @@ -20,7 +20,7 @@ use distributions::Distribution;
/// ```rust
/// use rand::distributions::{Bernoulli, Distribution};
///
/// let d = Bernoulli::new(0.3);
/// let d = Bernoulli::new(0.3).unwrap();
/// let v = d.sample(&mut rand::thread_rng());
/// println!("{} is from a Bernoulli distribution", v);
/// ```
Expand Down Expand Up @@ -61,13 +61,16 @@ const ALWAYS_TRUE: u64 = ::core::u64::MAX;
// in `no_std` mode.
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;

/// Error type returned from `Bernoulli::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BernoulliError {
/// `p < 0` or `p > 1`.
InvalidProbability,
}

impl Bernoulli {
/// Construct a new `Bernoulli` with the given probability of success `p`.
///
/// # Panics
///
/// If `p < 0` or `p > 1`.
///
/// # Precision
///
/// For `p = 1.0`, the resulting distribution will always generate true.
Expand All @@ -77,12 +80,12 @@ impl Bernoulli {
/// a multiple of 2<sup>-64</sup>. (Note that not all multiples of
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
#[inline]
pub fn new(p: f64) -> Bernoulli {
pub fn new(p: f64) -> Result<Bernoulli, BernoulliError> {
if p < 0.0 || p >= 1.0 {
if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } }
panic!("Bernoulli::new not called with 0.0 <= p <= 1.0");
if p == 1.0 { return Ok(Bernoulli { p_int: ALWAYS_TRUE }) }
return Err(BernoulliError::InvalidProbability);
}
Bernoulli { p_int: (p * SCALE) as u64 }
Ok(Bernoulli { p_int: (p * SCALE) as u64 })
}

/// Construct a new `Bernoulli` with the probability of success of
Expand All @@ -91,19 +94,16 @@ impl Bernoulli {
///
/// If `numerator == denominator` then the returned `Bernoulli` will always
/// return `true`. If `numerator == 0` it will always return `false`.
///
/// # Panics
///
/// If `denominator == 0` or `numerator > denominator`.
///
#[inline]
pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli {
assert!(numerator <= denominator);
pub fn from_ratio(numerator: u32, denominator: u32) -> Result<Bernoulli, BernoulliError> {
if !(numerator <= denominator) {
return Err(BernoulliError::InvalidProbability);
}
if numerator == denominator {
return Bernoulli { p_int: ::core::u64::MAX }
return Ok(Bernoulli { p_int: ALWAYS_TRUE })
}
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
Bernoulli { p_int }
Ok(Bernoulli { p_int })
}
}

Expand All @@ -126,8 +126,8 @@ mod test {
#[test]
fn test_trivial() {
let mut r = ::test::rng(1);
let always_false = Bernoulli::new(0.0);
let always_true = Bernoulli::new(1.0);
let always_false = Bernoulli::new(0.0).unwrap();
let always_true = Bernoulli::new(1.0).unwrap();
for _ in 0..5 {
assert_eq!(r.sample::<bool, _>(&always_false), false);
assert_eq!(r.sample::<bool, _>(&always_true), true);
Expand All @@ -142,8 +142,8 @@ mod test {
const P: f64 = 0.3;
const NUM: u32 = 3;
const DENOM: u32 = 10;
let d1 = Bernoulli::new(P);
let d2 = Bernoulli::from_ratio(NUM, DENOM);
let d1 = Bernoulli::new(P).unwrap();
let d2 = Bernoulli::from_ratio(NUM, DENOM).unwrap();
const N: u32 = 100_000;

let mut sum1: u32 = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mod.rs
Expand Up @@ -108,7 +108,7 @@ use Rng;
pub use self::other::Alphanumeric;
#[doc(inline)] pub use self::uniform::Uniform;
pub use self::float::{OpenClosed01, Open01};
pub use self::bernoulli::Bernoulli;
pub use self::bernoulli::{Bernoulli, BernoulliError};
#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError};

// The following are all deprecated after being moved to rand_distr
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Expand Up @@ -325,7 +325,7 @@ pub trait Rng: RngCore {
/// [`Bernoulli`]: distributions::bernoulli::Bernoulli
#[inline]
fn gen_bool(&mut self, p: f64) -> bool {
let d = distributions::Bernoulli::new(p);
let d = distributions::Bernoulli::new(p).unwrap();
self.sample(d)
}

Expand Down Expand Up @@ -354,7 +354,7 @@ pub trait Rng: RngCore {
/// [`Bernoulli`]: distributions::bernoulli::Bernoulli
#[inline]
fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool {
let d = distributions::Bernoulli::from_ratio(numerator, denominator);
let d = distributions::Bernoulli::from_ratio(numerator, denominator).unwrap();
self.sample(d)
}
}
Expand Down