From 8711bcc53987384a2fb07830bca9dabbfadf1243 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Wed, 29 May 2019 15:05:21 +0100 Subject: [PATCH] Uniform: use u32 samples where possible for isize, usize --- src/distributions/uniform.rs | 128 +++++++++++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 4 deletions(-) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 8f51f5aa402..2d9b1280dda 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -376,9 +376,9 @@ macro_rules! uniform_int_impl { let high = *high_b.borrow(); assert!(low <= high, "Uniform::new_inclusive called with `low > high`"); - let unsigned_max = ::core::$u_large::MAX; - let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned; + + let unsigned_max = ::core::$u_large::MAX; let ints_to_reject = if range > 0 { let range = range as $u_large; @@ -455,15 +455,135 @@ uniform_int_impl! { i32, u32, u32 } uniform_int_impl! { i64, u64, u64 } #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] uniform_int_impl! { i128, u128, u128 } -uniform_int_impl! { isize, usize, usize } uniform_int_impl! { u8, u8, u32 } uniform_int_impl! { u16, u16, u32 } uniform_int_impl! { u32, u32, u32 } uniform_int_impl! { u64, u64, u64 } -uniform_int_impl! { usize, usize, usize } #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] uniform_int_impl! { u128, u128, u128 } +// For isize and usize, we sample using u32 integers if range <= u32::MAX. +// The primary reason for this is to make results consistent across +// architectures where possible. +macro_rules! uniform_int_size_impl { + ($ty:ty) => { + impl SampleUniform for $ty { + type Sampler = UniformInt<$ty>; + } + + impl UniformSampler for UniformInt<$ty> { + type X = $ty; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Self + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low < high, "Uniform::new called with `low >= high`"); + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Self + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low <= high, + "Uniform::new_inclusive called with `low > high`"); + let range = high.wrapping_sub(low).wrapping_add(1) as usize; + + let ints_to_reject = if range == 0 { + 0 + } else if range <= ::core::u32::MAX as usize { + let unsigned_max = ::core::u32::MAX as usize; + (unsigned_max - range + 1) % range + } else { + (::core::usize::MAX - range + 1) % range + }; + + UniformInt { + low: low, + // These are really usize values, but store as $ty: + range: range as $ty, + ints_to_reject: ints_to_reject as usize as $ty + } + } + + fn sample(&self, rng: &mut R) -> Self::X { + let range = self.range as usize; + if range == 0 { + // Sample from the entire integer range. + rng.gen() + } else if range <= ::core::u32::MAX as usize { + let range = range as u32; + let unsigned_max = ::core::u32::MAX; + let zone = unsigned_max - (self.ints_to_reject as usize as u32); + loop { + let v: u32 = rng.gen(); + let (hi, lo) = v.wmul(range); + if lo <= zone { + return self.low.wrapping_add(hi as usize as $ty); + } + } + } else { + let unsigned_max = ::core::usize::MAX; + let zone = unsigned_max - (self.ints_to_reject as usize); + loop { + let v: usize = rng.gen(); + let (hi, lo) = v.wmul(range); + if lo <= zone { + return self.low.wrapping_add(hi as $ty); + } + } + } + } + + fn sample_single(low_b: B1, high_b: B2, rng: &mut R) + -> Self::X + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!(low < high, + "UniformSampler::sample_single: low >= high"); + let range = high.wrapping_sub(low) as usize; + if range <= ::core::u32::MAX as usize { + let range = range as u32; + let zone = (range << range.leading_zeros()).wrapping_sub(1); + + loop { + let v: u32 = rng.gen(); + let (hi, lo) = v.wmul(range); + if lo <= zone { + return low.wrapping_add(hi as usize as $ty); + } + } + } else { + let zone = (range << range.leading_zeros()).wrapping_sub(1); + + loop { + let v: usize = rng.gen(); + let (hi, lo) = v.wmul(range); + if lo <= zone { + return low.wrapping_add(hi as $ty); + } + } + } + } + } + } +} + +uniform_int_size_impl! { isize } +uniform_int_size_impl! { usize } + #[cfg(all(feature = "simd_support", feature = "nightly"))] macro_rules! uniform_simd_int_impl { ($ty:ident, $unsigned:ident, $u_scalar:ident) => {