Skip to content

Commit

Permalink
Add another test
Browse files Browse the repository at this point in the history
  • Loading branch information
marshallpierce committed Mar 8, 2020
1 parent 61af8bc commit 0a020ee
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 29 deletions.
25 changes: 15 additions & 10 deletions src/read/decoder.rs
Expand Up @@ -126,7 +126,7 @@ impl<'a, R: io::Read> DecoderReader<'a, R> {
/// Decode the requested number of bytes from the b64 buffer into the provided buffer. It's the
/// caller's responsibility to choose the number of b64 bytes to decode correctly.
///
/// Returns a Result with the number of decoded bytes written.
/// Returns a Result with the number of decoded bytes written to `buf`.
fn decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.b64_len >= num_bytes);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
Expand All @@ -149,8 +149,8 @@ impl<'a, R: io::Read> DecoderReader<'a, R> {
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;

self.total_b64_decoded += num_bytes;
self.b64_len -= num_bytes;
self.b64_offset += num_bytes;
self.b64_len -= num_bytes;

debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);

Expand Down Expand Up @@ -224,6 +224,7 @@ impl<'a, R: Read> Read for DecoderReader<'a, R> {
}

if self.b64_len == 0 {
debug_assert!(at_eof);
// we must be at EOF, and we have no data left to decode
return Ok(0);
};
Expand All @@ -236,13 +237,15 @@ impl<'a, R: Read> Read for DecoderReader<'a, R> {
self.b64_len >= BASE64_CHUNK_SIZE
});

debug_assert_eq!(0, self.decoded_len);

if buf.len() < DECODED_CHUNK_SIZE {
// caller requested an annoyingly short read
debug_assert_eq!(0, self.decoded_len);

// have to write to a tmp buf first to avoid double mutable borrow
let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE];
// if we are at eof, could have less than BASE64_CHUNK_SIZE
// if we are at eof, could have less than BASE64_CHUNK_SIZE, in which case we have
// to assume that these last few tokens are, in fact, valid (i.e. must be 2-4 b64
// tokens, not 1, since 1 token can't decode to 1 byte).
let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE);

let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?;
Expand All @@ -256,20 +259,22 @@ impl<'a, R: Read> Read for DecoderReader<'a, R> {

self.flush_decoded_buf(buf)
} else {
let bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
.checked_mul(BASE64_CHUNK_SIZE)
.expect("too many chunks");
debug_assert!(bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);
debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);

let bytes_available_to_decode = if at_eof {
let b64_bytes_available_to_decode = if at_eof {
self.b64_len
} else {
// only use complete chunks
self.b64_len - self.b64_len % 4
};

let actual_decode_len =
cmp::min(bytes_that_can_decode_into_buf, bytes_available_to_decode);
let actual_decode_len = cmp::min(
b64_bytes_that_can_decode_into_buf,
b64_bytes_available_to_decode,
);
self.decode_to_buf(actual_decode_len, buf)
}
}
Expand Down
87 changes: 68 additions & 19 deletions src/read/decoder_tests.rs
Expand Up @@ -142,25 +142,42 @@ fn read_in_short_increments() {
let mut wrapped_reader = io::Cursor::new(&b64[..]);
let mut decoder = DecoderReader::new(&mut wrapped_reader, config);

let mut total_read = 0_usize;
loop {
assert!(total_read <= size, "tr {} size {}", total_read, size);
if total_read == size {
assert_eq!(bytes, &decoded[..total_read]);
// should be done
assert_eq!(0, decoder.read(&mut decoded[..]).unwrap());
// didn't write anything
assert_eq!(bytes, &decoded[..total_read]);

break;
}
let decode_len = rng.gen_range(1, cmp::max(2, size * 2));
consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder);
}
}

let read = decoder
.read(&mut decoded[total_read..total_read + decode_len])
.unwrap();
total_read += read;
}
#[test]
fn read_in_short_increments_with_short_delegate_reads() {
let mut rng = rand::thread_rng();
let mut bytes = Vec::new();
let mut b64 = String::new();
let mut decoded = Vec::new();

for _ in 0..10_000 {
bytes.clear();
b64.clear();
decoded.clear();

let size = rng.gen_range(0, 10 * BUF_SIZE);
bytes.extend(iter::repeat(0).take(size));
// leave room to play around with larger buffers
decoded.extend(iter::repeat(0).take(size * 3));

rng.fill_bytes(&mut bytes[..]);
assert_eq!(size, bytes.len());

let config = random_config(&mut rng);

encode_config_buf(&bytes[..], config, &mut b64);

let mut base_reader = io::Cursor::new(&b64[..]);
let mut decoder = DecoderReader::new(&mut base_reader, config);
let mut short_reader = RandomShortRead {
delegate: &mut decoder,
rng: &mut rand::thread_rng(),
};

consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut short_reader)
}
}

Expand Down Expand Up @@ -268,6 +285,38 @@ fn reports_invalid_byte_correctly() {
}
}

fn consume_with_short_reads_and_validate<R: Read>(
rng: &mut rand::rngs::ThreadRng,
expected_bytes: &[u8],
decoded: &mut Vec<u8>,
short_reader: &mut R,
) -> () {
let mut total_read = 0_usize;
loop {
assert!(
total_read <= expected_bytes.len(),
"tr {} size {}",
total_read,
expected_bytes.len()
);
if total_read == expected_bytes.len() {
assert_eq!(expected_bytes, &decoded[..total_read]);
// should be done
assert_eq!(0, short_reader.read(&mut decoded[..]).unwrap());
// didn't write anything
assert_eq!(expected_bytes, &decoded[..total_read]);

break;
}
let decode_len = rng.gen_range(1, cmp::max(2, expected_bytes.len() * 2));

let read = short_reader
.read(&mut decoded[total_read..total_read + decode_len])
.unwrap();
total_read += read;
}
}

/// Limits how many bytes a reader will provide in each read call.
/// Useful for shaking out code that may work fine only with typical input sources that always fill
/// the buffer.
Expand All @@ -279,7 +328,7 @@ struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> {
impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
// avoid 0 since it means EOF for non-empty buffers
let effective_len = self.rng.gen_range(1, 20);
let effective_len = cmp::min(self.rng.gen_range(1, 20), buf.len());

self.delegate.read(&mut buf[..effective_len])
}
Expand Down

0 comments on commit 0a020ee

Please sign in to comment.