Skip to content

Commit

Permalink
Special error types for BinDecoder and read_inner
Browse files Browse the repository at this point in the history
  • Loading branch information
saethlin committed Mar 5, 2021
1 parent 8a63948 commit d147588
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 95 deletions.
7 changes: 7 additions & 0 deletions crates/proto/src/error.rs
Expand Up @@ -26,6 +26,7 @@ use ring::error::Unspecified;
use thiserror::Error;

use crate::rr::{Name, RecordType};
use crate::serialize::binary::DecodeError;

#[cfg(feature = "backtrace")]
lazy_static! {
Expand Down Expand Up @@ -269,6 +270,12 @@ impl From<ProtoErrorKind> for ProtoError {
}
}

impl From<DecodeError> for ProtoError {
fn from(err: DecodeError) -> ProtoError {
ProtoErrorKind::Msg(err.to_string()).into()
}
}

impl From<&'static str> for ProtoError {
fn from(msg: &'static str) -> ProtoError {
ProtoErrorKind::Message(msg).into()
Expand Down
24 changes: 7 additions & 17 deletions crates/proto/src/rr/domain/name.rs
Expand Up @@ -1128,7 +1128,7 @@ fn read_inner(
label_data: &mut TinyVec<[u8; 32]>,
label_ends: &mut TinyVec<[u8; 16]>,
max_idx: Option<usize>,
) -> ProtoResult<()> {
) -> Result<(), DecodeError> {
let mut state: LabelParseState = LabelParseState::LabelLengthOrPointer;
let name_start = decoder.index();

Expand All @@ -1141,18 +1141,14 @@ fn read_inner(
// this protects against overlapping labels
if let Some(max_idx) = max_idx {
if decoder.index() >= max_idx {
return Err(ProtoErrorKind::LabelOverlapsWithOther {
label: name_start,
other: max_idx,
}
.into());
return Err(DecodeError::LabelOverlapsWithOther);
}
}

// enforce max length of name
let cur_len = label_data.len() + label_ends.len();
if cur_len > 255 {
return Err(ProtoErrorKind::DomainNameTooLong(cur_len).into());
return Err(DecodeError::DomainNameTooLong);
}

state = match state {
Expand All @@ -1165,15 +1161,15 @@ fn read_inner(
Some(0) | None => LabelParseState::Root,
Some(byte) if byte & 0b1100_0000 == 0b1100_0000 => LabelParseState::Pointer,
Some(byte) if byte & 0b1100_0000 == 0b0000_0000 => LabelParseState::Label,
Some(byte) => return Err(ProtoErrorKind::UnrecognizedLabelCode(byte).into()),
Some(_) => return Err(DecodeError::UnrecognizedLabelCode),
}
}
// labels must have a maximum length of 63
LabelParseState::Label => {
let label = decoder
.read_character_data_max(Some(63))?
.read_character_data()?
.verify_unwrap(|l| l.len() <= 63)
.map_err(|_| ProtoError::from("label exceeds maximum length of 63"))?;
.map_err(|_| DecodeError::LabelTooLong)?;

label_data.extend_from_slice(label);
label_ends.push(label_data.len() as u8);
Expand Down Expand Up @@ -1203,7 +1199,6 @@ fn read_inner(
// domain header). A zero offset specifies the first byte of the ID field,
// etc.
LabelParseState::Pointer => {
let pointer_location = decoder.index();
let location = decoder
.read_u16()?
.map(|u| {
Expand All @@ -1214,12 +1209,7 @@ fn read_inner(
// all labels must appear "prior" to this Name
(*ptr as usize) < name_start
})
.map_err(|e| {
ProtoError::from(ProtoErrorKind::PointerNotPriorToLabel {
idx: pointer_location,
ptr: e,
})
})?;
.map_err(|_| DecodeError::PointerNotPriorToLabel)?;

let mut pointer = decoder.clone(location);
read_inner(&mut pointer, label_data, label_ends, Some(name_start))?;
Expand Down
4 changes: 2 additions & 2 deletions crates/proto/src/rr/record_type.rs
Expand Up @@ -249,12 +249,12 @@ impl BinEncodable for RecordType {

impl<'r> BinDecodable<'r> for RecordType {
fn read(decoder: &mut BinDecoder<'_>) -> ProtoResult<Self> {
decoder
Ok(decoder
.read_u16()
.map(
Restrict::unverified, /*RecordType is safe with any u16*/
)
.map(Self::from)
.map(Self::from)?)
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/proto/src/serialize/binary/bin_tests.rs
Expand Up @@ -58,7 +58,7 @@ fn get_u16_data() -> Vec<(u16, Vec<u8>)> {
#[test]
fn read_u16() {
test_read_data_set(get_u16_data(), |mut d| {
d.read_u16().map(Restrict::unverified)
d.read_u16().map(Restrict::unverified).map_err(Into::into)
});
}

Expand All @@ -83,7 +83,7 @@ fn get_i32_data() -> Vec<(i32, Vec<u8>)> {
#[test]
fn read_i32() {
test_read_data_set(get_i32_data(), |mut d| {
d.read_i32().map(Restrict::unverified)
d.read_i32().map(Restrict::unverified).map_err(Into::into)
});
}

Expand All @@ -109,7 +109,7 @@ fn get_u32_data() -> Vec<(u32, Vec<u8>)> {
#[test]
fn read_u32() {
test_read_data_set(get_u32_data(), |mut d| {
d.read_u32().map(Restrict::unverified)
d.read_u32().map(Restrict::unverified).map_err(Into::into)
});
}

Expand Down
146 changes: 77 additions & 69 deletions crates/proto/src/serialize/binary/decoder.rs
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/

use crate::error::{ProtoError, ProtoErrorKind, ProtoResult};
use crate::serialize::binary::Restrict;
use thiserror::Error;

/// This is non-destructive to the inner buffer, b/c for pointer types we need to perform a reverse
/// seek to lookup names
Expand All @@ -26,7 +26,42 @@ use crate::serialize::binary::Restrict;
/// binary DNS protocols.
pub struct BinDecoder<'a> {
buffer: &'a [u8],
index: usize,
remaining: &'a [u8],
}

pub type DecodeResult<T> = Result<T, DecodeError>;

/// And error that can occur deep in a decoder
/// This type is kept very small so that function that use it inline often
#[derive(Debug, Error)]
pub enum DecodeError {
/// Insufficient data in the buffer for a read operation
#[error("unexpected end of input reached")]
InsufficientBytes,

/// slice_from was called with an invalid index
#[error("index antecedes upper bound")]
InvalidPreviousIndex,

/// Pointer points to an index within or after the current label
#[error("pointer does not point to an index before the current label")]
PointerNotPriorToLabel,

/// Unreachable
#[error("label exceeds length of 63")]
LabelTooLong,

/// Invalid code for a label
#[error("the start of a label must be a pointer or a length less than 64")]
UnrecognizedLabelCode,

/// Domain names cannot exceed 255
#[error("domain name exceeds maximum length")]
DomainNameTooLong,

/// Labels may not overlap
#[error("label overlaps with other")]
LabelOverlapsWithOther,
}

impl<'a> BinDecoder<'a> {
Expand All @@ -35,19 +70,22 @@ impl<'a> BinDecoder<'a> {
/// # Arguments
///
/// * `buffer` - buffer from which all data will be read
#[inline]
pub fn new(buffer: &'a [u8]) -> Self {
BinDecoder { buffer, index: 0 }
BinDecoder {
buffer,
remaining: buffer,
}
}

/// Pop one byte from the buffer
pub fn pop(&mut self) -> ProtoResult<Restrict<u8>> {
if self.index < self.buffer.len() {
let byte = self.buffer[self.index];
self.index += 1;
Ok(Restrict::new(byte))
} else {
Err("unexpected end of input reached".into())
pub fn pop(&mut self) -> DecodeResult<Restrict<u8>> {
if self.remaining.is_empty() {
return Err(DecodeError::InsufficientBytes);
}
let (first, remaining) = self.remaining.split_at(1);
self.remaining = remaining;
Ok(Restrict::new(first[0]))
}

/// Returns the number of bytes in the buffer
Expand All @@ -62,7 +100,7 @@ impl<'a> BinDecoder<'a> {
/// assert_eq!(decoder.len(), 1);
/// ```
pub fn len(&self) -> usize {
self.buffer.len().saturating_sub(self.index)
self.remaining.len()
}

/// Returns `true` if the buffer is empty
Expand All @@ -72,24 +110,20 @@ impl<'a> BinDecoder<'a> {

/// Peed one byte forward, without moving the current index forward
pub fn peek(&self) -> Option<Restrict<u8>> {
if self.index < self.buffer.len() {
Some(Restrict::new(self.buffer[self.index]))
} else {
None
}
Some(Restrict::new(*self.remaining.get(0)?))
}

/// Returns the current index in the buffer
pub fn index(&self) -> usize {
self.index
self.buffer.len() - self.remaining.len()
}

/// This is a pretty efficient clone, as the buffer is never cloned, and only the index is set
/// to the value passed in
pub fn clone(&self, index_at: u16) -> BinDecoder<'a> {
BinDecoder {
buffer: self.buffer,
index: index_at as usize,
remaining: &self.buffer[index_at as usize..],
}
}

Expand All @@ -105,32 +139,8 @@ impl<'a> BinDecoder<'a> {
/// # Returns
///
/// A String version of the character data
pub fn read_character_data(&mut self) -> ProtoResult<Restrict<&[u8]>> {
self.read_character_data_max(None)
}

/// Reads to a maximum length of data, returns an error if this is exceeded
pub fn read_character_data_max(
&mut self,
max_len: Option<usize>,
) -> ProtoResult<Restrict<&[u8]>> {
let length = self
.pop()?
.map(|u| u as usize)
.verify_unwrap(|length| {
if let Some(max_len) = max_len {
*length <= max_len
} else {
true
}
})
.map_err(|length| {
ProtoError::from(ProtoErrorKind::CharacterDataTooLong {
max: max_len.unwrap_or_default(),
len: length,
})
})?;

pub fn read_character_data(&mut self) -> DecodeResult<Restrict<&[u8]>> {
let length = self.pop()?.unverified() as usize;
self.read_slice(length)
}

Expand All @@ -143,7 +153,7 @@ impl<'a> BinDecoder<'a> {
/// # Returns
///
/// The Vec of the specified length, otherwise an error
pub fn read_vec(&mut self, len: usize) -> ProtoResult<Restrict<Vec<u8>>> {
pub fn read_vec(&mut self, len: usize) -> DecodeResult<Restrict<Vec<u8>>> {
self.read_slice(len).map(|s| s.map(ToOwned::to_owned))
}

Expand All @@ -156,30 +166,26 @@ impl<'a> BinDecoder<'a> {
/// # Returns
///
/// The slice of the specified length, otherwise an error
pub fn read_slice(&mut self, len: usize) -> ProtoResult<Restrict<&'a [u8]>> {
let end = self
.index
.checked_add(len)
.ok_or_else(|| ProtoError::from("invalid length for slice"))?;
if end > self.buffer.len() {
return Err("buffer exhausted".into());
pub fn read_slice(&mut self, len: usize) -> DecodeResult<Restrict<&'a [u8]>> {
if len > self.remaining.len() {
return Err(DecodeError::InsufficientBytes);
}
let slice: &'a [u8] = &self.buffer[self.index..end];
self.index += len;
Ok(Restrict::new(slice))
let (read, remaining) = self.remaining.split_at(len);
self.remaining = remaining;
Ok(Restrict::new(read))
}

/// Reads a slice from a previous index to the current
pub fn slice_from(&self, index: usize) -> ProtoResult<&'a [u8]> {
if index > self.index {
return Err("index antecedes upper bound".into());
pub fn slice_from(&self, index: usize) -> DecodeResult<&'a [u8]> {
if index > self.index() {
return Err(DecodeError::InvalidPreviousIndex);
}

Ok(&self.buffer[index..self.index])
Ok(&self.buffer[index..self.index()])
}

/// Reads a byte from the buffer, equivalent to `Self::pop()`
pub fn read_u8(&mut self) -> ProtoResult<Restrict<u8>> {
pub fn read_u8(&mut self) -> DecodeResult<Restrict<u8>> {
self.pop()
}

Expand All @@ -191,7 +197,7 @@ impl<'a> BinDecoder<'a> {
/// # Return
///
/// Return the u16 from the buffer
pub fn read_u16(&mut self) -> ProtoResult<Restrict<u16>> {
pub fn read_u16(&mut self) -> DecodeResult<Restrict<u16>> {
Ok(self
.read_slice(2)?
.map(|s| u16::from_be_bytes([s[0], s[1]])))
Expand All @@ -205,10 +211,11 @@ impl<'a> BinDecoder<'a> {
/// # Return
///
/// Return the i32 from the buffer
pub fn read_i32(&mut self) -> ProtoResult<Restrict<i32>> {
Ok(self
.read_slice(4)?
.map(|s| i32::from_be_bytes([s[0], s[1], s[2], s[3]])))
pub fn read_i32(&mut self) -> DecodeResult<Restrict<i32>> {
Ok(self.read_slice(4)?.map(|s| {
assert!(s.len() == 4);
i32::from_be_bytes([s[0], s[1], s[2], s[3]])
}))
}

/// Reads the next four bytes into u32.
Expand All @@ -219,10 +226,11 @@ impl<'a> BinDecoder<'a> {
/// # Return
///
/// Return the u32 from the buffer
pub fn read_u32(&mut self) -> ProtoResult<Restrict<u32>> {
Ok(self
.read_slice(4)?
.map(|s| u32::from_be_bytes([s[0], s[1], s[2], s[3]])))
pub fn read_u32(&mut self) -> DecodeResult<Restrict<u32>> {
Ok(self.read_slice(4)?.map(|s| {
assert!(s.len() == 4);
u32::from_be_bytes([s[0], s[1], s[2], s[3]])
}))
}
}

Expand Down

0 comments on commit d147588

Please sign in to comment.