diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a5386b6187a..8355d6b7c23 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -92,7 +92,7 @@ jobs: cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=alloc,getrandom - name: Test rand_distr run: | - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml + cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde1 cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features --features=std,std_math - name: Test rand_pcg @@ -134,7 +134,7 @@ jobs: cross test --no-fail-fast --target ${{ matrix.target }} --features=serde1,log,small_rng cross test --no-fail-fast --target ${{ matrix.target }} --examples cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml + cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde1 cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde1 cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_hc/Cargo.toml diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 27fc9a8ddd3..e9d1391091b 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Make sure all distributions and their error types implement `Error`, `Display`, `Clone`, `Copy`, `PartialEq` and `Eq` as appropriate (#1126) - Port benchmarks to use Criterion crate (#1116) +- Support serde for distributions (#1141) ## [0.4.0] - 2020-12-18 - Bump `rand` to v0.8.0 diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index 7bbc5c4d400..317168e9d14 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -15,15 +15,17 @@ categories = ["algorithms", "no-std"] edition = "2018" include = ["src/", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] -[dependencies] -rand = { path = "..", version = "0.8.0", default-features = false } -num-traits = { version = "0.2", default-features = false, features = ["libm"] } - [features] default = ["std"] std = ["alloc", "rand/std"] alloc = ["rand/alloc"] std_math = ["num-traits/std"] +serde1 = ["serde", "rand/serde1"] + +[dependencies] +rand = { path = "..", version = "0.8.0", default-features = false } +num-traits = { version = "0.2", default-features = false, features = ["libm"] } +serde = { version = "1.0.103", features = ["derive"], optional = true } [dev-dependencies] rand_pcg = { version = "0.3.0", path = "../rand_pcg" } diff --git a/rand_distr/README.md b/rand_distr/README.md index 35f5dcaa1a8..3fc2ea62ef9 100644 --- a/rand_distr/README.md +++ b/rand_distr/README.md @@ -20,11 +20,7 @@ It is worth mentioning the [statrs] crate which provides similar functionality along with various support functions, including PDF and CDF computation. In contrast, this `rand_distr` crate focuses on sampling from distributions. -If the `std` default feature is enabled, `rand_distr` implements the `Error` -trait for its error types. - -The default `alloc` feature (which is implied by the `std` feature) is required -for some distributions (in particular, `Dirichlet` and `WeightedAliasIndex`). +## Portability and libm The floating point functions from `num_traits` and `libm` are used to support `no_std` environments and ensure reproducibility. If the floating point @@ -32,7 +28,16 @@ functions from `std` are preferred, which may provide better accuracy and performance but may produce different random values, the `std_math` feature can be enabled. -Links: +## Crate features + +- `std` (enabled by default): `rand_distr` implements the `Error` trait for + its error types. Implies `alloc` and `rand/std`. +- `alloc` (enabled by default): required for some distributions when not using + `std` (in particular, `Dirichlet` and `WeightedAliasIndex`). +- `std_math`: see above on portability and libm +- `serde1`: implement (de)seriaialization using `serde` + +## Links - [API documentation (master)](https://rust-random.github.io/rand/rand_distr) - [API documentation (docs.rs)](https://docs.rs/rand_distr) diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index a701e6bb684..5efe367e126 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -29,6 +29,7 @@ use core::cmp::Ordering; /// println!("{} is from a binomial distribution", v); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Binomial { /// Number of trials. n: u64, diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 66dc09b3e7b..49b121c41d2 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -32,6 +32,7 @@ use core::fmt; /// println!("{} is from a Cauchy(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Cauchy where F: Float + FloatConst, Standard: Distribution { diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 6286e935742..0ffbc40a049 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -33,6 +33,7 @@ use alloc::{boxed::Box, vec, vec::Vec}; /// ``` #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[derive(Clone, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Dirichlet where F: Float, diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index e389319d91f..4e33c3cac6e 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -39,6 +39,7 @@ use core::fmt; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Exp1; impl Distribution for Exp1 { @@ -91,6 +92,7 @@ impl Distribution for Exp1 { /// println!("{} is from a Exp(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Exp where F: Float, Exp1: Distribution { diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 45181b9204c..87faf11c893 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -21,6 +21,8 @@ use num_traits::Float; use crate::{Distribution, Exp, Exp1, Open01}; use rand::Rng; use core::fmt; +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; /// The Gamma distribution `Gamma(shape, scale)` distribution. /// @@ -53,6 +55,7 @@ use core::fmt; /// (September 2000), 363-372. /// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Gamma where F: Float, @@ -89,6 +92,7 @@ impl fmt::Display for Error { impl std::error::Error for Error {} #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] enum GammaRepr where F: Float, @@ -116,6 +120,7 @@ where /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] struct GammaSmallShape where F: Float, @@ -131,6 +136,7 @@ where /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] struct GammaLargeShape where F: Float, @@ -275,6 +281,7 @@ where /// println!("{} is from a χ²(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct ChiSquared where F: Float, @@ -287,6 +294,7 @@ where /// Error type returned from `ChiSquared::new` and `StudentT::new`. #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub enum ChiSquaredError { /// `0.5 * k <= 0` or `nan`. DoFTooSmall, @@ -307,6 +315,7 @@ impl fmt::Display for ChiSquaredError { impl std::error::Error for ChiSquaredError {} #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] enum ChiSquaredRepr where F: Float, @@ -377,6 +386,7 @@ where /// println!("{} is from an F(2, 32) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct FisherF where F: Float, @@ -393,6 +403,7 @@ where /// Error type returned from `FisherF::new`. #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub enum FisherFError { /// `m <= 0` or `nan`. MTooSmall, @@ -462,6 +473,7 @@ where /// println!("{} is from a t(11) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct StudentT where F: Float, @@ -511,6 +523,7 @@ where /// Communications of the ACM 21, 317-322. /// https://doi.org/10.1145/359460.359482 #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] enum BetaAlgorithm { BB(BB), BC(BC), @@ -518,6 +531,7 @@ enum BetaAlgorithm { /// Algorithm BB for `min(alpha, beta) > 1`. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] struct BB { alpha: N, beta: N, @@ -526,6 +540,7 @@ struct BB { /// Algorithm BC for `min(alpha, beta) <= 1`. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] struct BC { alpha: N, beta: N, @@ -546,6 +561,7 @@ struct BC { /// println!("{} is from a Beta(2, 5) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Beta where F: Float, @@ -557,6 +573,7 @@ where /// Error type returned from `Beta::new`. #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub enum BetaError { /// `alpha <= 0` or `nan`. AlphaTooSmall, diff --git a/rand_distr/src/geometric.rs b/rand_distr/src/geometric.rs index 8cb10576e6a..31bf98c896e 100644 --- a/rand_distr/src/geometric.rs +++ b/rand_distr/src/geometric.rs @@ -26,6 +26,7 @@ use core::fmt; /// println!("{} is from a Geometric(0.25) distribution", v); /// ``` #[derive(Copy, Clone, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Geometric { p: f64, @@ -151,6 +152,7 @@ impl Distribution for Geometric /// println!("{} is from a Geometric(0.5) distribution", v); /// ``` #[derive(Copy, Clone, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct StandardGeometric; impl Distribution for StandardGeometric { @@ -231,4 +233,4 @@ mod test { results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; assert!((variance - expected_variance).abs() < expected_variance / 10.0); } -} \ No newline at end of file +} diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs index 7a5eedc3600..8ab2dca0333 100644 --- a/rand_distr/src/hypergeometric.rs +++ b/rand_distr/src/hypergeometric.rs @@ -6,6 +6,7 @@ use rand::distributions::uniform::Uniform; use core::fmt; #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] enum SamplingMethod { InverseTransform{ initial_p: f64, initial_x: i64 }, RejectionAcceptance{ @@ -43,6 +44,7 @@ enum SamplingMethod { /// println!("{} is from a hypergeometric distribution", v); /// ``` #[derive(Copy, Clone, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Hypergeometric { n1: u64, n2: u64, diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index becb02b64f8..58986a769aa 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -27,6 +27,7 @@ impl std::error::Error for Error {} /// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) #[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct InverseGaussian where F: Float, diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index c3e99010c6f..7078a894f43 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -37,6 +37,7 @@ use core::fmt; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct StandardNormal; impl Distribution for StandardNormal { @@ -112,6 +113,7 @@ impl Distribution for StandardNormal { /// /// [`StandardNormal`]: crate::StandardNormal #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Normal where F: Float, StandardNormal: Distribution { @@ -226,6 +228,7 @@ where F: Float, StandardNormal: Distribution /// println!("{} is from an ln N(2, 9) distribution", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct LogNormal where F: Float, StandardNormal: Distribution { diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index d8e44587d05..c4d693d031d 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -27,6 +27,7 @@ impl std::error::Error for Error {} /// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) #[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct NormalInverseGaussian where F: Float, diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index 53e2987fce1..cd61894c526 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -24,6 +24,7 @@ use core::fmt; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Pareto where F: Float, OpenClosed01: Distribution { diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index e53ea0b89e8..4ead1fb8f74 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -31,6 +31,7 @@ use core::fmt; /// /// [`Triangular`]: crate::Triangular #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Pert where F: Float, diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index cf20bce5b5d..dc355258dfe 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -29,6 +29,7 @@ use core::fmt; /// println!("{} is from a Poisson(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Poisson where F: Float + FloatConst, Standard: Distribution { diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs index 97693b32fcc..ba6d36445ce 100644 --- a/rand_distr/src/triangular.rs +++ b/rand_distr/src/triangular.rs @@ -32,6 +32,7 @@ use core::fmt; /// /// [`Pert`]: crate::Pert #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Triangular where F: Float, Standard: Distribution { diff --git a/rand_distr/src/unit_ball.rs b/rand_distr/src/unit_ball.rs index e5585a1e677..8a4b4fbf3d1 100644 --- a/rand_distr/src/unit_ball.rs +++ b/rand_distr/src/unit_ball.rs @@ -25,6 +25,7 @@ use rand::Rng; /// println!("{:?} is from the unit ball.", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct UnitBall; impl Distribution<[F; 3]> for UnitBall { diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 29e5c9a5939..24a06f3f4de 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -29,6 +29,7 @@ use rand::Rng; /// NBS Appl. Math. Ser., No. 12. Washington, DC: U.S. Government Printing /// Office, pp. 36-38. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct UnitCircle; impl Distribution<[F; 2]> for UnitCircle { diff --git a/rand_distr/src/unit_disc.rs b/rand_distr/src/unit_disc.rs index ced548b4dc0..937c1d01b84 100644 --- a/rand_distr/src/unit_disc.rs +++ b/rand_distr/src/unit_disc.rs @@ -24,6 +24,7 @@ use rand::Rng; /// println!("{:?} is from the unit Disc.", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct UnitDisc; impl Distribution<[F; 2]> for UnitDisc { diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index b167a5d5d63..2b299239f49 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -28,6 +28,7 @@ use rand::Rng; /// Sphere.*](https://doi.org/10.1214/aoms/1177692644) /// Ann. Math. Statist. 43, no. 2, 645--646. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct UnitSphere; impl Distribution<[F; 3]> for UnitSphere { diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index aa9bdc44405..b390ad3ff2c 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -24,6 +24,7 @@ use core::fmt; /// println!("{}", val); /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Weibull where F: Float, OpenClosed01: Distribution { diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted_alias.rs index 53a9c2713d4..2cd90c52a32 100644 --- a/rand_distr/src/weighted_alias.rs +++ b/rand_distr/src/weighted_alias.rs @@ -16,6 +16,8 @@ use core::iter::Sum; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use rand::Rng; use alloc::{boxed::Box, vec, vec::Vec}; +#[cfg(feature = "serde1")] +use serde::{Serialize, Deserialize}; /// A distribution using weighted sampling to pick a discretely selected item. /// @@ -64,6 +66,9 @@ use alloc::{boxed::Box, vec, vec::Vec}; /// [`Uniform::sample`]: Distribution::sample /// [`Uniform::sample`]: Distribution::sample #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))] +#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))] pub struct WeightedAliasIndex { aliases: Box<[u32]>, no_alias_odds: Box<[W]>, diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 516a58c7072..066ae0df1d0 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -172,6 +172,8 @@ use serde::{Serialize, Deserialize}; /// [`Rng::gen_range`]: Rng::gen_range #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))] +#[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))] pub struct Uniform(X::Sampler); impl Uniform {