Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace internal buffer in decoder with BufRead #231

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/base64.rs
@@ -1,5 +1,5 @@
use std::fs::File;
use std::io::{self, Read};
use std::io::{self, BufRead, BufReader};
use std::path::PathBuf;
use std::process;
use std::str::FromStr;
Expand Down Expand Up @@ -48,7 +48,7 @@ struct Opt {
fn main() {
let opt = Opt::from_args();
let stdin;
let mut input: Box<dyn Read> = match opt.file {
let mut input: Box<dyn BufRead> = match opt.file {
None => {
stdin = io::stdin();
Box::new(stdin.lock())
Expand All @@ -57,7 +57,7 @@ fn main() {
stdin = io::stdin();
Box::new(stdin.lock())
}
Some(f) => Box::new(File::open(f).unwrap()),
Some(f) => Box::new(File::open(f).map(BufReader::new).unwrap()),
};

let alphabet = opt.alphabet.unwrap_or_default();
Expand Down
190 changes: 77 additions & 113 deletions src/read/decoder.rs
@@ -1,9 +1,6 @@
use crate::{engine::Engine, DecodeError};
use std::{cmp, fmt, io};

// This should be large, but it has to fit on the stack.
pub(crate) const BUF_SIZE: usize = 1024;

// 4 bytes of base64 data encode 3 bytes of raw data (modulo padding).
const BASE64_CHUNK_SIZE: usize = 4;
const DECODED_CHUNK_SIZE: usize = 3;
Expand All @@ -30,17 +27,11 @@ const DECODED_CHUNK_SIZE: usize = 3;
/// assert_eq!(b"asdf", &result[..]);
///
/// ```
pub struct DecoderReader<'e, E: Engine, R: io::Read> {
pub struct DecoderReader<'e, E: Engine, R: io::BufRead> {
engine: &'e E,
/// Where b64 data is read from
inner: R,

// Holds b64 data read from the delegate reader.
b64_buffer: [u8; BUF_SIZE],
// The start of the pending buffered data in b64_buffer.
b64_offset: usize,
// The amount of buffered b64 data.
b64_len: usize,
// Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a
// decoded chunk in to, we have to be able to hang on to a few decoded bytes.
// Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to
Expand All @@ -55,11 +46,9 @@ pub struct DecoderReader<'e, E: Engine, R: io::Read> {
total_b64_decoded: usize,
}

impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> {
impl<'e, E: Engine, R: io::BufRead> fmt::Debug for DecoderReader<'e, E, R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("DecoderReader")
.field("b64_offset", &self.b64_offset)
.field("b64_len", &self.b64_len)
.field("decoded_buffer", &self.decoded_buffer)
.field("decoded_offset", &self.decoded_offset)
.field("decoded_len", &self.decoded_len)
Expand All @@ -68,15 +57,12 @@ impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> {
}
}

impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
impl<'e, E: Engine, R: io::BufRead> DecoderReader<'e, E, R> {
/// Create a new decoder that will read from the provided reader `r`.
pub fn new(reader: R, engine: &'e E) -> Self {
DecoderReader {
engine,
inner: reader,
b64_buffer: [0; BUF_SIZE],
b64_offset: 0,
b64_len: 0,
decoded_buffer: [0; DECODED_CHUNK_SIZE],
decoded_offset: 0,
decoded_len: 0,
Expand Down Expand Up @@ -107,59 +93,6 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
Ok(copy_len)
}

/// Read into the remaining space in the buffer after the current contents.
/// Must only be called when there is space to read into in the buffer.
/// Returns the number of bytes read.
fn read_from_delegate(&mut self) -> io::Result<usize> {
debug_assert!(self.b64_offset + self.b64_len < BUF_SIZE);

let read = self
.inner
.read(&mut self.b64_buffer[self.b64_offset + self.b64_len..])?;
self.b64_len += read;

debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);

Ok(read)
}

/// Decode the requested number of bytes from the b64 buffer into the provided buffer. It's the
/// caller's responsibility to choose the number of b64 bytes to decode correctly.
///
/// Returns a Result with the number of decoded bytes written to `buf`.
fn decode_to_buf(&mut self, num_bytes: usize, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.b64_len >= num_bytes);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
debug_assert!(!buf.is_empty());

let decoded = self
.engine
.internal_decode(
&self.b64_buffer[self.b64_offset..self.b64_offset + num_bytes],
buf,
self.engine.internal_decoded_len_estimate(num_bytes),
)
.map_err(|e| match e {
DecodeError::InvalidByte(offset, byte) => {
DecodeError::InvalidByte(self.total_b64_decoded + offset, byte)
}
DecodeError::InvalidLength => DecodeError::InvalidLength,
DecodeError::InvalidLastSymbol(offset, byte) => {
DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte)
}
DecodeError::InvalidPadding => DecodeError::InvalidPadding,
})
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;

self.total_b64_decoded += num_bytes;
self.b64_offset += num_bytes;
self.b64_len -= num_bytes;

debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);

Ok(decoded)
}

/// Unwraps this `DecoderReader`, returning the base reader which it reads base64 encoded
/// input from.
///
Expand All @@ -171,7 +104,23 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
}
}

impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> {
fn map_error_offset(total_b64_decoded: usize) -> impl FnOnce(DecodeError) -> io::Error {
move |error| {
let error = match error {
DecodeError::InvalidByte(offset, byte) => {
DecodeError::InvalidByte(total_b64_decoded + offset, byte)
}
DecodeError::InvalidLength => DecodeError::InvalidLength,
DecodeError::InvalidLastSymbol(offset, byte) => {
DecodeError::InvalidLastSymbol(total_b64_decoded + offset, byte)
}
DecodeError::InvalidPadding => DecodeError::InvalidPadding,
};
io::Error::new(io::ErrorKind::InvalidData, error)
}
}

impl<'e, E: Engine, R: io::BufRead> io::Read for DecoderReader<'e, E, R> {
/// Decode input from the wrapped reader.
///
/// Under non-error circumstances, this returns `Ok` with the value being the number of bytes
Expand All @@ -189,15 +138,6 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> {
return Ok(0);
}

// offset == BUF_SIZE when we copied it all last time
debug_assert!(self.b64_offset <= BUF_SIZE);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
debug_assert!(if self.b64_offset == BUF_SIZE {
self.b64_len == 0
} else {
self.b64_len <= BUF_SIZE
});

debug_assert!(if self.decoded_len == 0 {
// can be = when we were able to copy the complete chunk
self.decoded_offset <= DECODED_CHUNK_SIZE
Expand All @@ -215,61 +155,74 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> {
// we have a few leftover decoded bytes; flush that rather than pull in more b64
self.flush_decoded_buf(buf)
} else {
let mut at_eof = false;
while self.b64_len < BASE64_CHUNK_SIZE {
// Work around lack of copy_within, which is only present in 1.37
// Copy any bytes we have to the start of the buffer.
// We know we have < 1 chunk, so we can use a tiny tmp buffer.
let mut memmove_buf = [0_u8; BASE64_CHUNK_SIZE];
memmove_buf[..self.b64_len].copy_from_slice(
&self.b64_buffer[self.b64_offset..self.b64_offset + self.b64_len],
);
self.b64_buffer[0..self.b64_len].copy_from_slice(&memmove_buf[..self.b64_len]);
self.b64_offset = 0;
let mut b64_bytes = self.inner.fill_buf()?;

// then fill in more data
let read = self.read_from_delegate()?;
if read == 0 {
// we never pass in an empty buf, so 0 => we've hit EOF
at_eof = true;
break;
}
}

if self.b64_len == 0 {
debug_assert!(at_eof);
// we must be at EOF, and we have no data left to decode
if b64_bytes.is_empty() {
return Ok(0);
};

let mut b64_bytes_tmp;
let mut at_eof = false;
let mut short = false;
if b64_bytes.len() < BASE64_CHUNK_SIZE {
short = true;
// Read as much as we can, trying to have a full chunk.
b64_bytes_tmp = [0; BASE64_CHUNK_SIZE];
b64_bytes_tmp[..b64_bytes.len()].copy_from_slice(b64_bytes);
let mut pos = b64_bytes.len();
self.inner.consume(pos);
while pos < BASE64_CHUNK_SIZE {
let bytes_read = match self.inner.read(&mut b64_bytes_tmp[pos..]) {
Ok(len) => len,
Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
Err(error) => return Err(error),
};
if bytes_read == 0 {
at_eof = true;
break;
}
pos += bytes_read;
}
b64_bytes = &b64_bytes_tmp[..pos];
}

debug_assert!(if at_eof {
// if we are at eof, we may not have a complete chunk
self.b64_len > 0
b64_bytes.len() > 0
} else {
// otherwise, we must have at least one chunk
self.b64_len >= BASE64_CHUNK_SIZE
b64_bytes.len() >= BASE64_CHUNK_SIZE
});

debug_assert_eq!(0, self.decoded_len);

if buf.len() < DECODED_CHUNK_SIZE {
// caller requested an annoyingly short read
// have to write to a tmp buf first to avoid double mutable borrow
let mut decoded_chunk = [0_u8; DECODED_CHUNK_SIZE];
// if we are at eof, could have less than BASE64_CHUNK_SIZE, in which case we have
// to assume that these last few tokens are, in fact, valid (i.e. must be 2-4 b64
// tokens, not 1, since 1 token can't decode to 1 byte).
let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE);
let to_decode = cmp::min(b64_bytes.len(), BASE64_CHUNK_SIZE);
debug_assert!(b64_bytes.len() > BASE64_CHUNK_SIZE || to_decode == b64_bytes.len());

let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?;
self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]);
let decoded = self
.engine
.internal_decode(
&b64_bytes[..to_decode],
&mut self.decoded_buffer,
self.engine.internal_decoded_len_estimate(to_decode),
)
.map_err(map_error_offset(self.total_b64_decoded))?;

self.total_b64_decoded += to_decode;
if !short { self.inner.consume(to_decode); }

self.decoded_offset = 0;
self.decoded_len = decoded;

// can be less than 3 on last block due to padding
debug_assert!(decoded <= 3);


self.flush_decoded_buf(buf)
} else {
let b64_bytes_that_can_decode_into_buf = (buf.len() / DECODED_CHUNK_SIZE)
Expand All @@ -278,17 +231,28 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> {
debug_assert!(b64_bytes_that_can_decode_into_buf >= BASE64_CHUNK_SIZE);

let b64_bytes_available_to_decode = if at_eof {
self.b64_len
b64_bytes.len()
} else {
// only use complete chunks
self.b64_len - self.b64_len % 4
b64_bytes.len() - b64_bytes.len() % 4
};

let actual_decode_len = cmp::min(
b64_bytes_that_can_decode_into_buf,
b64_bytes_available_to_decode,
);
self.decode_to_buf(actual_decode_len, buf)
let decoded = self
.engine
.internal_decode(
&b64_bytes[..actual_decode_len],
buf,
self.engine.internal_decoded_len_estimate(actual_decode_len),
)
.map_err(map_error_offset(self.total_b64_decoded))?;

self.total_b64_decoded += actual_decode_len;
if !short { self.inner.consume(actual_decode_len); }
Ok(decoded)
}
}
}
Expand Down
22 changes: 19 additions & 3 deletions src/read/decoder_tests.rs
Expand Up @@ -6,13 +6,15 @@ use std::{

use rand::{Rng as _, RngCore as _};

use super::decoder::{DecoderReader, BUF_SIZE};
use super::decoder::DecoderReader;
use crate::{
engine::{general_purpose::STANDARD, Engine, GeneralPurpose},
tests::{random_alphabet, random_config, random_engine},
DecodeError,
};

const BUF_SIZE: usize = 1024;

#[test]
fn simple() {
let tests: &[(&[u8], &[u8])] = &[
Expand Down Expand Up @@ -113,7 +115,6 @@ fn handles_short_read_from_delegate() {
};

let mut decoder = DecoderReader::new(&mut short_reader, &engine);

let decoded_len = decoder.read_to_end(&mut decoded).unwrap();
assert_eq!(size, decoded_len);
assert_eq!(&bytes[..], &decoded[..]);
Expand Down Expand Up @@ -341,6 +342,21 @@ impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R,
// avoid 0 since it means EOF for non-empty buffers
let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len());

self.delegate.read(&mut buf[..effective_len])
self.delegate.read(&mut buf [..effective_len])
}
}

impl<'a, 'b, R: io::BufRead, N: rand::Rng> io::BufRead for RandomShortRead<'a, 'b, R, N> {
fn fill_buf(&mut self) -> Result<&[u8], io::Error> {
self.delegate.fill_buf().map(|buf| {
// avoid 0 since it means EOF for non-empty buffers
let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len());

&buf[..effective_len]
})
}

fn consume(&mut self, amount: usize) {
self.delegate.consume(amount)
}
}