Skip to content

Commit

Permalink
Fix DistString impl.
Browse files Browse the repository at this point in the history
  • Loading branch information
aobatact committed May 31, 2023
1 parent 7b5a417 commit 4d4f34d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
33 changes: 23 additions & 10 deletions src/distributions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,29 @@ impl std::error::Error for EmptySlice {}
#[cfg(feature = "alloc")]
impl<'a> super::DistString for Slice<'a, char> {
fn append_string<R: crate::Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize) {
let max_char_len = self
.slice
.iter()
.try_fold(1, |max_len, char| {
// When the current max_len is 4, the result max_char_len will be 4.
Some(max_len.max(char.len_utf8())).filter(|len| *len < 4)
})
.unwrap_or(4);
// Get the max char length to minimize extra space.
// Limit this check to avoid searching for long slice.
let max_char_len = if self.slice.len() < 200 {
self.slice
.iter()
.try_fold(1, |max_len, char| {
// When the current max_len is 4, the result max_char_len will be 4.
Some(max_len.max(char.len_utf8())).filter(|len| *len < 4)
})
.unwrap_or(4)
} else {
4
};

string.reserve(max_char_len * len);
string.extend(self.sample_iter(rng).take(len))
// Split the extension of string to reuse the unused capacities.
// Skip the split for small length or only ascii slice.
let mut extend_len = if max_char_len == 1 || len < 100 { len } else { len / 4 };
let mut remain_len = len;
while extend_len > 0 {
string.reserve(max_char_len * extend_len);
string.extend(self.sample_iter(&mut *rng).take(extend_len));
remain_len -= extend_len;
extend_len = extend_len.min(remain_len);
}
}
}
22 changes: 18 additions & 4 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,14 +850,12 @@ impl UniformSampler for UniformChar {
impl super::DistString for Uniform<char>{
fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) {
// Getting the hi value to assume the required length to reserve in string.
let mut hi = self.0.sampler.low + self.0.sampler.range;
let mut hi = self.0.sampler.low + self.0.sampler.range - 1;
if hi >= CHAR_SURROGATE_START {
hi += CHAR_SURROGATE_LEN;
}
// Get the utf8 length of hi to minimize extra space.
// SAFETY: hi used to be valid char.
// This relies on range constructors which accept char arguments.
let max_char_len = unsafe { char::from_u32_unchecked(hi).len_utf8() };
let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4);
string.reserve(max_char_len * len);
string.extend(self.sample_iter(rng).take(len))
}
Expand Down Expand Up @@ -1396,6 +1394,22 @@ mod tests {
let c = d.sample(&mut rng);
assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF);
}
#[cfg(feature = "alloc")]
{
use crate::distributions::DistString;
let string1 = d.sample_string(&mut rng, 100);
assert_eq!(string1.capacity(), 300);
let string2 = Uniform::new(
core::char::from_u32(0x0000).unwrap(),
core::char::from_u32(0x0080).unwrap(),
).unwrap().sample_string(&mut rng, 100);
assert_eq!(string2.capacity(), 100);
let string3 = Uniform::new_inclusive(
core::char::from_u32(0x0000).unwrap(),
core::char::from_u32(0x0080).unwrap(),
).unwrap().sample_string(&mut rng, 100);
assert_eq!(string3.capacity(), 200);
}
}

#[test]
Expand Down

0 comments on commit 4d4f34d

Please sign in to comment.