Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making distributions comparable by deriving PartialEq #1218

Merged
merged 1 commit into from Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion rand_distr/src/binomial.rs
Expand Up @@ -30,7 +30,7 @@ use num_traits::Float;
/// let v = bin.sample(&mut rand::thread_rng());
/// println!("{} is from a binomial distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Binomial {
/// Number of trials.
Expand Down Expand Up @@ -347,4 +347,9 @@ mod test {
fn test_binomial_invalid_lambda_neg() {
Binomial::new(20, -10.0).unwrap();
}

#[test]
fn binomial_distributions_can_be_compared() {
assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/cauchy.rs
Expand Up @@ -31,7 +31,7 @@ use core::fmt;
/// let v = cau.sample(&mut rand::thread_rng());
/// println!("{} is from a Cauchy(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
Expand Down Expand Up @@ -164,4 +164,9 @@ mod test {
assert_almost_eq!(*a, *b, 1e-5);
}
}

#[test]
fn cauchy_distributions_can_be_compared() {
assert_eq!(Cauchy::new(1.0, 2.0), Cauchy::new(1.0, 2.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/dirichlet.rs
Expand Up @@ -32,7 +32,7 @@ use alloc::{boxed::Box, vec, vec::Vec};
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Dirichlet<F>
where
Expand Down Expand Up @@ -183,4 +183,9 @@ mod test {
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_size(0.0f64, 2).unwrap();
}

#[test]
fn dirichlet_distributions_can_be_compared() {
assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/exponential.rs
Expand Up @@ -91,7 +91,7 @@ impl Distribution<f64> for Exp1 {
/// let v = exp.sample(&mut rand::thread_rng());
/// println!("{} is from a Exp(2) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Exp<F>
where F: Float, Exp1: Distribution<F>
Expand Down Expand Up @@ -178,4 +178,9 @@ mod test {
fn test_exp_invalid_lambda_nan() {
Exp::new(f64::nan()).unwrap();
}

#[test]
fn exponential_distributions_can_be_compared() {
assert_eq!(Exp::new(1.0), Exp::new(1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/frechet.rs
Expand Up @@ -27,7 +27,7 @@ use rand::Rng;
/// let val: f64 = thread_rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap());
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Frechet<F>
where
Expand Down Expand Up @@ -182,4 +182,9 @@ mod tests {
.zip(&probabilities)
.all(|(p_hat, p)| (p_hat - p).abs() < 0.003))
}

#[test]
fn frechet_distributions_can_be_compared() {
assert_eq!(Frechet::new(1.0, 2.0, 3.0), Frechet::new(1.0, 2.0, 3.0));
}
}
49 changes: 37 additions & 12 deletions rand_distr/src/gamma.rs
Expand Up @@ -54,7 +54,7 @@ use serde::{Serialize, Deserialize};
/// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3
/// (September 2000), 363-372.
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Gamma<F>
where
Expand Down Expand Up @@ -91,7 +91,7 @@ impl fmt::Display for Error {
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum GammaRepr<F>
where
Expand Down Expand Up @@ -119,7 +119,7 @@ where
///
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct GammaSmallShape<F>
where
Expand All @@ -135,7 +135,7 @@ where
///
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct GammaLargeShape<F>
where
Expand Down Expand Up @@ -280,7 +280,7 @@ where
/// let v = chi.sample(&mut rand::thread_rng());
/// println!("{} is from a χ²(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ChiSquared<F>
where
Expand Down Expand Up @@ -314,7 +314,7 @@ impl fmt::Display for ChiSquaredError {
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for ChiSquaredError {}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum ChiSquaredRepr<F>
where
Expand Down Expand Up @@ -385,7 +385,7 @@ where
/// let v = f.sample(&mut rand::thread_rng());
/// println!("{} is from an F(2, 32) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct FisherF<F>
where
Expand Down Expand Up @@ -472,7 +472,7 @@ where
/// let v = t.sample(&mut rand::thread_rng());
/// println!("{} is from a t(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct StudentT<F>
where
Expand Down Expand Up @@ -522,15 +522,15 @@ where
/// Generating beta variates with nonintegral shape parameters.
/// Communications of the ACM 21, 317-322.
/// https://doi.org/10.1145/359460.359482
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum BetaAlgorithm<N> {
BB(BB<N>),
BC(BC<N>),
}

/// Algorithm BB for `min(alpha, beta) > 1`.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BB<N> {
alpha: N,
Expand All @@ -539,7 +539,7 @@ struct BB<N> {
}

/// Algorithm BC for `min(alpha, beta) <= 1`.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BC<N> {
alpha: N,
Expand All @@ -560,7 +560,7 @@ struct BC<N> {
/// let v = beta.sample(&mut rand::thread_rng());
/// println!("{} is from a Beta(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Beta<F>
where
Expand Down Expand Up @@ -811,4 +811,29 @@ mod test {
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
}
}

#[test]
fn gamma_distributions_can_be_compared() {
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
}

#[test]
fn beta_distributions_can_be_compared() {
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
}

#[test]
fn chi_squared_distributions_can_be_compared() {
assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0));
}

#[test]
fn fisher_f_distributions_can_be_compared() {
assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0));
}

#[test]
fn student_t_distributions_can_be_compared() {
assert_eq!(StudentT::new(1.0), StudentT::new(1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/geometric.rs
Expand Up @@ -27,7 +27,7 @@ use num_traits::Float;
/// let v = geo.sample(&mut rand::thread_rng());
/// println!("{} is from a Geometric(0.25) distribution", v);
/// ```
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Geometric
{
Expand Down Expand Up @@ -235,4 +235,9 @@ mod test {
results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
assert!((variance - expected_variance).abs() < expected_variance / 10.0);
}

#[test]
fn geometric_distributions_can_be_compared() {
assert_eq!(Geometric::new(1.0), Geometric::new(1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/gumbel.rs
Expand Up @@ -27,7 +27,7 @@ use rand::Rng;
/// let val: f64 = thread_rng().sample(Gumbel::new(0.0, 1.0).unwrap());
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Gumbel<F>
where
Expand Down Expand Up @@ -152,4 +152,9 @@ mod tests {
.zip(&probabilities)
.all(|(p_hat, p)| (p_hat - p).abs() < 0.003))
}

#[test]
fn gumbel_distributions_can_be_compared() {
assert_eq!(Gumbel::new(1.0, 2.0), Gumbel::new(1.0, 2.0));
}
}
9 changes: 7 additions & 2 deletions rand_distr/src/hypergeometric.rs
Expand Up @@ -7,7 +7,7 @@ use core::fmt;
#[allow(unused_imports)]
use num_traits::Float;

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
enum SamplingMethod {
InverseTransform{ initial_p: f64, initial_x: i64 },
Expand Down Expand Up @@ -45,7 +45,7 @@ enum SamplingMethod {
/// let v = hypergeo.sample(&mut rand::thread_rng());
/// println!("{} is from a hypergeometric distribution", v);
/// ```
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Hypergeometric {
n1: u64,
Expand Down Expand Up @@ -419,4 +419,9 @@ mod test {
test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
}

#[test]
fn hypergeometric_distributions_can_be_compared() {
assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/inverse_gaussian.rs
Expand Up @@ -26,7 +26,7 @@ impl fmt::Display for Error {
impl std::error::Error for Error {}

/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution)
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct InverseGaussian<F>
where
Expand Down Expand Up @@ -109,4 +109,9 @@ mod tests {
assert!(InverseGaussian::new(1.0, -1.0).is_err());
assert!(InverseGaussian::new(1.0, 1.0).is_ok());
}

#[test]
fn inverse_gaussian_distributions_can_be_compared() {
assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0));
}
}
14 changes: 12 additions & 2 deletions rand_distr/src/normal.rs
Expand Up @@ -112,7 +112,7 @@ impl Distribution<f64> for StandardNormal {
/// ```
///
/// [`StandardNormal`]: crate::StandardNormal
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Normal<F>
where F: Float, StandardNormal: Distribution<F>
Expand Down Expand Up @@ -227,7 +227,7 @@ where F: Float, StandardNormal: Distribution<F>
/// let v = log_normal.sample(&mut rand::thread_rng());
/// println!("{} is from an ln N(2, 9) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct LogNormal<F>
where F: Float, StandardNormal: Distribution<F>
Expand Down Expand Up @@ -368,4 +368,14 @@ mod tests {
assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
}

#[test]
fn normal_distributions_can_be_compared() {
assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0));
}

#[test]
fn log_normal_distributions_can_be_compared() {
assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/normal_inverse_gaussian.rs
Expand Up @@ -26,7 +26,7 @@ impl fmt::Display for Error {
impl std::error::Error for Error {}

/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution)
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct NormalInverseGaussian<F>
where
Expand Down Expand Up @@ -104,4 +104,9 @@ mod tests {
assert!(NormalInverseGaussian::new(1.0, 2.0).is_err());
assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok());
}

#[test]
fn normal_inverse_gaussian_distributions_can_be_compared() {
assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/pareto.rs
Expand Up @@ -23,7 +23,7 @@ use core::fmt;
/// let val: f64 = thread_rng().sample(Pareto::new(1., 2.).unwrap());
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Pareto<F>
where F: Float, OpenClosed01: Distribution<F>
Expand Down Expand Up @@ -131,4 +131,9 @@ mod tests {
105.8826669383772,
]);
}

#[test]
fn pareto_distributions_can_be_compared() {
assert_eq!(Pareto::new(1.0, 2.0), Pareto::new(1.0, 2.0));
}
}