Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThreadRng: fix use-after-free in TL-dtor; doc for sized iterators #1035

Merged
merged 4 commits into from Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -16,6 +16,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
- Implement weighted sampling without replacement (#976, #1013)

### Changes
- `ThreadRng` is no longer `Copy` to enable safe usage within thread-local destructors (see #968)
- `gen_range(a, b)` was replaced with `gen_range(a..b)`, and `gen_range(a..=b)`
is supported (#744, #1003). Note that `a` and `b` can no longer be references or SIMD types.
- Replace `AsByteSliceMut` with `Fill` (#940)
Expand Down
12 changes: 8 additions & 4 deletions src/distributions/mod.rs
Expand Up @@ -162,17 +162,21 @@ pub trait Distribution<T> {
/// use rand::thread_rng;
/// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard};
///
/// let rng = thread_rng();
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = Standard.sample_iter(rng).take(16).collect();
/// let v: Vec<f32> = Standard.sample_iter(&mut rng).take(16).collect();
///
/// // String:
/// let s: String = Alphanumeric.sample_iter(rng).take(7).map(char::from).collect();
/// let s: String = Alphanumeric
/// .sample_iter(&mut rng)
/// .take(7)
/// .map(char::from)
/// .collect();
///
/// // Dice-rolling:
/// let die_range = Uniform::new_inclusive(1, 6);
/// let mut roll_die = die_range.sample_iter(rng);
/// let mut roll_die = die_range.sample_iter(&mut rng);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
Expand Down
13 changes: 8 additions & 5 deletions src/rng.rs
Expand Up @@ -165,21 +165,24 @@ pub trait Rng: RngCore {
/// use rand::{thread_rng, Rng};
/// use rand::distributions::{Alphanumeric, Uniform, Standard};
///
/// let rng = thread_rng();
/// let mut rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = rng.sample_iter(Standard).take(16).collect();
/// let v: Vec<f32> = (&mut rng).sample_iter(Standard).take(16).collect();
///
/// // String:
/// let s: String = rng.sample_iter(Alphanumeric).take(7).map(char::from).collect();
/// let s: String = (&mut rng).sample_iter(Alphanumeric)
/// .take(7)
/// .map(char::from)
/// .collect();
///
/// // Combined values
/// println!("{:?}", rng.sample_iter(Standard).take(5)
/// println!("{:?}", (&mut rng).sample_iter(Standard).take(5)
/// .collect::<Vec<(f64, bool)>>());
///
/// // Dice-rolling:
/// let die_range = Uniform::new_inclusive(1, 6);
/// let mut roll_die = rng.sample_iter(die_range);
/// let mut roll_die = (&mut rng).sample_iter(die_range);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
Expand Down
52 changes: 31 additions & 21 deletions src/rngs/thread.rs
Expand Up @@ -9,7 +9,7 @@
//! Thread-local random number generator

use core::cell::UnsafeCell;
use core::ptr::NonNull;
use std::rc::Rc;
use std::thread_local;

use super::std::Core;
Expand Down Expand Up @@ -37,38 +37,42 @@ use crate::{CryptoRng, Error, RngCore, SeedableRng};
// of 32 kB and less. We choose 64 kB to avoid significant overhead.
const THREAD_RNG_RESEED_THRESHOLD: u64 = 1024 * 64;

/// The type returned by [`thread_rng`], essentially just a reference to the
/// PRNG in thread-local memory.
/// A reference to the thread-local generator
///
/// `ThreadRng` uses the same PRNG as [`StdRng`] for security and performance.
/// As hinted by the name, the generator is thread-local. `ThreadRng` is a
/// handle to this generator and thus supports `Copy`, but not `Send` or `Sync`.
/// An instance can be obtained via [`thread_rng`] or via `ThreadRng::default()`.
/// This handle is safe to use everywhere (including thread-local destructors)
/// but cannot be passed between threads (is not `Send` or `Sync`).
///
/// Unlike `StdRng`, `ThreadRng` uses the [`ReseedingRng`] wrapper to reseed
/// the PRNG from fresh entropy every 64 kiB of random data.
/// [`OsRng`] is used to provide seed data.
/// `ThreadRng` uses the same PRNG as [`StdRng`] for security and performance
/// and is automatically seeded from [`OsRng`].
///
/// Unlike `StdRng`, `ThreadRng` uses the [`ReseedingRng`] wrapper to reseed
/// the PRNG from fresh entropy every 64 kiB of random data as well as after a
/// fork on Unix (though not quite immediately; see documentation of
/// [`ReseedingRng`]).
/// Note that the reseeding is done as an extra precaution against side-channel
/// attacks and mis-use (e.g. if somehow weak entropy were supplied initially).
/// The PRNG algorithms used are assumed to be secure.
///
/// [`ReseedingRng`]: crate::rngs::adapter::ReseedingRng
/// [`StdRng`]: crate::rngs::StdRng
#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))]
#[derive(Copy, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct ThreadRng {
// inner raw pointer implies type is neither Send nor Sync
rng: NonNull<ReseedingRng<Core, OsRng>>,
// Rc is explictly !Send and !Sync
rng: Rc<UnsafeCell<ReseedingRng<Core, OsRng>>>,
}

thread_local!(
static THREAD_RNG_KEY: UnsafeCell<ReseedingRng<Core, OsRng>> = {
// We require Rc<..> to avoid premature freeing when thread_rng is used
// within thread-local destructors. See #968.
static THREAD_RNG_KEY: Rc<UnsafeCell<ReseedingRng<Core, OsRng>>> = {
let r = Core::from_rng(OsRng).unwrap_or_else(|err|
panic!("could not initialize thread_rng: {}", err));
let rng = ReseedingRng::new(r,
THREAD_RNG_RESEED_THRESHOLD,
OsRng);
UnsafeCell::new(rng)
Rc::new(UnsafeCell::new(rng))
}
);

Expand All @@ -81,9 +85,8 @@ thread_local!(
/// For more information see [`ThreadRng`].
#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))]
pub fn thread_rng() -> ThreadRng {
let raw = THREAD_RNG_KEY.with(|t| t.get());
let nn = NonNull::new(raw).unwrap();
ThreadRng { rng: nn }
let rng = THREAD_RNG_KEY.with(|t| t.clone());
ThreadRng { rng }
}

impl Default for ThreadRng {
Expand All @@ -92,23 +95,30 @@ impl Default for ThreadRng {
}
}

impl ThreadRng {
#[inline(always)]
fn rng(&mut self) -> &mut ReseedingRng<Core, OsRng> {
unsafe { &mut *self.rng.get() }
Copy link
Contributor

@RalfJung RalfJung Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a subtle safety argument here: the caller is required to stop using the return mutable reference before anyone else calls this method. That is crucial to avoid aliasing mutable references.

That might be worth documenting? In particular it might be worth making the function unsafe as it is not in general safe to use by arbitrary callers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right of course. It's not a public function so I didn't put much thought into its signature.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I submitted #1037 to fix this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RalfJung Which function should be unsafe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I answered there.)

}
}

impl RngCore for ThreadRng {
#[inline(always)]
fn next_u32(&mut self) -> u32 {
unsafe { self.rng.as_mut().next_u32() }
self.rng().next_u32()
}

#[inline(always)]
fn next_u64(&mut self) -> u64 {
unsafe { self.rng.as_mut().next_u64() }
self.rng().next_u64()
}

fn fill_bytes(&mut self, dest: &mut [u8]) {
unsafe { self.rng.as_mut().fill_bytes(dest) }
self.rng().fill_bytes(dest)
}

fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
unsafe { self.rng.as_mut().try_fill_bytes(dest) }
self.rng().try_fill_bytes(dest)
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/seq/mod.rs
Expand Up @@ -265,7 +265,8 @@ pub trait SliceRandom {

/// Extension trait on iterators, providing random sampling methods.
///
/// This trait is implemented on all sized iterators, providing methods for
/// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
/// and provides methods for
/// choosing one or more elements. You must `use` this trait:
dhardy marked this conversation as resolved.
Show resolved Hide resolved
///
/// ```
Expand Down