Skip to content

Commit

Permalink
Special error types for BinDecoder and read_inner
Browse files Browse the repository at this point in the history
  • Loading branch information
saethlin committed Mar 10, 2021
1 parent 7e55a50 commit a1f94c2
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 66 deletions.
7 changes: 7 additions & 0 deletions crates/proto/src/error.rs
Expand Up @@ -26,6 +26,7 @@ use ring::error::Unspecified;
use thiserror::Error;

use crate::rr::{Name, RecordType};
use crate::serialize::binary::DecodeError;

#[cfg(feature = "backtrace")]
lazy_static! {
Expand Down Expand Up @@ -269,6 +270,12 @@ impl From<ProtoErrorKind> for ProtoError {
}
}

impl From<DecodeError> for ProtoError {
fn from(err: DecodeError) -> ProtoError {
ProtoErrorKind::Msg(err.to_string()).into()
}
}

impl From<&'static str> for ProtoError {
fn from(msg: &'static str) -> ProtoError {
ProtoErrorKind::Message(msg).into()
Expand Down
24 changes: 7 additions & 17 deletions crates/proto/src/rr/domain/name.rs
Expand Up @@ -1124,7 +1124,7 @@ fn read_inner(
label_data: &mut TinyVec<[u8; 32]>,
label_ends: &mut TinyVec<[u8; 24]>,
max_idx: Option<usize>,
) -> ProtoResult<()> {
) -> Result<(), DecodeError> {
let mut state: LabelParseState = LabelParseState::LabelLengthOrPointer;
let name_start = decoder.index();

Expand All @@ -1137,18 +1137,14 @@ fn read_inner(
// this protects against overlapping labels
if let Some(max_idx) = max_idx {
if decoder.index() >= max_idx {
return Err(ProtoErrorKind::LabelOverlapsWithOther {
label: name_start,
other: max_idx,
}
.into());
return Err(DecodeError::LabelOverlapsWithOther);
}
}

// enforce max length of name
let cur_len = label_data.len() + label_ends.len();
if cur_len > 255 {
return Err(ProtoErrorKind::DomainNameTooLong(cur_len).into());
return Err(DecodeError::DomainNameTooLong);
}

state = match state {
Expand All @@ -1161,15 +1157,15 @@ fn read_inner(
Some(0) | None => LabelParseState::Root,
Some(byte) if byte & 0b1100_0000 == 0b1100_0000 => LabelParseState::Pointer,
Some(byte) if byte & 0b1100_0000 == 0b0000_0000 => LabelParseState::Label,
Some(byte) => return Err(ProtoErrorKind::UnrecognizedLabelCode(byte).into()),
Some(_) => return Err(DecodeError::UnrecognizedLabelCode),
}
}
// labels must have a maximum length of 63
LabelParseState::Label => {
let label = decoder
.read_character_data_max(Some(63))?
.read_character_data()?
.verify_unwrap(|l| l.len() <= 63)
.map_err(|_| ProtoError::from("label exceeds maximum length of 63"))?;
.map_err(|_| DecodeError::LabelTooLong)?;

label_data.extend_from_slice(label);
label_ends.push(label_data.len() as u8);
Expand Down Expand Up @@ -1199,7 +1195,6 @@ fn read_inner(
// domain header). A zero offset specifies the first byte of the ID field,
// etc.
LabelParseState::Pointer => {
let pointer_location = decoder.index();
let location = decoder
.read_u16()?
.map(|u| {
Expand All @@ -1210,12 +1205,7 @@ fn read_inner(
// all labels must appear "prior" to this Name
(*ptr as usize) < name_start
})
.map_err(|e| {
ProtoError::from(ProtoErrorKind::PointerNotPriorToLabel {
idx: pointer_location,
ptr: e,
})
})?;
.map_err(|_| DecodeError::PointerNotPriorToLabel)?;

let mut pointer = decoder.clone(location);
read_inner(&mut pointer, label_data, label_ends, Some(name_start))?;
Expand Down
4 changes: 2 additions & 2 deletions crates/proto/src/rr/record_type.rs
Expand Up @@ -249,12 +249,12 @@ impl BinEncodable for RecordType {

impl<'r> BinDecodable<'r> for RecordType {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
decoder
Ok(decoder
.read_u16()
.map(
Restrict::unverified, /*RecordType is safe with any u16*/
)
.map(Self::from)
.map(Self::from)?)
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/proto/src/serialize/binary/bin_tests.rs
Expand Up @@ -58,7 +58,7 @@ fn get_u16_data() -> Vec<(u16, Vec<u8>)> {
#[test]
fn read_u16() {
test_read_data_set(get_u16_data(), |mut d| {
d.read_u16().map(Restrict::unverified)
d.read_u16().map(Restrict::unverified).map_err(Into::into)
});
}

Expand All @@ -83,7 +83,7 @@ fn get_i32_data() -> Vec<(i32, Vec<u8>)> {
#[test]
fn read_i32() {
test_read_data_set(get_i32_data(), |mut d| {
d.read_i32().map(Restrict::unverified)
d.read_i32().map(Restrict::unverified).map_err(Into::into)
});
}

Expand All @@ -109,7 +109,7 @@ fn get_u32_data() -> Vec<(u32, Vec<u8>)> {
#[test]
fn read_u32() {
test_read_data_set(get_u32_data(), |mut d| {
d.read_u32().map(Restrict::unverified)
d.read_u32().map(Restrict::unverified).map_err(Into::into)
});
}

Expand Down
92 changes: 52 additions & 40 deletions crates/proto/src/serialize/binary/decoder.rs
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/

use crate::error::{ProtoError, ProtoErrorKind, ProtoResult};
use crate::serialize::binary::Restrict;
use thiserror::Error;

/// This is non-destructive to the inner buffer, b/c for pointer types we need to perform a reverse
/// seek to lookup names
Expand All @@ -25,8 +25,43 @@ use crate::serialize::binary::Restrict;
/// this is a simpler implementation without the cruft, at least for serializing to/from the
/// binary DNS protocols.
pub struct BinDecoder<'a> {
buffer: &'a [u8],
remaining: &'a [u8],
buffer: &'a [u8], // The entire original buffer
remaining: &'a [u8], // The unread section of the original buffer, so that reads do not cause a bounds check at the current seek offset
}

pub type DecodeResult<T> = Result<T, DecodeError>;

/// An error that can occur deep in a decoder
/// This type is kept very small so that function that use it inline often
#[derive(Debug, Error)]
pub enum DecodeError {
/// Insufficient data in the buffer for a read operation
#[error("unexpected end of input reached")]
InsufficientBytes,

/// slice_from was called with an invalid index
#[error("index antecedes upper bound")]
InvalidPreviousIndex,

/// Pointer points to an index within or after the current label
#[error("pointer does not point to an index before the current label")]
PointerNotPriorToLabel,

/// Unreachable
#[error("label exceeds length of 63")]
LabelTooLong,

/// Invalid code for a label
#[error("the start of a label must be a pointer or a length less than 64")]
UnrecognizedLabelCode,

/// Domain names cannot exceed 255
#[error("domain name exceeds maximum length")]
DomainNameTooLong,

/// Labels may not overlap
#[error("label overlaps with other")]
LabelOverlapsWithOther,
}

impl<'a> BinDecoder<'a> {
Expand All @@ -44,9 +79,9 @@ impl<'a> BinDecoder<'a> {
}

/// Pop one byte from the buffer
pub fn pop(&mut self) -> ProtoResult<Restrict<u8>> {
pub fn pop(&mut self) -> DecodeResult<Restrict<u8>> {
if self.remaining.is_empty() {
return Err("insufficient bytes".into());
return Err(DecodeError::InsufficientBytes);
}
let (first, remaining) = self.remaining.split_at(1);
self.remaining = remaining;
Expand All @@ -65,7 +100,7 @@ impl<'a> BinDecoder<'a> {
/// assert_eq!(decoder.len(), 1);
/// ```
pub fn len(&self) -> usize {
self.buffer.len().saturating_sub(self.index())
self.remaining.len()
}

/// Returns `true` if the buffer is empty
Expand Down Expand Up @@ -104,31 +139,8 @@ impl<'a> BinDecoder<'a> {
/// # Returns
///
/// A String version of the character data
pub fn read_character_data(&mut self) -> ProtoResult<Restrict<&[u8]>> {
self.read_character_data_max(None)
}

/// Reads to a maximum length of data, returns an error if this is exceeded
pub fn read_character_data_max(
&mut self,
max_len: Option<usize>,
) -> ProtoResult<Restrict<&[u8]>> {
let length = self
.pop()?
.map(|u| u as usize)
.verify_unwrap(|length| {
if let Some(max_len) = max_len {
*length <= max_len
} else {
true
}
})
.map_err(|length| {
ProtoError::from(ProtoErrorKind::CharacterDataTooLong {
max: max_len.unwrap_or_default(),
len: length,
})
})?;
pub fn read_character_data(&mut self) -> DecodeResult<Restrict<&[u8]>> {
let length = self.pop()?.unverified() as usize;
self.read_slice(length)
}

Expand All @@ -141,7 +153,7 @@ impl<'a> BinDecoder<'a> {
/// # Returns
///
/// The Vec of the specified length, otherwise an error
pub fn read_vec(&mut self, len: usize) -> ProtoResult<Restrict<Vec<u8>>> {
pub fn read_vec(&mut self, len: usize) -> DecodeResult<Restrict<Vec<u8>>> {
self.read_slice(len).map(|s| s.map(ToOwned::to_owned))
}

Expand All @@ -154,26 +166,26 @@ impl<'a> BinDecoder<'a> {
/// # Returns
///
/// The slice of the specified length, otherwise an error
pub fn read_slice(&mut self, len: usize) -> ProtoResult<Restrict<&'a [u8]>> {
pub fn read_slice(&mut self, len: usize) -> DecodeResult<Restrict<&'a [u8]>> {
if len > self.remaining.len() {
return Err("buffer exhausted".into());
return Err(DecodeError::InsufficientBytes);
}
let (read, remaining) = self.remaining.split_at(len);
self.remaining = remaining;
Ok(Restrict::new(read))
}

/// Reads a slice from a previous index to the current
pub fn slice_from(&self, index: usize) -> ProtoResult<&'a [u8]> {
pub fn slice_from(&self, index: usize) -> DecodeResult<&'a [u8]> {
if index > self.index() {
return Err("index antecedes upper bound".into());
return Err(DecodeError::InvalidPreviousIndex);
}

Ok(&self.buffer[index..self.index()])
}

/// Reads a byte from the buffer, equivalent to `Self::pop()`
pub fn read_u8(&mut self) -> ProtoResult<Restrict<u8>> {
pub fn read_u8(&mut self) -> DecodeResult<Restrict<u8>> {
self.pop()
}

Expand All @@ -185,7 +197,7 @@ impl<'a> BinDecoder<'a> {
/// # Return
///
/// Return the u16 from the buffer
pub fn read_u16(&mut self) -> ProtoResult<Restrict<u16>> {
pub fn read_u16(&mut self) -> DecodeResult<Restrict<u16>> {
Ok(self
.read_slice(2)?
.map(|s| u16::from_be_bytes([s[0], s[1]])))
Expand All @@ -199,7 +211,7 @@ impl<'a> BinDecoder<'a> {
/// # Return
///
/// Return the i32 from the buffer
pub fn read_i32(&mut self) -> ProtoResult<Restrict<i32>> {
pub fn read_i32(&mut self) -> DecodeResult<Restrict<i32>> {
Ok(self.read_slice(4)?.map(|s| {
assert!(s.len() == 4);
i32::from_be_bytes([s[0], s[1], s[2], s[3]])
Expand All @@ -214,7 +226,7 @@ impl<'a> BinDecoder<'a> {
/// # Return
///
/// Return the u32 from the buffer
pub fn read_u32(&mut self) -> ProtoResult<Restrict<u32>> {
pub fn read_u32(&mut self) -> DecodeResult<Restrict<u32>> {
Ok(self.read_slice(4)?.map(|s| {
assert!(s.len() == 4);
u32::from_be_bytes([s[0], s[1], s[2], s[3]])
Expand Down
17 changes: 13 additions & 4 deletions crates/proto/src/serialize/binary/mod.rs
Expand Up @@ -20,7 +20,7 @@ mod decoder;
mod encoder;
mod restrict;

pub use self::decoder::BinDecoder;
pub use self::decoder::{BinDecoder, DecodeError};
pub use self::encoder::BinEncoder;
pub use self::encoder::EncodeMode;
pub use self::restrict::{Restrict, RestrictedMath, Verified};
Expand Down Expand Up @@ -67,7 +67,10 @@ impl BinEncodable for u16 {

impl<'r> BinDecodable<'r> for u16 {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
decoder.read_u16().map(Restrict::unverified)
decoder
.read_u16()
.map(Restrict::unverified)
.map_err(Into::into)
}
}

Expand All @@ -79,7 +82,10 @@ impl BinEncodable for i32 {

impl<'r> BinDecodable<'r> for i32 {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<i32> {
decoder.read_i32().map(Restrict::unverified)
decoder
.read_i32()
.map(Restrict::unverified)
.map_err(Into::into)
}
}

Expand All @@ -91,7 +97,10 @@ impl BinEncodable for u32 {

impl<'r> BinDecodable<'r> for u32 {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
decoder.read_u32().map(Restrict::unverified)
decoder
.read_u32()
.map(Restrict::unverified)
.map_err(Into::into)
}
}

Expand Down

0 comments on commit a1f94c2

Please sign in to comment.