Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for readers that implement Seek (#218) #266

Merged
merged 1 commit into from Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 45 additions & 5 deletions src/archive.rs
@@ -1,5 +1,6 @@
use std::cell::{Cell, RefCell};
use std::cmp;
use std::convert::TryFrom;
use std::fs;
use std::io;
use std::io::prelude::*;
Expand Down Expand Up @@ -35,8 +36,12 @@ pub struct Entries<'a, R: 'a + Read> {
_ignored: marker::PhantomData<&'a Archive<R>>,
}

trait SeekRead: Read + Seek {}
impl<R: Read + Seek> SeekRead for R {}

struct EntriesFields<'a> {
archive: &'a Archive<dyn Read + 'a>,
seekable_archive: Option<&'a Archive<dyn SeekRead + 'a>>,
next: u64,
done: bool,
raw: bool,
Expand Down Expand Up @@ -71,7 +76,7 @@ impl<R: Read> Archive<R> {
/// corrupted.
pub fn entries(&mut self) -> io::Result<Entries<R>> {
let me: &mut Archive<dyn Read> = self;
me._entries().map(|fields| Entries {
me._entries(None).map(|fields| Entries {
fields: fields,
_ignored: marker::PhantomData,
})
Expand Down Expand Up @@ -143,8 +148,29 @@ impl<R: Read> Archive<R> {
}
}

impl<R: Seek + Read> Archive<R> {
/// Construct an iterator over the entries in this archive for a seekable
/// reader. Seek will be used to efficiently skip over file contents.
///
/// Note that care must be taken to consider each entry within an archive in
/// sequence. If entries are processed out of sequence (from what the
/// iterator returns), then the contents read for each entry may be
/// corrupted.
pub fn entries_with_seek(&mut self) -> io::Result<Entries<R>> {
let me: &Archive<dyn Read> = self;
let me_seekable: &Archive<dyn SeekRead> = self;
me._entries(Some(me_seekable)).map(|fields| Entries {
fields: fields,
_ignored: marker::PhantomData,
})
}
}

impl<'a> Archive<dyn Read + 'a> {
fn _entries(&mut self) -> io::Result<EntriesFields> {
fn _entries(
&'a self,
seekable_archive: Option<&'a Archive<dyn SeekRead + 'a>>,
) -> io::Result<EntriesFields> {
if self.inner.pos.get() != 0 {
return Err(other(
"cannot call entries unless archive is at \
Expand All @@ -153,13 +179,14 @@ impl<'a> Archive<dyn Read + 'a> {
}
Ok(EntriesFields {
archive: self,
seekable_archive,
done: false,
next: 0,
raw: false,
})
}

fn _unpack(&mut self, dst: &Path) -> io::Result<()> {
fn _unpack(&'a mut self, dst: &Path) -> io::Result<()> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this 'a may no longer be necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rustc seems to need this to determine the lifetime of the seekable_archive argument when calling _entries().

if dst.symlink_metadata().is_err() {
fs::create_dir_all(&dst)
.map_err(|e| TarError::new(&format!("failed to create `{}`", dst.display()), e))?;
Expand All @@ -176,7 +203,7 @@ impl<'a> Archive<dyn Read + 'a> {
// descendants), to ensure that directory permissions do not interfer with descendant
// extraction.
let mut directories = Vec::new();
for entry in self._entries()? {
for entry in self._entries(None)? {
let mut file = entry.map_err(|e| TarError::new("failed to iterate over archive", e))?;
if file.header().entry_type() == crate::EntryType::Directory {
directories.push(file);
Expand Down Expand Up @@ -241,7 +268,7 @@ impl<'a> EntriesFields<'a> {
loop {
// Seek to the start of the next header in the archive
let delta = self.next - self.archive.inner.pos.get();
self.archive.skip(delta)?;
self.skip(delta)?;

// EOF is an indicator that we are at the end of the archive.
if !try_read_all(&mut &self.archive.inner, header.as_mut_bytes())? {
Expand Down Expand Up @@ -476,6 +503,19 @@ impl<'a> EntriesFields<'a> {
}
Ok(())
}

fn skip(&mut self, amt: u64) -> io::Result<()> {
if let Some(seekable_archive) = self.seekable_archive {
let pos = io::SeekFrom::Current(
i64::try_from(amt).map_err(|_| other("seek position out of bounds"))?,
);
let i = seekable_archive.inner.obj.borrow_mut().seek(pos)?;
seekable_archive.inner.pos.set(i);
Ok(())
} else {
self.archive.skip(amt)
}
}
}

impl<'a> Iterator for EntriesFields<'a> {
Expand Down
71 changes: 64 additions & 7 deletions tests/all.rs
Expand Up @@ -11,7 +11,7 @@ use std::iter::repeat;
use std::path::{Path, PathBuf};

use filetime::FileTime;
use tar::{Archive, Builder, EntryType, Header, HeaderMode};
use tar::{Archive, Builder, Entries, EntryType, Header, HeaderMode};
use tempfile::{Builder as TempBuilder, TempDir};

macro_rules! t {
Expand Down Expand Up @@ -203,11 +203,7 @@ fn large_filename() {
assert!(entries.next().is_none());
}

#[test]
fn reading_entries() {
let rdr = Cursor::new(tar!("reading_files.tar"));
let mut ar = Archive::new(rdr);
let mut entries = t!(ar.entries());
fn reading_entries_common<R: Read>(mut entries: Entries<R>) {
let mut a = t!(entries.next().unwrap());
assert_eq!(&*a.header().path_bytes(), b"a");
let mut s = String::new();
Expand All @@ -216,15 +212,76 @@ fn reading_entries() {
s.truncate(0);
t!(a.read_to_string(&mut s));
assert_eq!(s, "");
let mut b = t!(entries.next().unwrap());

let mut b = t!(entries.next().unwrap());
assert_eq!(&*b.header().path_bytes(), b"b");
s.truncate(0);
t!(b.read_to_string(&mut s));
assert_eq!(s, "b\nb\nb\nb\nb\nb\nb\nb\nb\nb\nb\n");
assert!(entries.next().is_none());
}

#[test]
fn reading_entries() {
let rdr = Cursor::new(tar!("reading_files.tar"));
let mut ar = Archive::new(rdr);
reading_entries_common(t!(ar.entries()));
}

#[test]
fn reading_entries_with_seek() {
let rdr = Cursor::new(tar!("reading_files.tar"));
let mut ar = Archive::new(rdr);
reading_entries_common(t!(ar.entries_with_seek()));
}

struct LoggingReader<R> {
inner: R,
read_bytes: u64,
}

impl<R> LoggingReader<R> {
fn new(reader: R) -> LoggingReader<R> {
LoggingReader {
inner: reader,
read_bytes: 0,
}
}
}

impl<T: Read> Read for LoggingReader<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf).map(|i| {
self.read_bytes += i as u64;
i
})
}
}

impl<T: Seek> Seek for LoggingReader<T> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
self.inner.seek(pos)
}
}

#[test]
fn skipping_entries_with_seek() {
let mut reader = LoggingReader::new(Cursor::new(tar!("reading_files.tar")));
let mut ar_reader = Archive::new(&mut reader);
let files: Vec<_> = t!(ar_reader.entries())
.map(|entry| entry.unwrap().path().unwrap().to_path_buf())
.collect();

let mut seekable_reader = LoggingReader::new(Cursor::new(tar!("reading_files.tar")));
let mut ar_seekable_reader = Archive::new(&mut seekable_reader);
let files_seekable: Vec<_> = t!(ar_seekable_reader.entries_with_seek())
.map(|entry| entry.unwrap().path().unwrap().to_path_buf())
.collect();

assert!(files == files_seekable);
assert!(seekable_reader.read_bytes < reader.read_bytes);
}

fn check_dirtree(td: &TempDir) {
let dir_a = td.path().join("a");
let dir_b = td.path().join("a/b");
Expand Down