Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient implementation of vectored writes for BufWriter #3163

Merged
merged 8 commits into from Jun 29, 2021
14 changes: 13 additions & 1 deletion tokio/src/io/util/buf_reader.rs
Expand Up @@ -2,7 +2,7 @@ use crate::io::util::DEFAULT_BUF_SIZE;
use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::io::{self, SeekFrom};
use std::io::{self, IoSlice, SeekFrom};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{cmp, fmt, mem};
Expand Down Expand Up @@ -268,6 +268,18 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
self.get_pin_mut().poll_write(cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.get_pin_mut().poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.get_ref().is_write_vectored()
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_flush(cx)
}
Expand Down
14 changes: 13 additions & 1 deletion tokio/src/io/util/buf_stream.rs
Expand Up @@ -2,7 +2,7 @@ use crate::io::util::{BufReader, BufWriter};
use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::io::{self, SeekFrom};
use std::io::{self, IoSlice, SeekFrom};
use std::pin::Pin;
use std::task::{Context, Poll};

Expand Down Expand Up @@ -127,6 +127,18 @@ impl<RW: AsyncRead + AsyncWrite> AsyncWrite for BufStream<RW> {
self.project().inner.poll_write(cx, buf)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
Expand Down
68 changes: 67 additions & 1 deletion tokio/src/io/util/buf_writer.rs
Expand Up @@ -3,7 +3,7 @@ use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::fmt;
use std::io::{self, SeekFrom, Write};
use std::io::{self, IoSlice, SeekFrom, Write};
use std::pin::Pin;
use std::task::{Context, Poll};

Expand Down Expand Up @@ -133,6 +133,72 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
}
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if self.inner.is_write_vectored() {
let total_len = bufs
.iter()
.fold(0usize, |acc, b| acc.saturating_add(b.len()));
if total_len > self.buf.capacity() - self.buf.len() {
ready!(self.as_mut().flush_buf(cx))?;
}
let me = self.as_mut().project();
if total_len >= me.buf.capacity() {
// It's more efficient to pass the slices directly to the
// underlying writer than to buffer them.
// The case when the total_len calculation saturates at
// usize::MAX is also handled here.
me.inner.poll_write_vectored(cx, bufs)
} else {
bufs.iter().for_each(|b| me.buf.extend_from_slice(b));
Darksonn marked this conversation as resolved.
Show resolved Hide resolved
Poll::Ready(Ok(total_len))
}
} else {
// Remove empty buffers at the beginning of bufs.
while bufs.first().map(|buf| buf.len()) == Some(0) {
bufs = &bufs[1..];
}
if bufs.is_empty() {
return Poll::Ready(Ok(0));
}
// Flush if the first buffer doesn't fit.
let first_len = bufs[0].len();
if first_len > self.buf.capacity() - self.buf.len() {
ready!(self.as_mut().flush_buf(cx))?;
debug_assert!(self.buf.is_empty());
}
let me = self.as_mut().project();
if first_len >= me.buf.capacity() {
// The slice is at least as large as the buffering capacity,
// so it's better to write it directly, bypassing the buffer.
debug_assert!(me.buf.is_empty());
return me.inner.poll_write(cx, &bufs[0]);
mzabaluev marked this conversation as resolved.
Show resolved Hide resolved
} else {
me.buf.extend_from_slice(&bufs[0]);
bufs = &bufs[1..];
}
let mut total_written = first_len;
debug_assert!(total_written != 0);
// Append the buffers that fit in the internal buffer.
for buf in bufs {
if buf.len() > me.buf.capacity() - me.buf.len() {
break;
} else {
me.buf.extend_from_slice(buf);
total_written += buf.len();
}
}
Poll::Ready(Ok(total_written))
}
}

fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(self.as_mut().flush_buf(cx))?;
self.get_pin_mut().poll_flush(cx)
Expand Down