Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fs: future proof File #2930

Merged
merged 3 commits into from Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
164 changes: 96 additions & 68 deletions tokio/src/fs/file.rs
Expand Up @@ -6,6 +6,7 @@ use self::State::*;
use crate::fs::{asyncify, sys};
use crate::io::blocking::Buf;
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use crate::sync::Mutex;

use std::fmt;
use std::fs::{Metadata, Permissions};
Expand Down Expand Up @@ -80,6 +81,10 @@ use std::task::Poll::*;
/// ```
pub struct File {
std: Arc<sys::File>,
inner: Mutex<Inner>,
}

struct Inner {
state: State,

/// Errors from writes/flushes are returned in write/flush calls. If a write
Expand Down Expand Up @@ -199,9 +204,11 @@ impl File {
pub fn from_std(std: sys::File) -> File {
File {
std: Arc::new(std),
state: State::Idle(Some(Buf::with_capacity(0))),
last_write_err: None,
pos: 0,
inner: Mutex::new(Inner {
state: State::Idle(Some(Buf::with_capacity(0))),
last_write_err: None,
pos: 0,
}),
}
}

Expand All @@ -228,8 +235,9 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn sync_all(&mut self) -> io::Result<()> {
self.complete_inflight().await;
pub async fn sync_all(&self) -> io::Result<()> {
let mut inner = self.inner.lock().await;
inner.complete_inflight().await;

let std = self.std.clone();
asyncify(move || std.sync_all()).await
Expand Down Expand Up @@ -262,8 +270,9 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn sync_data(&mut self) -> io::Result<()> {
self.complete_inflight().await;
pub async fn sync_data(&self) -> io::Result<()> {
let mut inner = self.inner.lock().await;
inner.complete_inflight().await;

let std = self.std.clone();
asyncify(move || std.sync_data()).await
Expand Down Expand Up @@ -299,10 +308,11 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn set_len(&mut self, size: u64) -> io::Result<()> {
self.complete_inflight().await;
pub async fn set_len(&self, size: u64) -> io::Result<()> {
let mut inner = self.inner.lock().await;
inner.complete_inflight().await;

let mut buf = match self.state {
let mut buf = match inner.state {
Idle(ref mut buf_cell) => buf_cell.take().unwrap(),
_ => unreachable!(),
};
Expand All @@ -315,7 +325,7 @@ impl File {

let std = self.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| std.set_len(size))
} else {
Expand All @@ -327,16 +337,16 @@ impl File {
(Operation::Seek(res), buf)
}));

let (op, buf) = match self.state {
let (op, buf) = match inner.state {
Idle(_) => unreachable!(),
Busy(ref mut rx) => rx.await?,
};

self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));

match op {
Operation::Seek(res) => res.map(|pos| {
self.pos = pos;
inner.pos = pos;
}),
_ => unreachable!(),
}
Expand Down Expand Up @@ -402,7 +412,7 @@ impl File {
/// # }
/// ```
pub async fn into_std(mut self) -> sys::File {
self.complete_inflight().await;
self.inner.get_mut().complete_inflight().await;
Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed")
}

Expand Down Expand Up @@ -469,24 +479,19 @@ impl File {
let std = self.std.clone();
asyncify(move || std.set_permissions(perm)).await
}

async fn complete_inflight(&mut self) {
use crate::future::poll_fn;

if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await {
self.last_write_err = Some(e.kind());
}
}
}

impl AsyncRead for File {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let me = self.get_mut();
let inner = me.inner.get_mut();

loop {
match self.state {
match inner.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

Expand All @@ -497,9 +502,9 @@ impl AsyncRead for File {
}

buf.ensure_capacity_for(dst);
let std = self.std.clone();
let std = me.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = buf.read_from(&mut &*std);
(Operation::Read(res), buf)
}));
Expand All @@ -510,30 +515,30 @@ impl AsyncRead for File {
match op {
Operation::Read(Ok(_)) => {
buf.copy_to(dst);
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
return Ready(Ok(()));
}
Operation::Read(Err(e)) => {
assert!(buf.is_empty());

self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
return Ready(Err(e));
}
Operation::Write(Ok(_)) => {
assert!(buf.is_empty());
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
continue;
}
Operation::Write(Err(e)) => {
assert!(self.last_write_err.is_none());
self.last_write_err = Some(e.kind());
self.state = Idle(Some(buf));
assert!(inner.last_write_err.is_none());
inner.last_write_err = Some(e.kind());
inner.state = Idle(Some(buf));
}
Operation::Seek(result) => {
assert!(buf.is_empty());
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
if let Ok(pos) = result {
self.pos = pos;
inner.pos = pos;
}
continue;
}
Expand All @@ -545,9 +550,12 @@ impl AsyncRead for File {
}

impl AsyncSeek for File {
fn start_seek(mut self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> {
fn start_seek(self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> {
let me = self.get_mut();
let inner = me.inner.get_mut();

loop {
match self.state {
match inner.state {
Busy(_) => panic!("must wait for poll_complete before calling start_seek"),
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
Expand All @@ -561,9 +569,9 @@ impl AsyncSeek for File {
}
}

let std = self.std.clone();
let std = me.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = (&*std).seek(pos);
(Operation::Seek(res), buf)
}));
Expand All @@ -574,23 +582,25 @@ impl AsyncSeek for File {
}

fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let inner = self.inner.get_mut();

loop {
match self.state {
Idle(_) => return Poll::Ready(Ok(self.pos)),
match inner.state {
Idle(_) => return Poll::Ready(Ok(inner.pos)),
Busy(ref mut rx) => {
let (op, buf) = ready!(Pin::new(rx).poll(cx))?;
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));

match op {
Operation::Read(_) => {}
Operation::Write(Err(e)) => {
assert!(self.last_write_err.is_none());
self.last_write_err = Some(e.kind());
assert!(inner.last_write_err.is_none());
inner.last_write_err = Some(e.kind());
}
Operation::Write(_) => {}
Operation::Seek(res) => {
if let Ok(pos) = res {
self.pos = pos;
inner.pos = pos;
}
return Ready(res);
}
Expand All @@ -603,16 +613,19 @@ impl AsyncSeek for File {

impl AsyncWrite for File {
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
src: &[u8],
) -> Poll<io::Result<usize>> {
if let Some(e) = self.last_write_err.take() {
let me = self.get_mut();
let inner = me.inner.get_mut();

if let Some(e) = inner.last_write_err.take() {
return Ready(Err(e.into()));
}

loop {
match self.state {
match inner.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

Expand All @@ -623,9 +636,9 @@ impl AsyncWrite for File {
};

let n = buf.copy_from(src);
let std = self.std.clone();
let std = me.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
Expand All @@ -639,7 +652,7 @@ impl AsyncWrite for File {
}
Busy(ref mut rx) => {
let (op, buf) = ready!(Pin::new(rx).poll(cx))?;
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));

match op {
Operation::Read(_) => {
Expand All @@ -665,23 +678,8 @@ impl AsyncWrite for File {
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if let Some(e) = self.last_write_err.take() {
return Ready(Err(e.into()));
}

let (op, buf) = match self.state {
Idle(_) => return Ready(Ok(())),
Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?,
};

// The buffer is not used here
self.state = Idle(Some(buf));

match op {
Operation::Read(_) => Ready(Ok(())),
Operation::Write(res) => Ready(res),
Operation::Seek(_) => Ready(Ok(())),
}
let inner = self.inner.get_mut();
inner.poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Expand Down Expand Up @@ -731,3 +729,33 @@ impl std::os::windows::io::FromRawHandle for File {
sys::File::from_raw_handle(handle).into()
}
}

impl Inner {
async fn complete_inflight(&mut self) {
use crate::future::poll_fn;

if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await {
self.last_write_err = Some(e.kind());
}
}

fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if let Some(e) = self.last_write_err.take() {
return Ready(Err(e.into()));
}

let (op, buf) = match self.state {
Idle(_) => return Ready(Ok(())),
Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?,
};

// The buffer is not used here
self.state = Idle(Some(buf));

match op {
Operation::Read(_) => Ready(Ok(())),
Operation::Write(res) => Ready(res),
Operation::Seek(_) => Ready(Ok(())),
}
}
}
1 change: 1 addition & 0 deletions tokio/src/sync/batch_semaphore.rs
@@ -1,3 +1,4 @@
#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
//! # Implementation Details
//!
//! The semaphore is implemented using an intrusive linked list of waiters. An
Expand Down
9 changes: 8 additions & 1 deletion tokio/src/sync/mod.rs
Expand Up @@ -456,6 +456,14 @@ cfg_sync! {
}

cfg_not_sync! {
#[cfg(any(feature = "fs", feature = "signal", all(unix, feature = "process")))]
pub(crate) mod batch_semaphore;

cfg_fs! {
mod mutex;
pub(crate) use mutex::Mutex;
}

mod notify;
pub(crate) use notify::Notify;

Expand All @@ -472,7 +480,6 @@ cfg_not_sync! {

cfg_signal_internal! {
pub(crate) mod mpsc;
pub(crate) mod batch_semaphore;
}
}

Expand Down
2 changes: 2 additions & 0 deletions tokio/src/sync/mutex.rs
@@ -1,3 +1,5 @@
#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]

use crate::sync::batch_semaphore as semaphore;

use std::cell::UnsafeCell;
Expand Down