From 24015e8ed17a4d880311247cba2b58b53deaa890 Mon Sep 17 00:00:00 2001 From: Riccardo Casatta Date: Tue, 20 Apr 2021 13:59:37 +0200 Subject: [PATCH] Introduce CappedRead --- src/blockdata/transaction.rs | 3 +- src/consensus/encode.rs | 53 +++++++++++++++++++++++++++++++++++- src/consensus/mod.rs | 2 +- src/internal_macros.rs | 1 + src/util/psbt/map/global.rs | 3 +- src/util/psbt/mod.rs | 4 ++- 6 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/blockdata/transaction.rs b/src/blockdata/transaction.rs index f45b8908ec..4b7885eddc 100644 --- a/src/blockdata/transaction.rs +++ b/src/blockdata/transaction.rs @@ -33,7 +33,7 @@ use util::endian; use blockdata::constants::WITNESS_SCALE_FACTOR; #[cfg(feature="bitcoinconsensus")] use blockdata::script; use blockdata::script::Script; -use consensus::{encode, Decodable, Encodable}; +use consensus::{encode, Decodable, Encodable, CappedRead}; use hash_types::{SigHash, Txid, Wtxid}; use VarInt; @@ -567,6 +567,7 @@ impl Encodable for Transaction { impl Decodable for Transaction { fn consensus_decode(mut d: D) -> Result { + let mut d = CappedRead::new(&mut d); let version = i32::consensus_decode(&mut d)?; let input = Vec::::consensus_decode(&mut d)?; // segwit diff --git a/src/consensus/encode.rs b/src/consensus/encode.rs index 54c8366f90..602aad64e1 100644 --- a/src/consensus/encode.rs +++ b/src/consensus/encode.rs @@ -547,6 +547,38 @@ impl Encodable for [u16; 8] { } } +/// Read wrapper capped to `remaining` bytes (default: [MAX_VEC_SIZE]) +pub struct CappedRead<'r> { + reader: &'r mut dyn io::Read, + remaining: usize, +} + +impl<'r> CappedRead<'r> { + /// New [CappedRead] from a Read type capped to [MAX_VEC_SIZE] bytes + pub fn new(reader: &'r mut R) -> Self { + Self::with_cap(reader, MAX_VEC_SIZE) + } + + /// New [CappedRead] from a Read type with custom max bytes + pub fn with_cap(reader: &'r mut R, cap: usize) -> Self { + CappedRead { + reader: reader as &'r mut dyn io::Read, + remaining: cap, + } + } +} + +impl<'r> io::Read for CappedRead<'r> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let bytes = self.reader.read(buf)?; + self.remaining = self.remaining + .checked_sub(bytes) + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "cap read exceeded"))?; + + Ok(bytes) + } +} + // Vectors macro_rules! impl_vec { ($type: ty) => { @@ -575,9 +607,11 @@ macro_rules! impl_vec { return Err(self::Error::OversizedVectorAllocation { requested: byte_size, max: MAX_VEC_SIZE }) } let mut ret = Vec::with_capacity(len as usize); + let mut capped_reader = CappedRead::new(&mut d); for _ in 0..len { - ret.push(Decodable::consensus_decode(&mut d)?); + ret.push(Decodable::consensus_decode(&mut capped_reader)?); } + Ok(ret) } } @@ -764,6 +798,7 @@ mod tests { use secp256k1::rand::{thread_rng, Rng}; use network::message_blockdata::Inventory; use network::Address; + use consensus::encode::{CappedRead, MAX_VEC_SIZE}; #[test] fn serialize_int_test() { @@ -997,6 +1032,22 @@ mod tests { assert_eq!(cd.ok(), Some(CheckedData(vec![1u8, 2, 3, 4, 5]))); } + #[test] + fn capped_reader_test() { + let mut cursor = io::Cursor::new(vec![1u8]); + let mut count_read = CappedRead::new(&mut cursor); + let v = VarInt::consensus_decode(&mut count_read).unwrap(); + assert_eq!(count_read.remaining, MAX_VEC_SIZE - 1); + assert_eq!(v.0, 1u64); + + let witness = vec![vec![0u8; 3_999_999]; 2]; + let ser = serialize(&witness); + let mut cursor = io::Cursor::new(ser); + let mut count_read = CappedRead::new(&mut cursor); + let err = Vec::>::consensus_decode(&mut count_read); + assert!(err.is_err()); + } + #[test] fn serialization_round_trips() { macro_rules! round_trip { diff --git a/src/consensus/mod.rs b/src/consensus/mod.rs index 1e302333a2..67f73f70e9 100644 --- a/src/consensus/mod.rs +++ b/src/consensus/mod.rs @@ -21,6 +21,6 @@ pub mod encode; pub mod params; -pub use self::encode::{Encodable, Decodable, WriteExt, ReadExt}; +pub use self::encode::{Encodable, Decodable, WriteExt, ReadExt, CappedRead}; pub use self::encode::{serialize, deserialize, deserialize_partial}; pub use self::params::Params; diff --git a/src/internal_macros.rs b/src/internal_macros.rs index dc66ec602a..8a6a0d587e 100644 --- a/src/internal_macros.rs +++ b/src/internal_macros.rs @@ -35,6 +35,7 @@ macro_rules! impl_consensus_encoding { fn consensus_decode( mut d: D, ) -> Result<$thing, $crate::consensus::encode::Error> { + let mut d = $crate::consensus::CappedRead::new(&mut d); Ok($thing { $($field: $crate::consensus::Decodable::consensus_decode(&mut d)?),+ }) diff --git a/src/util/psbt/map/global.rs b/src/util/psbt/map/global.rs index 7369afa5fb..be316798fb 100644 --- a/src/util/psbt/map/global.rs +++ b/src/util/psbt/map/global.rs @@ -18,7 +18,7 @@ use std::io::{self, Cursor, Read}; use std::cmp; use blockdata::transaction::Transaction; -use consensus::{encode, Encodable, Decodable}; +use consensus::{encode, Encodable, Decodable, CappedRead}; use util::psbt::map::Map; use util::psbt::raw; use util::psbt; @@ -229,6 +229,7 @@ impl_psbtmap_consensus_encoding!(Global); impl Decodable for Global { fn consensus_decode(mut d: D) -> Result { + let mut d = CappedRead::new(&mut d); let mut tx: Option = None; let mut version: Option = None; diff --git a/src/util/psbt/mod.rs b/src/util/psbt/mod.rs index 3cd2ef894c..1e9d27ed50 100644 --- a/src/util/psbt/mod.rs +++ b/src/util/psbt/mod.rs @@ -20,7 +20,7 @@ use blockdata::script::Script; use blockdata::transaction::Transaction; -use consensus::{encode, Encodable, Decodable}; +use consensus::{encode, Encodable, Decodable, CappedRead}; use std::io; @@ -163,6 +163,8 @@ impl Encodable for PartiallySignedTransaction { impl Decodable for PartiallySignedTransaction { fn consensus_decode(mut d: D) -> Result { + let mut d = CappedRead::new(&mut d); + let magic: [u8; 4] = Decodable::consensus_decode(&mut d)?; if *b"psbt" != magic {