Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DistString impl to Uniform and Slice #1315

Merged
merged 2 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/distributions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// except according to those terms.

use crate::distributions::{Distribution, Uniform};
#[cfg(feature = "alloc")]
use alloc::string::String;

/// A distribution to sample items uniformly from a slice.
///
Expand Down Expand Up @@ -115,3 +117,22 @@ impl core::fmt::Display for EmptySlice {

#[cfg(feature = "std")]
impl std::error::Error for EmptySlice {}

/// Note: the `String` is potentially left with excess capacity; optionally the
/// user may call `string.shrink_to_fit()` afterwards.
#[cfg(feature = "alloc")]
impl<'a> super::DistString for Slice<'a, char> {
fn append_string<R: crate::Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize) {
let max_char_len = self
.slice
.iter()
.try_fold(1, |max_len, char| {
// When the current max_len is 4, the result max_char_len will be 4.
Some(max_len.max(char.len_utf8())).filter(|len| *len < 4)
})
.unwrap_or(4);

string.reserve(max_char_len * len);
string.extend(self.sample_iter(rng).take(len))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the slice is large, this could take significant time (and the only purpose is to potentially reduce the max-char-len used for space reservation below 4). I suggest only iterating when the length is under some bound (1000 maybe?).

Further, a slice could contain one 4-byte character but mostly 1-byte chars so this could massively over-reserve (perhaps significant if len is large). It isn't obvious what the best approach is (or how to test — it could be very context dependent — perhaps try an approach which isn't terrible anywhere rather than to perfectly optimise).

I mean it's arguable whether just reserving len bytes would be better. Or you could sample len/4, check the length and repeat... but probably over-complex.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the goal here? How many allocations are acceptable? We could always have conservative lower limit and let extend deal with the additional allocations. If we reserve the minimal len/4, this should result in at most 3 allocations, right? I think this is good enough for a "general-purpose" implementation.

Are there use cases where the slice is large? Typically, I would expect the "alphabet" to be small, unless it is abused for weighted sampling, for which we have better implementations. Another use case could be "valid Unicode except a few characters", but for this, rejection sampling should be used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically, I would expect the "alphabet" to be small

As would I, but it could be large and this is easy to test. The simplest solution would be to measure max-char-size only where the alphabet is reasonably small, e.g. under 200 items.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As pointed out, I implemented not checking for long slices. The limit is currently 200 but I don't know what is the best size for this.
I also split the sampling if the sampling length is long or the slice contains not only ascii.

}
}
20 changes: 20 additions & 0 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,26 @@ impl UniformSampler for UniformChar {
}
}

/// Note: the `String` is potentially left with excess capacity if the range
/// includes non ascii chars; optionally the user may call
/// `string.shrink_to_fit()` afterwards.
#[cfg(feature = "alloc")]
impl super::DistString for Uniform<char>{
fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) {
// Getting the hi value to assume the required length to reserve in string.
let mut hi = self.0.sampler.low + self.0.sampler.range;
if hi >= CHAR_SURROGATE_START {
hi += CHAR_SURROGATE_LEN;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Largest possible result is one less than this I think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have some tests for such corner cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fix this and add some test.

// Get the utf8 length of hi to minimize extra space.
// SAFETY: hi used to be valid char.
// This relies on range constructors which accept char arguments.
let max_char_len = unsafe { char::from_u32_unchecked(hi).len_utf8() };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe char::len_utf8 even cares whether hi is a valid char so this should be safe, but the use of unsafe still feels unnecessary. Suggestion:

let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have applied your suggestion.

string.reserve(max_char_len * len);
string.extend(self.sample_iter(rng).take(len))
}
}

/// The back-end implementing [`UniformSampler`] for floating-point types.
///
/// Unless you are implementing [`UniformSampler`] for your own type, this type
Expand Down