From d30c85ef5280c5f97fafe5fb1888f3e520c3be46 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Sun, 23 Dec 2018 15:19:42 -0700 Subject: [PATCH] Streaming encoder: keep track of encoded bytes that weren't written. They'll be retried on subsequent writes, but this plays poorly with write_all. See https://github.com/rust-lang/rust/issues/56889. Hat tip to #90. --- RELEASE-NOTES.md | 4 + src/write/encoder.rs | 196 +++++++++++++++++++++++++++---------- src/write/encoder_tests.rs | 109 ++++++++++++++++++--- 3 files changed, 247 insertions(+), 62 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 39f80e1d..1a993477 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -2,6 +2,10 @@ - Minimum rust version 1.27.2 +# 0.10.1 + +- Fix bug in streaming encoding ([#90](https://github.com/alicemaz/rust-base64/pull/90)): if the underlying writer didn't write all the bytes given to it, the remaining bytes would not be retried later. See the docs on `EncoderWriter::write`. + # 0.10.0 - Remove line wrapping. Line wrapping was never a great conceptual fit in this library, and other features (streaming encoding, etc) either couldn't support it or could support only special cases of it with a great increase in complexity. Line wrapping has been pulled out into a [line-wrap](https://crates.io/crates/line-wrap) crate, so it's still available if you need it. diff --git a/src/write/encoder.rs b/src/write/encoder.rs index b54fca45..23de075e 100644 --- a/src/write/encoder.rs +++ b/src/write/encoder.rs @@ -1,5 +1,5 @@ use encode::encode_to_slice; -use std::io::{Result, Write}; +use std::io::{ErrorKind, Result, Write}; use std::{cmp, fmt}; use {encode_config_slice, Config}; @@ -60,11 +60,14 @@ pub struct EncoderWriter<'a, W: 'a + Write> { w: &'a mut W, /// Holds a partial chunk, if any, after the last `write()`, so that we may then fill the chunk /// with the next `write()`, encode it, then proceed with the rest of the input normally. - extra: [u8; MIN_ENCODE_CHUNK_SIZE], + extra_input: [u8; MIN_ENCODE_CHUNK_SIZE], /// How much of `extra` is occupied, in `[0, MIN_ENCODE_CHUNK_SIZE]`. - extra_len: usize, - /// Buffer to encode into. + extra_input_occupied_len: usize, + /// Buffer to encode into. May hold leftover encoded bytes from a previous write call that the underlying writer + /// did not write last time. output: [u8; BUF_SIZE], + /// How much of `output` is occupied with encoded data that couldn't be written last time + output_occupied_len: usize, /// True iff padding / partial last chunk has been written. finished: bool, /// panic safety: don't write again in destructor if writer panicked while we were writing to it @@ -75,23 +78,25 @@ impl<'a, W: Write> fmt::Debug for EncoderWriter<'a, W> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "extra:{:?} extra_len:{:?} output[..5]: {:?}", - self.extra, - self.extra_len, - &self.output[0..5] + "extra_input: {:?} extra_input_occupied_len:{:?} output[..5]: {:?} output_occupied_len: {:?}", + self.extra_input, + self.extra_input_occupied_len, + &self.output[0..5], + self.output_occupied_len ) } } impl<'a, W: Write> EncoderWriter<'a, W> { - /// Create a new encoder around an existing writer. + /// Create a new encoder that will write to the provided delegate writer `w`. pub fn new(w: &'a mut W, config: Config) -> EncoderWriter<'a, W> { EncoderWriter { config, w, - extra: [0u8; MIN_ENCODE_CHUNK_SIZE], - extra_len: 0, + extra_input: [0u8; MIN_ENCODE_CHUNK_SIZE], + extra_input_occupied_len: 0, output: [0u8; BUF_SIZE], + output_occupied_len: 0, finished: false, panicked: false, } @@ -103,34 +108,112 @@ impl<'a, W: Write> EncoderWriter<'a, W> { /// Once this succeeds, no further writes can be performed, as that would produce invalid /// base64. /// + /// This may write to the delegate writer multiple times if the delegate writer does not accept all input provided + /// to its `write` each invocation. + /// /// # Errors /// - /// Assuming the wrapped writer obeys the `Write` contract, if this returns `Err`, no data was - /// written, and `finish()` may be retried if appropriate for the type of error, etc. + /// The first error that is not of [`ErrorKind::Interrupted`] will be returned. pub fn finish(&mut self) -> Result<()> { if self.finished { return Ok(()); }; - if self.extra_len > 0 { + self.write_all_encoded_output()?; + + if self.extra_input_occupied_len > 0 { let encoded_len = encode_config_slice( - &self.extra[..self.extra_len], + &self.extra_input[..self.extra_input_occupied_len], self.config, &mut self.output[..], ); - self.panicked = true; - let _ = self.w.write(&self.output[..encoded_len])?; - self.panicked = false; + + self.output_occupied_len = encoded_len; + + self.write_all_encoded_output()?; + // write succeeded, do not write the encoding of extra again if finish() is retried - self.extra_len = 0; + self.extra_input_occupied_len = 0; } self.finished = true; Ok(()) } + + /// Write as much of the encoded output to the delegate writer as it will accept, and store the + /// leftovers to be attempted at the next write() call. Updates `self.output_occupied_len`. + /// + /// # Errors + /// + /// Errors from the delegate writer are returned. In the case of an error, + /// `self.output_occupied_len` will not be updated, as errors from `write` are specified to mean + /// that no write took place. + fn write_to_delegate(&mut self, current_output_len: usize) -> Result<()> { + self.panicked = true; + let res = self.w.write(&self.output[..current_output_len]); + self.panicked = false; + + return res.map(|consumed| { + debug_assert!(consumed <= current_output_len); + + if consumed < current_output_len { + self.output_occupied_len = current_output_len.checked_sub(consumed).unwrap(); + // If we're blocking on I/O, the minor inefficiency of copying bytes to the + // start of the buffer is the least of our concerns... + // Rotate moves more than we need to, but copy_within isn't stabilized yet. + self.output.rotate_left(consumed); + } else { + self.output_occupied_len = 0; + } + + () + }); + } + + /// Write all buffered encoded output. If this returns `Ok`, `self.output_occupied_len` is `0`. + /// + /// This is basically write_all for the remaining buffered data but without the undesirable + /// abort-on-`Ok(0)` behavior. + /// + /// # Errors + /// + /// Any error emitted by the delegate writer abort the write loop and is returned, unless it's + /// `Interrupted`, in which case the error is ignored and writes will continue. + fn write_all_encoded_output(&mut self) -> Result<()> { + while self.output_occupied_len > 0 { + let remaining_len = self.output_occupied_len; + match self.write_to_delegate(remaining_len) { + // try again on interrupts ala write_all + Err(ref e) if e.kind() == ErrorKind::Interrupted => {} + // other errors return + Err(e) => return Err(e), + // success no-ops because remaining length is already updated + Ok(_) => {} + }; + } + + debug_assert_eq!(0, self.output_occupied_len); + Ok(()) + } } impl<'a, W: Write> Write for EncoderWriter<'a, W> { + /// Encode input and then write to the delegate writer. + /// + /// Under non-error circumstances, this returns `Ok` with the value being the number of bytes + /// of `input` consumed. The value may be `0`, which interacts poorly with `write_all`, which + /// interprets `Ok(0)` as an error, despite it being allowed by the contract of `write`. See + /// https://github.com/rust-lang/rust/issues/56889 for more on that. + /// + /// If the previous call to `write` provided more (encoded) data than the delegate writer could + /// accept in a single call to its `write`, the remaining data is buffered. As long as buffered + /// data is present, subsequent calls to `write` will try to write the remaining buffered data + /// to the delegate and return either `Ok(0)` -- and therefore not consume any of `input` -- or + /// an error. + /// + /// # Errors + /// + /// Any errors emitted by the delegate writer are returned. fn write(&mut self, input: &[u8]) -> Result { if self.finished { panic!("Cannot write more after calling finish()"); @@ -146,34 +229,45 @@ impl<'a, W: Write> Write for EncoderWriter<'a, W> { // - Errors mean that "no bytes were written to this writer", so we need to reset the // internal state to what it was before the error occurred + // before reading any input, write any leftover encoded output from last time + if self.output_occupied_len > 0 { + let current_len = self.output_occupied_len; + return self.write_to_delegate(current_len) + // did not read any input + .map(|_| 0) + + } + + debug_assert_eq!(0, self.output_occupied_len); + // how many bytes, if any, were read into `extra` to create a triple to encode let mut extra_input_read_len = 0; let mut input = input; - let orig_extra_len = self.extra_len; + let orig_extra_len = self.extra_input_occupied_len; let mut encoded_size = 0; // always a multiple of MIN_ENCODE_CHUNK_SIZE let mut max_input_len = MAX_INPUT_LEN; - // process leftover stuff from last write - if self.extra_len > 0 { - debug_assert!(self.extra_len < 3); - if input.len() + self.extra_len >= MIN_ENCODE_CHUNK_SIZE { + // process leftover un-encoded input from last write + if self.extra_input_occupied_len > 0 { + debug_assert!(self.extra_input_occupied_len < 3); + if input.len() + self.extra_input_occupied_len >= MIN_ENCODE_CHUNK_SIZE { // Fill up `extra`, encode that into `output`, and consume as much of the rest of // `input` as possible. // We could write just the encoding of `extra` by itself but then we'd have to // return after writing only 4 bytes, which is inefficient if the underlying writer // would make a syscall. - extra_input_read_len = MIN_ENCODE_CHUNK_SIZE - self.extra_len; + extra_input_read_len = MIN_ENCODE_CHUNK_SIZE - self.extra_input_occupied_len; debug_assert!(extra_input_read_len > 0); // overwrite only bytes that weren't already used. If we need to rollback extra_len // (when the subsequent write errors), the old leading bytes will still be there. - self.extra[self.extra_len..MIN_ENCODE_CHUNK_SIZE] + self.extra_input[self.extra_input_occupied_len..MIN_ENCODE_CHUNK_SIZE] .copy_from_slice(&input[0..extra_input_read_len]); let len = encode_to_slice( - &self.extra[0..MIN_ENCODE_CHUNK_SIZE], + &self.extra_input[0..MIN_ENCODE_CHUNK_SIZE], &mut self.output[..], self.config.char_set.encode_table(), ); @@ -182,38 +276,40 @@ impl<'a, W: Write> Write for EncoderWriter<'a, W> { input = &input[extra_input_read_len..]; // consider extra to be used up, since we encoded it - self.extra_len = 0; + self.extra_input_occupied_len = 0; // don't clobber where we just encoded to encoded_size = 4; // and don't read more than can be encoded max_input_len = MAX_INPUT_LEN - MIN_ENCODE_CHUNK_SIZE; - // fall through to normal encoding + // fall through to normal encoding } else { // `extra` and `input` are non empty, but `|extra| + |input| < 3`, so there must be // 1 byte in each. debug_assert_eq!(1, input.len()); - debug_assert_eq!(1, self.extra_len); + debug_assert_eq!(1, self.extra_input_occupied_len); - self.extra[self.extra_len] = input[0]; - self.extra_len += 1; + self.extra_input[self.extra_input_occupied_len] = input[0]; + self.extra_input_occupied_len += 1; return Ok(1); }; } else if input.len() < MIN_ENCODE_CHUNK_SIZE { // `extra` is empty, and `input` fits inside it - self.extra[0..input.len()].copy_from_slice(input); - self.extra_len = input.len(); + self.extra_input[0..input.len()].copy_from_slice(input); + self.extra_input_occupied_len = input.len(); return Ok(input.len()); }; // either 0 or 1 complete chunks encoded from extra debug_assert!(encoded_size == 0 || encoded_size == 4); debug_assert!( - MAX_INPUT_LEN - max_input_len == 0 - || MAX_INPUT_LEN - max_input_len == MIN_ENCODE_CHUNK_SIZE + // didn't encode extra input + MAX_INPUT_LEN == max_input_len + // encoded one triple + || MAX_INPUT_LEN == max_input_len + MIN_ENCODE_CHUNK_SIZE ); - // handle complete triples + // encode complete triples only let input_complete_chunks_len = input.len() - (input.len() % MIN_ENCODE_CHUNK_SIZE); let input_chunks_to_encode_len = cmp::min(input_complete_chunks_len, max_input_len); debug_assert_eq!(0, max_input_len % MIN_ENCODE_CHUNK_SIZE); @@ -224,25 +320,27 @@ impl<'a, W: Write> Write for EncoderWriter<'a, W> { &mut self.output[encoded_size..], self.config.char_set.encode_table(), ); - self.panicked = true; - let r = self.w.write(&self.output[..encoded_size]); - self.panicked = false; - match r { - Ok(_) => Ok(extra_input_read_len + input_chunks_to_encode_len), - Err(_) => { + + // not updating `self.output_occupied_len` here because if the below write fails, it should + // "never take place" -- the buffer contents we encoded are ignored and perhaps retried + // later, if the consumer chooses. + + self.write_to_delegate(encoded_size) + // no matter whether we wrote the full encoded buffer or not, we consumed the same + // input + .map(|_| extra_input_read_len + input_chunks_to_encode_len) + .map_err( |e| { // in case we filled and encoded `extra`, reset extra_len - self.extra_len = orig_extra_len; - r - } - } + self.extra_input_occupied_len = orig_extra_len; - // we could hypothetically copy a few more bytes into `extra` but the extra 1-2 bytes - // are not worth all the complexity (and branches) + e + }) } /// Because this is usually treated as OK to call multiple times, it will *not* flush any /// incomplete chunks of input or write padding. fn flush(&mut self) -> Result<()> { + self.write_all_encoded_output()?; self.w.flush() } } diff --git a/src/write/encoder_tests.rs b/src/write/encoder_tests.rs index 6897c5cd..681235b5 100644 --- a/src/write/encoder_tests.rs +++ b/src/write/encoder_tests.rs @@ -209,8 +209,7 @@ fn write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk() { } #[test] -fn write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining( -) { +fn write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining() { let mut c = Cursor::new(Vec::new()); { let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD); @@ -246,8 +245,7 @@ fn write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complet } #[test] -fn write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks( -) { +fn write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks() { let mut c = Cursor::new(Vec::new()); { let mut enc = EncoderWriter::new(&mut c, STANDARD_NO_PAD); @@ -359,7 +357,6 @@ fn retrying_writes_that_error_with_interrupted_works() { // when errors occur let input_len: usize = cmp::min(rng.gen_range(0, 10), orig_len - bytes_consumed); - // write a little bit of the data retry_interrupted_write_all( &mut stream_encoder, &orig_data[bytes_consumed..bytes_consumed + input_len], @@ -386,21 +383,77 @@ fn retrying_writes_that_error_with_interrupted_works() { } } +#[test] +fn writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data() { + let mut rng = rand::thread_rng(); + let mut orig_data = Vec::::new(); + let mut stream_encoded = Vec::::new(); + let mut normal_encoded = String::new(); + + for _ in 0..1_000 { + orig_data.clear(); + stream_encoded.clear(); + normal_encoded.clear(); + + let orig_len: usize = rng.gen_range(100, 20_000); + for _ in 0..orig_len { + orig_data.push(rng.gen()); + } + + // encode the normal way + let config = random_config(&mut rng); + encode_config_buf(&orig_data, config, &mut normal_encoded); + + // encode via the stream encoder + { + let mut partial_rng = rand::thread_rng(); + let mut partial_writer = PartialInterruptingWriter { + w: &mut stream_encoded, + rng: &mut partial_rng, + full_input_fraction: 0.1, + no_interrupt_fraction: 0.1 + }; + + let mut stream_encoder = EncoderWriter::new(&mut partial_writer, config); + let mut bytes_consumed = 0; + while bytes_consumed < orig_len { + // use at most medium-length inputs to exercise retry logic more aggressively + let input_len: usize = cmp::min(rng.gen_range(0, 100), orig_len - bytes_consumed); + + let res = stream_encoder.write(&orig_data[bytes_consumed..bytes_consumed + input_len]); + + // retry on interrupt + match res { + Ok(len) => bytes_consumed += len, + Err(e) => match e.kind() { + io::ErrorKind::Interrupted => continue, + _ => { panic!("should not see other errors"); } + }, + } + }; + + stream_encoder.finish().unwrap(); + + assert_eq!(orig_len, bytes_consumed); + } + + assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap()); + } +} + + /// Retry writes until all the data is written or an error that isn't Interrupted is returned. fn retry_interrupted_write_all(w: &mut W, buf: &[u8]) -> io::Result<()> { - let mut written = 0; + let mut bytes_consumed = 0; - while written < buf.len() { - let res = w.write(&buf[written..]); + while bytes_consumed < buf.len() { + let res = w.write(&buf[bytes_consumed..]); match res { - Ok(len) => written += len, + Ok(len) => bytes_consumed += len, Err(e) => match e.kind() { io::ErrorKind::Interrupted => continue, - _ => { - println!("got kind: {:?}", e.kind()); - return Err(e); - } + _ => { return Err(e) } }, } } @@ -479,3 +532,33 @@ impl<'a, W: Write, R: Rng> Write for InterruptingWriter<'a, W, R> { self.w.flush() } } + +/// A `Write` implementation that sometimes will only write part of its input. +struct PartialInterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> { + w: &'a mut W, + rng: &'a mut R, + /// In [0, 1]. If a random number in [0, 1] is `<= threshold`, `write()` will write all its + /// input. Otherwise, it will write a random substring + full_input_fraction: f64, + no_interrupt_fraction: f64 +} + +impl<'a, W: Write, R: Rng> Write for PartialInterruptingWriter<'a, W, R> { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.rng.gen_range(0.0, 1.0) > self.no_interrupt_fraction{ + return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted")); + } + + if self.rng.gen_range(0.0, 1.0) <= self.full_input_fraction || buf.len() == 0 { + // pass through the buf untouched + self.w.write(buf) + } else { + // only use a prefix of it + self.w.write(&buf[0..(self.rng.gen_range(0, buf.len() - 1))]) + } + } + + fn flush(&mut self) -> io::Result<()> { + self.w.flush() + } +} \ No newline at end of file