Skip to content

Commit

Permalink
Added serde1 feature to Serialize/Deserialize WeightedIndex (#974)
Browse files Browse the repository at this point in the history
Re-enable serde1 for all distributions, StepRng and IndexVec
  • Loading branch information
CGMossa committed May 18, 2020
1 parent 6a5f0d4 commit 7ede440
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,10 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md).

You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful.

## [Unreleased]
### Additions
- Added a `serde1` feature and added Serialize/Deserialize to `UniformInt` and `WeightedIndex` (#974)

## [0.7.3] - 2020-01-10
### Fixes
- The `Bernoulli` distribution constructors now reports an error on NaN and on
Expand Down
5 changes: 4 additions & 1 deletion Cargo.toml
Expand Up @@ -24,7 +24,7 @@ appveyor = { repository = "rust-random/rand" }
# Meta-features:
default = ["std", "std_rng"]
nightly = ["simd_support"] # enables all features requiring nightly rust
serde1 = [] # does nothing, deprecated
serde1 = ["serde"]

# Option (enabled by default): without "std" rand uses libcore; this option
# enables functionality expected to be available on a standard platform.
Expand Down Expand Up @@ -58,6 +58,7 @@ members = [
rand_core = { path = "rand_core", version = "0.5.1" }
rand_pcg = { path = "rand_pcg", version = "0.2", optional = true }
log = { version = "0.4.4", optional = true }
serde = { version = "1.0.103", features = ["derive"], optional = true }

[dependencies.packed_simd]
# NOTE: so far no version works reliably due to dependence on unstable features
Expand All @@ -81,6 +82,8 @@ rand_hc = { path = "rand_hc", version = "0.2", optional = true }
rand_pcg = { path = "rand_pcg", version = "0.2" }
# Only for benches:
rand_hc = { path = "rand_hc", version = "0.2" }
# Only to test serde1
bincode = "1.2.1"

[package.metadata.docs.rs]
all-features = true
12 changes: 12 additions & 0 deletions src/distributions/bernoulli.rs
Expand Up @@ -12,6 +12,8 @@ use crate::distributions::Distribution;
use crate::Rng;
use core::{fmt, u64};

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
/// The Bernoulli distribution.
///
/// This is a special case of the Binomial distribution where `n = 1`.
Expand All @@ -32,6 +34,7 @@ use core::{fmt, u64};
/// so only probabilities that are multiples of 2<sup>-64</sup> can be
/// represented.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Bernoulli {
/// Probability of success, relative to the maximal integer.
p_int: u64,
Expand Down Expand Up @@ -143,6 +146,15 @@ mod test {
use crate::distributions::Distribution;
use crate::Rng;

#[test]
#[cfg(feature="serde1")]
fn test_serializing_deserializing_bernoulli() {
let coin_flip = Bernoulli::new(0.5).unwrap();
let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();

assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
}

#[test]
fn test_trivial() {
let mut r = crate::test::rng(1);
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/float.rs
Expand Up @@ -14,6 +14,9 @@ use crate::Rng;
use core::mem;
#[cfg(feature = "simd_support")] use packed_simd::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A distribution to sample floating point numbers uniformly in the half-open
/// interval `(0, 1]`, i.e. including 1 but not 0.
///
Expand All @@ -39,6 +42,7 @@ use core::mem;
/// [`Open01`]: crate::distributions::Open01
/// [`Uniform`]: crate::distributions::uniform::Uniform
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct OpenClosed01;

/// A distribution to sample floating point numbers uniformly in the open
Expand All @@ -65,6 +69,7 @@ pub struct OpenClosed01;
/// [`OpenClosed01`]: crate::distributions::OpenClosed01
/// [`Uniform`]: crate::distributions::uniform::Uniform
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Open01;


Expand Down
4 changes: 4 additions & 0 deletions src/distributions/mod.rs
Expand Up @@ -111,6 +111,9 @@ pub mod uniform;
#[cfg(feature = "alloc")] pub mod weighted;
#[cfg(feature = "alloc")] mod weighted_index;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

mod float;
#[doc(hidden)]
pub mod hidden_export {
Expand Down Expand Up @@ -320,6 +323,7 @@ where
///
/// [`Uniform`]: uniform::Uniform
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Standard;


Expand Down
4 changes: 4 additions & 0 deletions src/distributions/other.rs
Expand Up @@ -14,6 +14,9 @@ use core::num::Wrapping;
use crate::distributions::{Distribution, Standard, Uniform};
use crate::Rng;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

// ----- Sampling distributions -----

/// Sample a `char`, uniformly distributed over ASCII letters and numbers:
Expand All @@ -34,6 +37,7 @@ use crate::Rng;
/// println!("Random chars: {}", chars);
/// ```
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Alphanumeric;


Expand Down
59 changes: 58 additions & 1 deletion src/distributions/uniform.rs
Expand Up @@ -115,9 +115,11 @@ use crate::Rng;
#[allow(unused_imports)] // rustc doesn't detect that this is actually used
use crate::distributions::utils::Float;


#[cfg(feature = "simd_support")] use packed_simd::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// Sample values uniformly between two bounds.
///
/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform
Expand Down Expand Up @@ -159,6 +161,7 @@ use crate::distributions::utils::Float;
/// [`new`]: Uniform::new
/// [`new_inclusive`]: Uniform::new_inclusive
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Uniform<X: SampleUniform>(X::Sampler);

impl<X: SampleUniform> Uniform<X> {
Expand Down Expand Up @@ -347,6 +350,7 @@ where Borrowed: SampleUniform
/// multiply by `range`, the result is in the high word. Then comparing the low
/// word against `zone` makes sure our distribution is uniform.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct UniformInt<X> {
low: X,
range: X,
Expand Down Expand Up @@ -644,6 +648,7 @@ uniform_simd_int_impl! {
/// [`new_inclusive`]: UniformSampler::new_inclusive
/// [`Standard`]: crate::distributions::Standard
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct UniformFloat<X> {
low: X,
scale: X,
Expand Down Expand Up @@ -837,12 +842,14 @@ uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 }
/// Unless you are implementing [`UniformSampler`] for your own types, this type
/// should not be used directly, use [`Uniform`] instead.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct UniformDuration {
mode: UniformDurationMode,
offset: u32,
}

#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum UniformDurationMode {
Small {
secs: u64,
Expand Down Expand Up @@ -967,6 +974,56 @@ mod tests {
use super::*;
use crate::rngs::mock::StepRng;

#[test]
#[cfg(feature = "serde1")]
fn test_serialization_uniform_duration() {
let distr = UniformDuration::new(std::time::Duration::from_secs(10), std::time::Duration::from_secs(60));
let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap();
assert_eq!(
distr.offset, de_distr.offset
);
match (distr.mode, de_distr.mode) {
(UniformDurationMode::Small {secs: a_secs, nanos: a_nanos}, UniformDurationMode::Small {secs, nanos}) => {
assert_eq!(a_secs, secs);

assert_eq!(a_nanos.0.low, nanos.0.low);
assert_eq!(a_nanos.0.range, nanos.0.range);
assert_eq!(a_nanos.0.z, nanos.0.z);
}
(UniformDurationMode::Medium {nanos: a_nanos} , UniformDurationMode::Medium {nanos}) => {
assert_eq!(a_nanos.0.low, nanos.0.low);
assert_eq!(a_nanos.0.range, nanos.0.range);
assert_eq!(a_nanos.0.z, nanos.0.z);
}
(UniformDurationMode::Large {max_secs:a_max_secs, max_nanos:a_max_nanos, secs:a_secs}, UniformDurationMode::Large {max_secs, max_nanos, secs} ) => {
assert_eq!(a_max_secs, max_secs);
assert_eq!(a_max_nanos, max_nanos);

assert_eq!(a_secs.0.low, secs.0.low);
assert_eq!(a_secs.0.range, secs.0.range);
assert_eq!(a_secs.0.z, secs.0.z);
}
_ => panic!("`UniformDurationMode` was not serialized/deserialized correctly")
}
}

#[test]
#[cfg(feature = "serde1")]
fn test_uniform_serialization() {
let unit_box: Uniform<i32> = Uniform::new(-1, 1);
let de_unit_box: Uniform<i32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();

assert_eq!(unit_box.0.low, de_unit_box.0.low);
assert_eq!(unit_box.0.range, de_unit_box.0.range);
assert_eq!(unit_box.0.z, de_unit_box.0.z);

let unit_box: Uniform<f32> = Uniform::new(-1., 1.);
let de_unit_box: Uniform<f32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();

assert_eq!(unit_box.0.low, de_unit_box.0.low);
assert_eq!(unit_box.0.scale, de_unit_box.0.scale);
}

#[should_panic]
#[test]
fn test_uniform_bad_limits_equal_int() {
Expand Down
21 changes: 21 additions & 0 deletions src/distributions/weighted_index.rs
Expand Up @@ -17,6 +17,9 @@ use core::fmt;
// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A distribution using weighted sampling of discrete items
///
/// Sampling a `WeightedIndex` distribution returns the index of a randomly
Expand Down Expand Up @@ -73,6 +76,7 @@ use core::fmt;
/// [`Uniform<X>`]: crate::distributions::uniform::Uniform
/// [`RngCore`]: crate::RngCore
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
total_weight: X,
Expand Down Expand Up @@ -236,6 +240,23 @@ where X: SampleUniform + PartialOrd
mod test {
use super::*;

#[cfg(feature = "serde1")]
#[test]
fn test_weightedindex_serde1() {
let weighted_index = WeightedIndex::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();

let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
let de_weighted_index: WeightedIndex<i32> =
bincode::deserialize(&ser_weighted_index).unwrap();

assert_eq!(
de_weighted_index.cumulative_weights,
weighted_index.cumulative_weights
);
assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
}


#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_weightedindex() {
Expand Down
19 changes: 19 additions & 0 deletions src/rngs/mock.rs
Expand Up @@ -10,6 +10,9 @@

use rand_core::{impls, Error, RngCore};

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A simple implementation of `RngCore` for testing purposes.
///
/// This generates an arithmetic sequence (i.e. adds a constant each step)
Expand All @@ -25,6 +28,7 @@ use rand_core::{impls, Error, RngCore};
/// assert_eq!(sample, [2, 3, 4]);
/// ```
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct StepRng {
v: u64,
a: u64,
Expand Down Expand Up @@ -65,3 +69,18 @@ impl RngCore for StepRng {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
#[cfg(feature = "serde1")]
fn test_serialization_step_rng() {
let some_rng = StepRng::new(42, 7);
let de_some_rng: StepRng = bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap();
assert_eq!(some_rng.v, de_some_rng.v);
assert_eq!(some_rng.a, de_some_rng.a);

}
}
21 changes: 21 additions & 0 deletions src/seq/index.rs
Expand Up @@ -22,10 +22,14 @@ use crate::alloc::collections::BTreeSet;
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform};
use crate::Rng;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A vector of indices.
///
/// Multiple internal representations are possible.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum IndexVec {
#[doc(hidden)]
U32(Vec<u32>),
Expand Down Expand Up @@ -376,6 +380,23 @@ where
#[cfg(test)]
mod test {
use super::*;

#[test]
#[cfg(feature = "serde1")]
fn test_serialization_index_vec() {
let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]);
let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap();
match (some_index_vec, de_some_index_vec) {
(IndexVec::U32(a), IndexVec::U32(b)) => {
assert_eq!(a, b);
},
(IndexVec::USize(a), IndexVec::USize(b)) => {
assert_eq!(a, b);
},
_ => {panic!("failed to seralize/deserialize `IndexVec`")}
}
}

#[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::vec;
#[cfg(feature = "std")] use std::vec;

Expand Down

0 comments on commit 7ede440

Please sign in to comment.