Skip to content

Commit

Permalink
Merge pull request #83 from coreylowman/add-sampling
Browse files Browse the repository at this point in the history
Adds rand_distr::Distribution implementations for f16 & bf16
  • Loading branch information
starkat99 committed Apr 26, 2023
2 parents 6ed9b43 + c3ee4ea commit 1d7f862
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 3 deletions.
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ default = ["std"]
std = ["alloc"]
use-intrinsics = []
alloc = []
rand_distr = ["dep:rand", "dep:rand_distr"]

[dependencies]
cfg-if = "1.0.0"
Expand All @@ -29,6 +30,8 @@ serde = { version = "1.0", default-features = false, features = [
], optional = true }
num-traits = { version = "0.2.14", default-features = false, features = ["libm"], optional = true }
zerocopy = { version = "0.6.0", default-features = false, optional = true }
rand = { version = "0.8.5", default-features = false, optional = true }
rand_distr = { version = "0.4.3", default-features = false, optional = true }

[target.'cfg(target_arch = "spirv")'.dependencies]
crunchy = "0.2.2"
Expand All @@ -37,7 +40,7 @@ crunchy = "0.2.2"
criterion = "0.4.0"
quickcheck = "1.0"
quickcheck_macros = "1.0"
rand = "0.8.4"
rand = "0.8.5"
crunchy = "0.2.2"

[[bench]]
Expand All @@ -46,4 +49,4 @@ harness = false

[package.metadata.docs.rs]
rustc-args = ["--cfg", "docsrs"]
features = ["std", "serde", "bytemuck", "num-traits", "zerocopy"]
features = ["std", "serde", "bytemuck", "num-traits", "zerocopy", "rand_distr"]
2 changes: 1 addition & 1 deletion Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ min_version = "0.35.0"
CI_CARGO_TEST_FLAGS = { value = "--locked -- --nocapture", condition = { env_true = [
"CARGO_MAKE_CI",
] } }
CARGO_MAKE_CARGO_ALL_FEATURES = { source = "${CARGO_MAKE_RUST_CHANNEL}", default_value = "--features=std,serde,num-traits,bytemuck,zerocopy", mapping = { "nightly" = "--all-features" } }
CARGO_MAKE_CARGO_ALL_FEATURES = { source = "${CARGO_MAKE_RUST_CHANNEL}", default_value = "--features=std,serde,num-traits,bytemuck,zerocopy,rand_distr", mapping = { "nightly" = "--all-features" } }
CARGO_MAKE_CLIPPY_ARGS = { value = "${CARGO_MAKE_CLIPPY_ALL_FEATURES_WARN}", condition = { env_true = [
"CARGO_MAKE_CI",
] } }
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ See the [crate documentation](https://docs.rs/half/) for more details.
- **`zerocopy`** — Enable `AsBytes` and `FromBytes` trait implementations from the
[`zerocopy`](https://crates.io/crates/zerocopy) crate.

- **`rand_distr`** — Enable sampling from distributions like `Uniform` and `Normal` from the
[`rand_distr`](https://crates.io/crates/rand_distr) crate.

### Hardware support

The following list details hardware support for floating point types in this crate. When using `std`
Expand Down
7 changes: 7 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
//! - **`zerocopy`** — Adds support for the [`zerocopy`] crate by implementing [`AsBytes`] and
//! [`FromBytes`] traits for both [`f16`] and [`bf16`].
//!
//! - **`rand_distr`** — Adds support for the [`rand_distr`] crate by implementing [`rand::distributions::Distribution`]
//! and other traits for both [`f16`] and [`bf16`].
//!
//! [`alloc`]: https://doc.rust-lang.org/alloc/
//! [`std`]: https://doc.rust-lang.org/std/
//! [`binary16`]: https://en.wikipedia.org/wiki/Half-precision_floating-point_format
Expand All @@ -102,6 +105,7 @@
//! [`bytemuck`]: https://crates.io/crates/bytemuck
//! [`num-traits`]: https://crates.io/crates/num-traits
//! [`zerocopy`]: https://crates.io/crates/zerocopy
//! [`rand_distr`]: https://crates.io/crates/rand_distr
#![cfg_attr(
feature = "alloc",
doc = "
Expand Down Expand Up @@ -209,6 +213,9 @@ pub mod vec;
pub use bfloat::bf16;
pub use binary16::f16;

#[cfg(feature = "rand_distr")]
mod rand_distr;

/// A collection of the most used items and traits in this crate for easy importing.
///
/// # Examples
Expand Down
124 changes: 124 additions & 0 deletions src/rand_distr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use crate::{bf16, f16};

use rand::{distributions::Distribution, Rng};
use rand_distr::uniform::UniformFloat;

macro_rules! impl_distribution_via_f32 {
($Ty:ty, $Distr:ty) => {
impl Distribution<$Ty> for $Distr {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $Ty {
<$Ty>::from_f32(<Self as Distribution<f32>>::sample(self, rng))
}
}
};
}

impl_distribution_via_f32!(f16, rand_distr::Standard);
impl_distribution_via_f32!(f16, rand_distr::StandardNormal);
impl_distribution_via_f32!(f16, rand_distr::Exp1);
impl_distribution_via_f32!(f16, rand_distr::Open01);
impl_distribution_via_f32!(f16, rand_distr::OpenClosed01);

impl_distribution_via_f32!(bf16, rand_distr::Standard);
impl_distribution_via_f32!(bf16, rand_distr::StandardNormal);
impl_distribution_via_f32!(bf16, rand_distr::Exp1);
impl_distribution_via_f32!(bf16, rand_distr::Open01);
impl_distribution_via_f32!(bf16, rand_distr::OpenClosed01);

#[derive(Debug, Clone, Copy)]
pub struct Float16Sampler(UniformFloat<f32>);

impl rand_distr::uniform::SampleUniform for f16 {
type Sampler = Float16Sampler;
}

impl rand_distr::uniform::UniformSampler for Float16Sampler {
type X = f16;
fn new<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(UniformFloat::new(
low.borrow().to_f32(),
high.borrow().to_f32(),
))
}
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(UniformFloat::new_inclusive(
low.borrow().to_f32(),
high.borrow().to_f32(),
))
}
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
f16::from_f32(self.0.sample(rng))
}
}

#[derive(Debug, Clone, Copy)]
pub struct BFloat16Sampler(UniformFloat<f32>);

impl rand_distr::uniform::SampleUniform for bf16 {
type Sampler = BFloat16Sampler;
}

impl rand_distr::uniform::UniformSampler for BFloat16Sampler {
type X = bf16;
fn new<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(UniformFloat::new(
low.borrow().to_f32(),
high.borrow().to_f32(),
))
}
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
{
Self(UniformFloat::new_inclusive(
low.borrow().to_f32(),
high.borrow().to_f32(),
))
}
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
bf16::from_f32(self.0.sample(rng))
}
}

#[cfg(test)]
mod tests {
use super::*;

use rand::{thread_rng, Rng};
use rand_distr::{Standard, StandardNormal, Uniform};

#[test]
fn test_sample_f16() {
let mut rng = thread_rng();
let _: f16 = rng.sample(Standard);
let _: f16 = rng.sample(StandardNormal);
let _: f16 = rng.sample(Uniform::new(f16::from_f32(0.0), f16::from_f32(1.0)));
#[cfg(feature = "num-traits")]
let _: f16 =
rng.sample(rand_distr::Normal::new(f16::from_f32(0.0), f16::from_f32(1.0)).unwrap());
}

#[test]
fn test_sample_bf16() {
let mut rng = thread_rng();
let _: bf16 = rng.sample(Standard);
let _: bf16 = rng.sample(StandardNormal);
let _: bf16 = rng.sample(Uniform::new(bf16::from_f32(0.0), bf16::from_f32(1.0)));
#[cfg(feature = "num-traits")]
let _: bf16 =
rng.sample(rand_distr::Normal::new(bf16::from_f32(0.0), bf16::from_f32(1.0)).unwrap());
}
}

0 comments on commit 1d7f862

Please sign in to comment.