Skip to content

Commit

Permalink
Add support for readers that implement Seek (#218)
Browse files Browse the repository at this point in the history
`Archive::entries_with_seek` can be used to get an iterator over
entries for a reader that implements `Seek`.
  • Loading branch information
fermeise committed Oct 6, 2021
1 parent 60c6bd8 commit 396c55d
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 12 deletions.
50 changes: 45 additions & 5 deletions src/archive.rs
@@ -1,5 +1,6 @@
use std::cell::{Cell, RefCell};
use std::cmp;
use std::convert::TryFrom;
use std::fs;
use std::io;
use std::io::prelude::*;
Expand Down Expand Up @@ -35,8 +36,12 @@ pub struct Entries<'a, R: 'a + Read> {
_ignored: marker::PhantomData<&'a Archive<R>>,
}

trait SeekRead: Read + Seek {}
impl<R: Read + Seek> SeekRead for R {}

struct EntriesFields<'a> {
archive: &'a Archive<dyn Read + 'a>,
seekable_archive: Option<&'a Archive<dyn SeekRead + 'a>>,
next: u64,
done: bool,
raw: bool,
Expand Down Expand Up @@ -71,7 +76,7 @@ impl<R: Read> Archive<R> {
/// corrupted.
pub fn entries(&mut self) -> io::Result<Entries<R>> {
let me: &mut Archive<dyn Read> = self;
me._entries().map(|fields| Entries {
me._entries(None).map(|fields| Entries {
fields: fields,
_ignored: marker::PhantomData,
})
Expand Down Expand Up @@ -143,8 +148,29 @@ impl<R: Read> Archive<R> {
}
}

impl<R: Seek + Read> Archive<R> {
/// Construct an iterator over the entries in this archive for a seekable
/// reader. Seek will be used to efficiently skip over file contents.
///
/// Note that care must be taken to consider each entry within an archive in
/// sequence. If entries are processed out of sequence (from what the
/// iterator returns), then the contents read for each entry may be
/// corrupted.
pub fn entries_with_seek(&mut self) -> io::Result<Entries<R>> {
let me: &Archive<dyn Read> = self;
let me_seekable: &Archive<dyn SeekRead> = self;
me._entries(Some(me_seekable)).map(|fields| Entries {
fields: fields,
_ignored: marker::PhantomData,
})
}
}

impl<'a> Archive<dyn Read + 'a> {
fn _entries(&mut self) -> io::Result<EntriesFields> {
fn _entries(
&'a self,
seekable_archive: Option<&'a Archive<dyn SeekRead + 'a>>,
) -> io::Result<EntriesFields> {
if self.inner.pos.get() != 0 {
return Err(other(
"cannot call entries unless archive is at \
Expand All @@ -153,13 +179,14 @@ impl<'a> Archive<dyn Read + 'a> {
}
Ok(EntriesFields {
archive: self,
seekable_archive,
done: false,
next: 0,
raw: false,
})
}

fn _unpack(&mut self, dst: &Path) -> io::Result<()> {
fn _unpack(&'a mut self, dst: &Path) -> io::Result<()> {
if dst.symlink_metadata().is_err() {
fs::create_dir_all(&dst)
.map_err(|e| TarError::new(&format!("failed to create `{}`", dst.display()), e))?;
Expand All @@ -176,7 +203,7 @@ impl<'a> Archive<dyn Read + 'a> {
// descendants), to ensure that directory permissions do not interfer with descendant
// extraction.
let mut directories = Vec::new();
for entry in self._entries()? {
for entry in self._entries(None)? {
let mut file = entry.map_err(|e| TarError::new("failed to iterate over archive", e))?;
if file.header().entry_type() == crate::EntryType::Directory {
directories.push(file);
Expand Down Expand Up @@ -241,7 +268,7 @@ impl<'a> EntriesFields<'a> {
loop {
// Seek to the start of the next header in the archive
let delta = self.next - self.archive.inner.pos.get();
self.archive.skip(delta)?;
self.skip(delta)?;

// EOF is an indicator that we are at the end of the archive.
if !try_read_all(&mut &self.archive.inner, header.as_mut_bytes())? {
Expand Down Expand Up @@ -476,6 +503,19 @@ impl<'a> EntriesFields<'a> {
}
Ok(())
}

fn skip(&mut self, amt: u64) -> io::Result<()> {
if let Some(seekable_archive) = self.seekable_archive {
let pos = io::SeekFrom::Current(
i64::try_from(amt).map_err(|_| other("seek position out of bounds"))?,
);
let i = seekable_archive.inner.obj.borrow_mut().seek(pos)?;
seekable_archive.inner.pos.set(i);
Ok(())
} else {
self.archive.skip(amt)
}
}
}

impl<'a> Iterator for EntriesFields<'a> {
Expand Down
71 changes: 64 additions & 7 deletions tests/all.rs
Expand Up @@ -11,7 +11,7 @@ use std::iter::repeat;
use std::path::{Path, PathBuf};

use filetime::FileTime;
use tar::{Archive, Builder, EntryType, Header, HeaderMode};
use tar::{Archive, Builder, Entries, EntryType, Header, HeaderMode};
use tempfile::{Builder as TempBuilder, TempDir};

macro_rules! t {
Expand Down Expand Up @@ -203,11 +203,7 @@ fn large_filename() {
assert!(entries.next().is_none());
}

#[test]
fn reading_entries() {
let rdr = Cursor::new(tar!("reading_files.tar"));
let mut ar = Archive::new(rdr);
let mut entries = t!(ar.entries());
fn reading_entries_common<R: Read>(mut entries: Entries<R>) {
let mut a = t!(entries.next().unwrap());
assert_eq!(&*a.header().path_bytes(), b"a");
let mut s = String::new();
Expand All @@ -216,15 +212,76 @@ fn reading_entries() {
s.truncate(0);
t!(a.read_to_string(&mut s));
assert_eq!(s, "");
let mut b = t!(entries.next().unwrap());

let mut b = t!(entries.next().unwrap());
assert_eq!(&*b.header().path_bytes(), b"b");
s.truncate(0);
t!(b.read_to_string(&mut s));
assert_eq!(s, "b\nb\nb\nb\nb\nb\nb\nb\nb\nb\nb\n");
assert!(entries.next().is_none());
}

#[test]
fn reading_entries() {
let rdr = Cursor::new(tar!("reading_files.tar"));
let mut ar = Archive::new(rdr);
reading_entries_common(t!(ar.entries()));
}

#[test]
fn reading_entries_with_seek() {
let rdr = Cursor::new(tar!("reading_files.tar"));
let mut ar = Archive::new(rdr);
reading_entries_common(t!(ar.entries_with_seek()));
}

struct LoggingReader<R> {
inner: R,
read_bytes: u64,
}

impl<R> LoggingReader<R> {
fn new(reader: R) -> LoggingReader<R> {
LoggingReader {
inner: reader,
read_bytes: 0,
}
}
}

impl<T: Read> Read for LoggingReader<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf).map(|i| {
self.read_bytes += i as u64;
i
})
}
}

impl<T: Seek> Seek for LoggingReader<T> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
self.inner.seek(pos)
}
}

#[test]
fn skipping_entries_with_seek() {
let mut reader = LoggingReader::new(Cursor::new(tar!("reading_files.tar")));
let mut ar_reader = Archive::new(&mut reader);
let files: Vec<_> = t!(ar_reader.entries())
.map(|entry| entry.unwrap().path().unwrap().to_path_buf())
.collect();

let mut seekable_reader = LoggingReader::new(Cursor::new(tar!("reading_files.tar")));
let mut ar_seekable_reader = Archive::new(&mut seekable_reader);
let files_seekable: Vec<_> = t!(ar_seekable_reader.entries_with_seek())
.map(|entry| entry.unwrap().path().unwrap().to_path_buf())
.collect();

assert!(files == files_seekable);
assert!(seekable_reader.read_bytes < reader.read_bytes);
}

fn check_dirtree(td: &TempDir) {
let dir_a = td.path().join("a");
let dir_b = td.path().join("a/b");
Expand Down

0 comments on commit 396c55d

Please sign in to comment.