Skip to content

Commit

Permalink
Migrate rand_distr to num-traits for no_std support (#987)
Browse files Browse the repository at this point in the history
* replace custom Float trait with num-traits::Float
* enable no_std support via num-traits math functions
* remove Distribution<u64> impl for poisson
* move stability tests
* add copyright notice
* tweak dirichlet and alias_method to use boxed slice instead of vec
  • Loading branch information
newpavlov committed Aug 1, 2020
1 parent dda1780 commit 05a7ab3
Show file tree
Hide file tree
Showing 22 changed files with 897 additions and 1,040 deletions.
10 changes: 9 additions & 1 deletion rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ travis-ci = { repository = "rust-random/rand" }
appveyor = { repository = "rust-random/rand" }

[dependencies]
rand = { path = "..", version = "0.7" }
rand = { path = "..", version = "0.7", default-features = false }
num-traits = { version = "0.2", default-features = false, features = ["libm"] }

[features]
default = ["std"]
std = ["alloc"]
alloc = []

[dev-dependencies]
rand_pcg = { version = "0.2", path = "../rand_pcg" }
# For inline examples
rand = { path = "..", version = "0.7", default-features = false, features = ["std_rng", "std"] }
# Histogram implementation for testing uniformity
average = "0.10.3"
27 changes: 5 additions & 22 deletions rand_distr/src/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

use crate::{Distribution, Uniform};
use rand::Rng;
use std::{error, fmt};
use core::fmt;

/// The binomial distribution `Binomial(n, p)`.
///
Expand Down Expand Up @@ -53,7 +53,8 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl Binomial {
/// Construct a new `Binomial` with the given shape parameters `n` (number
Expand All @@ -72,7 +73,7 @@ impl Binomial {
/// Convert a `f64` to an `i64`, panicing on overflow.
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
fn f64_to_i64(x: f64) -> i64 {
assert!(x < (::std::i64::MAX as f64));
assert!(x < (core::i64::MAX as f64));
x as i64
}

Expand Down Expand Up @@ -106,7 +107,7 @@ impl Distribution<u64> for Binomial {
// Ranlib uses 30, and GSL uses 14.
const BINV_THRESHOLD: f64 = 10.;

if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) {
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) {
// Use the BINV algorithm.
let s = p / q;
let a = ((self.n + 1) as f64) * s;
Expand Down Expand Up @@ -338,22 +339,4 @@ mod test {
fn test_binomial_invalid_lambda_neg() {
Binomial::new(20, -10.0).unwrap();
}

#[test]
fn value_stability() {
fn test_samples(n: u64, p: f64, expected: &[u64]) {
let distr = Binomial::new(n, p).unwrap();
let mut rng = crate::test::rng(353);
let mut buf = [0; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}

// We have multiple code paths: np < 10, p > 0.5
test_samples(2, 0.7, &[1, 1, 2, 1]);
test_samples(20, 0.3, &[7, 7, 5, 7]);
test_samples(2000, 0.6, &[1194, 1208, 1192, 1210]);
}
}
41 changes: 23 additions & 18 deletions rand_distr/src/cauchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

//! The Cauchy distribution.

use crate::utils::Float;
use num_traits::{Float, FloatConst};
use crate::{Distribution, Standard};
use rand::Rng;
use std::{error, fmt};
use core::fmt;

/// The Cauchy distribution `Cauchy(median, scale)`.
///
Expand All @@ -32,9 +32,11 @@ use std::{error, fmt};
/// println!("{} is from a Cauchy(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Cauchy<N> {
median: N,
scale: N,
pub struct Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
median: F,
scale: F,
}

/// Error type returned from `Cauchy::new`.
Expand All @@ -52,30 +54,31 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl<N: Float> Cauchy<N>
where Standard: Distribution<N>
impl<F> Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
/// Construct a new `Cauchy` with the given shape parameters
/// `median` the peak location and `scale` the scale factor.
pub fn new(median: N, scale: N) -> Result<Cauchy<N>, Error> {
if !(scale > N::from(0.0)) {
pub fn new(median: F, scale: F) -> Result<Cauchy<F>, Error> {
if !(scale > F::zero()) {
return Err(Error::ScaleTooSmall);
}
Ok(Cauchy { median, scale })
}
}

impl<N: Float> Distribution<N> for Cauchy<N>
where Standard: Distribution<N>
impl<F> Distribution<F> for Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// sample from [0, 1)
let x = Standard.sample(rng);
// get standard cauchy random number
// note that π/2 is not exactly representable, even if x=0.5 the result is finite
let comp_dev = (N::pi() * x).tan();
let comp_dev = (F::PI() * x).tan();
// shift and scale according to parameters
self.median + self.scale * comp_dev
}
Expand Down Expand Up @@ -108,10 +111,12 @@ mod test {
sum += numbers[i];
}
let median = median(&mut numbers);
println!("Cauchy median: {}", median);
#[cfg(feature = "std")]
std::println!("Cauchy median: {}", median);
assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough
let mean = sum / 1000.0;
println!("Cauchy mean: {}", mean);
#[cfg(feature = "std")]
std::println!("Cauchy mean: {}", mean);
// for a Cauchy distribution the mean should not converge
assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough
}
Expand All @@ -130,8 +135,8 @@ mod test {

#[test]
fn value_stability() {
fn gen_samples<N: Float + core::fmt::Debug>(m: N, s: N, buf: &mut [N])
where Standard: Distribution<N> {
fn gen_samples<F: Float + FloatConst + core::fmt::Debug>(m: F, s: F, buf: &mut [F])
where Standard: Distribution<F> {
let distr = Cauchy::new(m, s).unwrap();
let mut rng = crate::test::rng(353);
for x in buf {
Expand Down
89 changes: 41 additions & 48 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
// except according to those terms.

//! The dirichlet distribution.

use crate::utils::Float;
#![cfg(feature = "alloc")]
use num_traits::Float;
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
use rand::Rng;
use std::{error, fmt};
use core::fmt;
use alloc::{boxed::Box, vec, vec::Vec};

/// The Dirichlet distribution `Dirichlet(alpha)`.
///
Expand All @@ -26,14 +27,20 @@ use std::{error, fmt};
/// use rand::prelude::*;
/// use rand_distr::Dirichlet;
///
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
/// let samples = dirichlet.sample(&mut rand::thread_rng());
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[derive(Clone, Debug)]
pub struct Dirichlet<N> {
pub struct Dirichlet<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Concentration parameters (alpha)
alpha: Vec<N>,
alpha: Box<[F]>,
}

/// Error type returned from `Dirchlet::new`.
Expand All @@ -58,68 +65,70 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl<N: Float> Dirichlet<N>
impl<F> Dirichlet<F>
where
StandardNormal: Distribution<N>,
Exp1: Distribution<N>,
Open01: Distribution<N>,
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
///
/// Requires `alpha.len() >= 2`.
#[inline]
pub fn new<V: Into<Vec<N>>>(alpha: V) -> Result<Dirichlet<N>, Error> {
let a = alpha.into();
if a.len() < 2 {
pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
if alpha.len() < 2 {
return Err(Error::AlphaTooShort);
}
for &ai in &a {
if !(ai > N::from(0.0)) {
for &ai in alpha.iter() {
if !(ai > F::zero()) {
return Err(Error::AlphaTooSmall);
}
}

Ok(Dirichlet { alpha: a })
Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() })
}

/// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
///
/// Requires `size >= 2`.
#[inline]
pub fn new_with_size(alpha: N, size: usize) -> Result<Dirichlet<N>, Error> {
if !(alpha > N::from(0.0)) {
pub fn new_with_size(alpha: F, size: usize) -> Result<Dirichlet<F>, Error> {
if !(alpha > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if size < 2 {
return Err(Error::SizeTooSmall);
}
Ok(Dirichlet {
alpha: vec![alpha; size],
alpha: vec![alpha; size].into_boxed_slice(),
})
}
}

impl<N: Float> Distribution<Vec<N>> for Dirichlet<N>
impl<F> Distribution<Vec<F>> for Dirichlet<F>
where
StandardNormal: Distribution<N>,
Exp1: Distribution<N>,
Open01: Distribution<N>,
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<N> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
let n = self.alpha.len();
let mut samples = vec![N::from(0.0); n];
let mut sum = N::from(0.0);
let mut samples = vec![F::zero(); n];
let mut sum = F::zero();

for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
let g = Gamma::new(a, N::from(1.0)).unwrap();
let g = Gamma::new(a, F::one()).unwrap();
*s = g.sample(rng);
sum += *s;
sum = sum + (*s);
}
let invacc = N::from(1.0) / sum;
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s *= invacc;
*s = (*s)*invacc;
}
samples
}
Expand All @@ -131,7 +140,7 @@ mod test {

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
Expand Down Expand Up @@ -170,20 +179,4 @@ mod test {
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_size(0.0f64, 2).unwrap();
}

#[test]
fn value_stability() {
let mut rng = crate::test::rng(223);
assert_eq!(
rng.sample(Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap()),
vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
);
assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![
0.17684200044809556,
0.29915953935953055,
0.1832858056608014,
0.1425623503573967,
0.19815030417417595
]);
}
}

0 comments on commit 05a7ab3

Please sign in to comment.