Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fs: add set_max_buf_size to tokio::fs::File #6411

Merged
merged 4 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 34 additions & 4 deletions tokio/src/fs/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! [`File`]: File

use crate::fs::{asyncify, OpenOptions};
use crate::io::blocking::Buf;
use crate::io::blocking::{Buf, DEFAULT_MAX_BUF_SIZE};
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use crate::sync::Mutex;

Expand Down Expand Up @@ -90,6 +90,7 @@ use std::fs::File as StdFile;
pub struct File {
std: Arc<StdFile>,
inner: Mutex<Inner>,
max_buf_size: usize,
}

struct Inner {
Expand Down Expand Up @@ -241,6 +242,7 @@ impl File {
last_write_err: None,
pos: 0,
}),
max_buf_size: DEFAULT_MAX_BUF_SIZE,
}
}

Expand Down Expand Up @@ -508,6 +510,34 @@ impl File {
let std = self.std.clone();
asyncify(move || std.set_permissions(perm)).await
}

/// Set the maximum buffer size for the underlying [`AsyncRead`] / [`AsyncWrite`] operation.
///
/// Although Tokio uses a sensible default value for this buffer size, this function would be
/// useful for changing that default depending on the situation.
///
/// # Examples
///
/// ```no_run
/// use tokio::fs::File;
/// use tokio::io::AsyncWriteExt;
///
/// # async fn dox() -> std::io::Result<()> {
/// let mut file = File::open("foo.txt").await?;
///
/// // Set maximum buffer size to 8 MiB
/// file.set_max_buf_size(8 * 1024 * 1024);
///
/// let mut buf = vec![1; 1024 * 1024 * 1024];
///
/// // Write the 1 GiB buffer in chunks up to 8 MiB each.
/// file.write_all(&mut buf).await?;
/// # Ok(())
/// # }
/// ```
pub fn set_max_buf_size(&mut self, max_buf_size: usize) {
self.max_buf_size = max_buf_size;
}
}

impl AsyncRead for File {
Expand All @@ -531,7 +561,7 @@ impl AsyncRead for File {
return Poll::Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
buf.ensure_capacity_for(dst, me.max_buf_size);
let std = me.std.clone();

inner.state = State::Busy(spawn_blocking(move || {
Expand Down Expand Up @@ -668,7 +698,7 @@ impl AsyncWrite for File {
None
};

let n = buf.copy_from(src);
let n = buf.copy_from(src, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
Expand Down Expand Up @@ -739,7 +769,7 @@ impl AsyncWrite for File {
None
};

let n = buf.copy_from_bufs(bufs);
let n = buf.copy_from_bufs(bufs, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/fs/file/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ fn flush_while_idle() {
#[cfg_attr(miri, ignore)] // takes a really long time with miri
fn read_with_buffer_larger_than_max() {
// Chunks
let chunk_a = crate::io::blocking::MAX_BUF;
let chunk_a = crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
let chunk_b = chunk_a * 2;
let chunk_c = chunk_a * 3;
let chunk_d = chunk_a * 4;
Expand Down Expand Up @@ -303,7 +303,7 @@ fn read_with_buffer_larger_than_max() {
#[cfg_attr(miri, ignore)] // takes a really long time with miri
fn write_with_buffer_larger_than_max() {
// Chunks
let chunk_a = crate::io::blocking::MAX_BUF;
let chunk_a = crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
let chunk_b = chunk_a * 2;
let chunk_c = chunk_a * 3;
let chunk_d = chunk_a * 4;
Expand Down
1 change: 1 addition & 0 deletions tokio/src/fs/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mock! {
pub fn open(pb: PathBuf) -> io::Result<Self>;
pub fn set_len(&self, size: u64) -> io::Result<()>;
pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()>;
pub fn set_max_buf_size(&self, max_buf_size: usize);
pub fn sync_all(&self) -> io::Result<()>;
pub fn sync_data(&self) -> io::Result<()>;
pub fn try_clone(&self) -> io::Result<Self>;
Expand Down
20 changes: 10 additions & 10 deletions tokio/src/io/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) struct Buf {
pos: usize,
}

pub(crate) const MAX_BUF: usize = 2 * 1024 * 1024;
pub(crate) const DEFAULT_MAX_BUF_SIZE: usize = 2 * 1024 * 1024;

#[derive(Debug)]
enum State<T> {
Expand Down Expand Up @@ -64,7 +64,7 @@ where
return Poll::Ready(Ok(()));
}

buf.ensure_capacity_for(dst);
buf.ensure_capacity_for(dst, DEFAULT_MAX_BUF_SIZE);
let mut inner = self.inner.take().unwrap();

self.state = State::Busy(sys::run(move || {
Expand Down Expand Up @@ -111,7 +111,7 @@ where

assert!(buf.is_empty());

let n = buf.copy_from(src);
let n = buf.copy_from(src, DEFAULT_MAX_BUF_SIZE);
let mut inner = self.inner.take().unwrap();

self.state = State::Busy(sys::run(move || {
Expand Down Expand Up @@ -214,10 +214,10 @@ impl Buf {
n
}

pub(crate) fn copy_from(&mut self, src: &[u8]) -> usize {
pub(crate) fn copy_from(&mut self, src: &[u8], max_buf_size: usize) -> usize {
assert!(self.is_empty());

let n = cmp::min(src.len(), MAX_BUF);
let n = cmp::min(src.len(), max_buf_size);

self.buf.extend_from_slice(&src[..n]);
n
Expand All @@ -227,10 +227,10 @@ impl Buf {
&self.buf[self.pos..]
}

pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>, max_buf_size: usize) {
assert!(self.is_empty());

let len = cmp::min(bytes.remaining(), MAX_BUF);
let len = cmp::min(bytes.remaining(), max_buf_size);

if self.buf.len() < len {
self.buf.reserve(len - self.buf.len());
Expand Down Expand Up @@ -274,10 +274,10 @@ cfg_fs! {
ret
}

pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>]) -> usize {
pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>], max_buf_size: usize) -> usize {
assert!(self.is_empty());

let mut rem = MAX_BUF;
let mut rem = max_buf_size;
for buf in bufs {
if rem == 0 {
break
Expand All @@ -288,7 +288,7 @@ cfg_fs! {
rem -= len;
}

MAX_BUF - rem
max_buf_size - rem
}
}
}
19 changes: 10 additions & 9 deletions tokio/src/io/stdio_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};
/// # Windows
/// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest,
/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `MAX_BUF`.
/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `DEFAULT_MAX_BUF_SIZE`.
/// That's why, wrapped writer will always receive well-formed utf-8 bytes.
/// # Other platforms
/// Passes data to `inner` as is.
Expand Down Expand Up @@ -45,12 +45,13 @@ where
// 2. If buffer is small, it will not be shrunk.
// That's why, it's "textness" will not change, so we don't have
// to fixup it.
if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF
if cfg!(not(any(target_os = "windows", test)))
|| buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE
{
return call_inner(buf);
}

buf = &buf[..crate::io::blocking::MAX_BUF];
buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE];

// Now there are two possibilities.
// If caller gave is binary buffer, we **should not** shrink it
Expand Down Expand Up @@ -108,7 +109,7 @@ where
#[cfg(test)]
#[cfg(not(loom))]
mod tests {
use crate::io::blocking::MAX_BUF;
use crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
use crate::io::AsyncWriteExt;
use std::io;
use std::pin::Pin;
Expand All @@ -123,7 +124,7 @@ mod tests {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
assert!(buf.len() <= MAX_BUF);
assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
assert!(std::str::from_utf8(buf).is_ok());
Poll::Ready(Ok(buf.len()))
}
Expand Down Expand Up @@ -158,7 +159,7 @@ mod tests {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
assert!(buf.len() <= MAX_BUF);
assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
self.write_history.push(buf.len());
Poll::Ready(Ok(buf.len()))
}
Expand All @@ -178,7 +179,7 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)]
fn test_splitter() {
let data = str::repeat("█", MAX_BUF);
let data = str::repeat("█", DEFAULT_MAX_BUF_SIZE);
let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
let fut = async move {
wr.write_all(data.as_bytes()).await.unwrap();
Expand All @@ -197,7 +198,7 @@ mod tests {
// was not shrunk too much.
let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
let mut data: Vec<u8> = str::repeat("a", checked_count).into();
data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1));
data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1));
let mut writer = LoggingMockWriter::new();
let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
crate::runtime::Builder::new_current_thread()
Expand All @@ -214,7 +215,7 @@ mod tests {
data.len()
);
// Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk
// from the buffer: one because it was outside of MAX_BUF boundary, and
// from the buffer: one because it was outside of DEFAULT_MAX_BUF_SIZE boundary, and
// up to one "utf8 code point".
assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
}
Expand Down
22 changes: 22 additions & 0 deletions tokio/tests/fs_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ fn tempfile() -> NamedTempFile {
NamedTempFile::new().unwrap()
}

#[tokio::test]
async fn set_max_buf_size_read() {
let mut tempfile = tempfile();
tempfile.write_all(HELLO).unwrap();
let mut file = File::open(tempfile.path()).await.unwrap();
let mut buf = [0; 1024];
file.set_max_buf_size(1);

// A single read operation reads a maximum of 1 byte.
assert_eq!(file.read(&mut buf).await.unwrap(), 1);
}

#[tokio::test]
async fn set_max_buf_size_write() {
let tempfile = tempfile();
let mut file = File::create(tempfile.path()).await.unwrap();
file.set_max_buf_size(1);

// A single write operation writes a maximum of 1 byte.
assert_eq!(file.write(HELLO).await.unwrap(), 1);
}

#[tokio::test]
#[cfg(unix)]
async fn file_debug_fmt() {
Expand Down