From db5328759d3a562385ef606433e5938b988de728 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Sun, 1 Nov 2020 13:06:25 +0200 Subject: [PATCH 1/8] io: implement vectored output for BufWriter Implement AsyncWrite::poll_write_vectored for BufWriter, making use of the buffer to coalesce data from the passed slices. Make exceptions for cases when writing directly to the underlying object is more efficient, which differ depending on whether the underlying writer itself supports vectored output. Change the implementation of AsyncWrite::is_write_vectored to return true. --- tokio/src/io/util/buf_writer.rs | 68 ++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index 4e8e493cefe..9a188c26cae 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -2,8 +2,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; +use std::cmp; use std::fmt; -use std::io::{self, SeekFrom, Write}; +use std::io::{self, IoSlice, SeekFrom, Write}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -133,6 +134,71 @@ impl AsyncWrite for BufWriter { } } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + if self.as_mut().project().inner.is_write_vectored() { + let total_len = bufs.iter().map(|b| b.len()).sum::(); + if self.buf.len() + total_len > self.buf.capacity() { + ready!(self.as_mut().flush_buf(cx))?; + } + let me = self.as_mut().project(); + if total_len >= me.buf.capacity() { + // It's more efficient to pass the slices directly to the + // underlying writer than to buffer them. + me.inner.poll_write_vectored(cx, bufs) + } else { + bufs.iter().for_each(|b| me.buf.extend_from_slice(b)); + Poll::Ready(Ok(total_len)) + } + } else { + let mut total_written = 0; + let mut iter = bufs.iter(); + if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) { + // This is the first non-empty slice to write, so if it does + // not fit in the buffer, we still get to flush and proceed. + if self.buf.len() + buf.len() > self.buf.capacity() { + ready!(self.as_mut().flush_buf(cx))?; + } + let me = self.as_mut().project(); + if buf.len() >= me.buf.capacity() { + // The slice is at least as large as the buffering capacity, + // so it's better to write it directly, bypassing the buffer. + return me.inner.poll_write(cx, buf); + } else { + me.buf.extend_from_slice(buf); + total_written += buf.len(); + } + debug_assert!(total_written != 0); + } + for buf in iter { + let me = self.as_mut().project(); + if buf.len() >= me.buf.capacity() { + // This slice should be written directly, but we have already + // buffered some of the input. Bail out, expecting it to be + // handled as the first slice in the next call to + // poll_write_vectored. + break; + } else { + let write_to = cmp::min(buf.len(), me.buf.capacity() - me.buf.len()); + me.buf.extend_from_slice(&buf[..write_to]); + total_written += write_to; + if me.buf.capacity() == me.buf.len() { + // The buffer is full, bail out + break; + } + } + } + Poll::Ready(Ok(total_written)) + } + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().flush_buf(cx))?; self.get_pin_mut().poll_flush(cx) From f805b69c2ce4468620bf278e6784392bc34aa8a0 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Sun, 1 Nov 2020 13:36:04 +0200 Subject: [PATCH 2/8] io: vectored output for BufReader and BufStream Forward poll_write_vectored/is_write_vectored to the underlying object for BufReader. Do the same in BufStream, which has the effect of is_write_vectored returning true because BufWriter's implementation now does. --- tokio/src/io/util/buf_reader.rs | 14 +++++++++++++- tokio/src/io/util/buf_stream.rs | 14 +++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tokio/src/io/util/buf_reader.rs b/tokio/src/io/util/buf_reader.rs index c4d6842d480..7cfd46ce03e 100644 --- a/tokio/src/io/util/buf_reader.rs +++ b/tokio/src/io/util/buf_reader.rs @@ -2,7 +2,7 @@ use crate::io::util::DEFAULT_BUF_SIZE; use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io::{self, SeekFrom}; +use std::io::{self, IoSlice, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, fmt, mem}; @@ -268,6 +268,18 @@ impl AsyncWrite for BufReader { self.get_pin_mut().poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + self.get_pin_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.get_ref().is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.get_pin_mut().poll_flush(cx) } diff --git a/tokio/src/io/util/buf_stream.rs b/tokio/src/io/util/buf_stream.rs index ff3d9dba86e..595c142aca5 100644 --- a/tokio/src/io/util/buf_stream.rs +++ b/tokio/src/io/util/buf_stream.rs @@ -2,7 +2,7 @@ use crate::io::util::{BufReader, BufWriter}; use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io::{self, SeekFrom}; +use std::io::{self, IoSlice, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -127,6 +127,18 @@ impl AsyncWrite for BufStream { self.project().inner.poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.project().inner.poll_flush(cx) } From ab30c42a37ecf1e5c79a9e581d891449a14c0efe Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Sun, 22 Nov 2020 16:13:33 +0200 Subject: [PATCH 3/8] io: simplify poll_write_vectored for BufWriter Simplify branching and do what poll_write does: do not buffer slices partially, optimizing for the most likely case where slices are much smaller than the buffer, while retaining special treatment for oversized slices. --- tokio/src/io/util/buf_writer.rs | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index 9a188c26cae..c403da5e63a 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -2,7 +2,6 @@ use crate::io::util::DEFAULT_BUF_SIZE; use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::cmp; use std::fmt; use std::io::{self, IoSlice, SeekFrom, Write}; use std::pin::Pin; @@ -154,9 +153,8 @@ impl AsyncWrite for BufWriter { Poll::Ready(Ok(total_len)) } } else { - let mut total_written = 0; let mut iter = bufs.iter(); - if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) { + let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) { // This is the first non-empty slice to write, so if it does // not fit in the buffer, we still get to flush and proceed. if self.buf.len() + buf.len() > self.buf.capacity() { @@ -169,26 +167,19 @@ impl AsyncWrite for BufWriter { return me.inner.poll_write(cx, buf); } else { me.buf.extend_from_slice(buf); - total_written += buf.len(); + buf.len() } - debug_assert!(total_written != 0); - } + } else { + return Poll::Ready(Ok(0)); + }; + debug_assert!(total_written != 0); for buf in iter { let me = self.as_mut().project(); - if buf.len() >= me.buf.capacity() { - // This slice should be written directly, but we have already - // buffered some of the input. Bail out, expecting it to be - // handled as the first slice in the next call to - // poll_write_vectored. + if me.buf.len() + buf.len() > me.buf.capacity() { break; } else { - let write_to = cmp::min(buf.len(), me.buf.capacity() - me.buf.len()); - me.buf.extend_from_slice(&buf[..write_to]); - total_written += write_to; - if me.buf.capacity() == me.buf.len() { - // The buffer is full, bail out - break; - } + me.buf.extend_from_slice(buf); + total_written += buf.len(); } } Poll::Ready(Ok(total_written)) From 32f8a8dd31fd2f9f8a022ee8bac207212d32fb63 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Tue, 3 Nov 2020 06:04:25 +0200 Subject: [PATCH 4/8] test write_vectored on BufWriter --- tokio/tests/io_buf_writer.rs | 286 ++++++++++++++++++++++++++++++++++ tokio/tests/support/io_vec.rs | 45 ++++++ 2 files changed, 331 insertions(+) create mode 100644 tokio/tests/support/io_vec.rs diff --git a/tokio/tests/io_buf_writer.rs b/tokio/tests/io_buf_writer.rs index 6f4f10a8e2e..47a0d466f49 100644 --- a/tokio/tests/io_buf_writer.rs +++ b/tokio/tests/io_buf_writer.rs @@ -8,6 +8,17 @@ use std::io::{self, Cursor}; use std::pin::Pin; use tokio::io::{AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter, SeekFrom}; +use futures::future; +use tokio_test::assert_ok; + +use std::cmp; +use std::io::IoSlice; + +mod support { + pub(crate) mod io_vec; +} +use support::io_vec::IoBufs; + struct MaybePending { inner: Vec, ready: bool, @@ -47,6 +58,14 @@ impl AsyncWrite for MaybePending { } } +async fn write_vectored(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result +where + W: AsyncWrite + Unpin, +{ + let mut writer = Pin::new(writer); + future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, bufs)).await +} + #[tokio::test] async fn buf_writer() { let mut writer = BufWriter::with_capacity(2, Vec::new()); @@ -249,3 +268,270 @@ async fn maybe_pending_buf_writer_seek() { &[0, 1, 8, 9, 4, 5, 6, 7] ); } + +struct MockWriter { + data: Vec, + write_len: usize, + vectored: bool, +} + +impl MockWriter { + fn new(write_len: usize) -> Self { + MockWriter { + data: Vec::new(), + write_len, + vectored: false, + } + } + + fn vectored(write_len: usize) -> Self { + MockWriter { + data: Vec::new(), + write_len, + vectored: true, + } + } + + fn write_up_to(&mut self, buf: &[u8], limit: usize) -> usize { + let len = cmp::min(buf.len(), limit); + self.data.extend_from_slice(&buf[..len]); + len + } +} + +impl AsyncWrite for MockWriter { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + let n = this.write_up_to(buf, this.write_len); + Ok(n).into() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let this = self.get_mut(); + let mut total_written = 0; + for buf in bufs { + let n = this.write_up_to(buf, this.write_len - total_written); + total_written += n; + if total_written == this.write_len { + break; + } + } + Ok(total_written).into() + } + + fn is_write_vectored(&self) -> bool { + self.vectored + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Ok(()).into() + } +} + +#[tokio::test] +async fn write_vectored_empty_on_non_vectored() { + let mut w = BufWriter::new(MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &[]).await); + assert_eq!(n, 0); + + let io_vec = [IoSlice::new(&[]); 3]; + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 0); + + assert_ok!(w.flush().await); + assert!(w.get_ref().data.is_empty()); +} + +#[tokio::test] +async fn write_vectored_empty_on_vectored() { + let mut w = BufWriter::new(MockWriter::vectored(4)); + let n = assert_ok!(write_vectored(&mut w, &[]).await); + assert_eq!(n, 0); + + let io_vec = [IoSlice::new(&[]); 3]; + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 0); + + assert_ok!(w.flush().await); + assert!(w.get_ref().data.is_empty()); +} + +#[tokio::test] +async fn write_vectored_basic_on_non_vectored() { + let msg = b"foo bar baz"; + let bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let mut w = BufWriter::new(MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &bufs).await); + assert_eq!(n, msg.len()); + assert!(w.buffer() == &msg[..]); + assert_ok!(w.flush().await); + assert_eq!(w.get_ref().data, msg); +} + +#[tokio::test] +async fn write_vectored_basic_on_vectored() { + let msg = b"foo bar baz"; + let bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let mut w = BufWriter::new(MockWriter::vectored(4)); + let n = assert_ok!(write_vectored(&mut w, &bufs).await); + assert_eq!(n, msg.len()); + assert!(w.buffer() == &msg[..]); + assert_ok!(w.flush().await); + assert_eq!(w.get_ref().data, msg); +} + +#[tokio::test] +async fn write_vectored_large_total_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let io_vec = IoBufs::new(&mut bufs); + let mut w = BufWriter::with_capacity(8, MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 8); + assert!(w.buffer() == &msg[..8]); + let io_vec = io_vec.advance(n); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 3); + assert!(w.get_ref().data.as_slice() == &msg[..8]); + assert!(w.buffer() == &msg[8..]); +} + +#[tokio::test] +async fn write_vectored_large_total_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let io_vec = IoBufs::new(&mut bufs); + let mut w = BufWriter::with_capacity(8, MockWriter::vectored(10)); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 10); + assert!(w.buffer().is_empty()); + let io_vec = io_vec.advance(n); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 1); + assert!(w.get_ref().data.as_slice() == &msg[..10]); + assert!(w.buffer() == &msg[10..]); +} + +struct VectoredWriteHarness { + writer: BufWriter, + buf_capacity: usize, +} + +impl VectoredWriteHarness { + fn new(buf_capacity: usize) -> Self { + VectoredWriteHarness { + writer: BufWriter::with_capacity(buf_capacity, MockWriter::new(4)), + buf_capacity, + } + } + + fn with_vectored_backend(buf_capacity: usize) -> Self { + VectoredWriteHarness { + writer: BufWriter::with_capacity(buf_capacity, MockWriter::vectored(4)), + buf_capacity, + } + } + + async fn write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize { + let mut total_written = 0; + while !io_vec.is_empty() { + let n = assert_ok!(write_vectored(&mut self.writer, &io_vec).await); + assert!(n != 0); + assert!(self.writer.buffer().len() <= self.buf_capacity); + total_written += n; + io_vec = io_vec.advance(n); + } + total_written + } + + async fn flush(&mut self) -> &[u8] { + assert_ok!(self.writer.flush().await); + &self.writer.get_ref().data + } +} + +#[tokio::test] +async fn write_vectored_odd_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&[]), + IoSlice::new(&msg[4..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::new(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_odd_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&[]), + IoSlice::new(&msg[4..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::with_vectored_backend(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_large_slice_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&[]), + IoSlice::new(&msg[..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::new(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_large_slice_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&[]), + IoSlice::new(&msg[..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::with_vectored_backend(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} diff --git a/tokio/tests/support/io_vec.rs b/tokio/tests/support/io_vec.rs new file mode 100644 index 00000000000..4ea47c748d1 --- /dev/null +++ b/tokio/tests/support/io_vec.rs @@ -0,0 +1,45 @@ +use std::io::IoSlice; +use std::ops::Deref; +use std::slice; + +pub struct IoBufs<'a, 'b>(&'b mut [IoSlice<'a>]); + +impl<'a, 'b> IoBufs<'a, 'b> { + pub fn new(slices: &'b mut [IoSlice<'a>]) -> Self { + IoBufs(slices) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn advance(mut self, n: usize) -> IoBufs<'a, 'b> { + let mut to_remove = 0; + let mut remaining_len = n; + for slice in self.0.iter() { + if remaining_len < slice.len() { + break; + } else { + remaining_len -= slice.len(); + to_remove += 1; + } + } + self.0 = self.0.split_at_mut(to_remove).1; + if let Some(slice) = self.0.first_mut() { + let tail = &slice[remaining_len..]; + // Safety: recasts slice to the original lifetime + let tail = unsafe { slice::from_raw_parts(tail.as_ptr(), tail.len()) }; + *slice = IoSlice::new(tail); + } else if remaining_len != 0 { + panic!("advance past the end of the slice vector"); + } + self + } +} + +impl<'a, 'b> Deref for IoBufs<'a, 'b> { + type Target = [IoSlice<'a>]; + fn deref(&self) -> &[IoSlice<'a>] { + self.0 + } +} From 30a6591b82e6f1f2dbed2c4f8c23ba7eaa155109 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Mon, 4 Jan 2021 12:41:47 +0200 Subject: [PATCH 5/8] io: guard against integer overflow in BufWriter In the implementation of AsyncWrite::poll_write_vectored for BufWriter, the total length of the data can technically overflow usize even with safely obtained buffer slices, since slices may overlap. --- tokio/src/io/util/buf_writer.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index c403da5e63a..b54f597d035 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -138,15 +138,19 @@ impl AsyncWrite for BufWriter { cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { - if self.as_mut().project().inner.is_write_vectored() { - let total_len = bufs.iter().map(|b| b.len()).sum::(); - if self.buf.len() + total_len > self.buf.capacity() { + if self.inner.is_write_vectored() { + let total_len = bufs + .iter() + .fold(0usize, |acc, b| acc.saturating_add(b.len())); + if total_len > self.buf.capacity() - self.buf.len() { ready!(self.as_mut().flush_buf(cx))?; } let me = self.as_mut().project(); if total_len >= me.buf.capacity() { // It's more efficient to pass the slices directly to the // underlying writer than to buffer them. + // The case when the total_len calculation saturates at + // usize::MAX is also handled here. me.inner.poll_write_vectored(cx, bufs) } else { bufs.iter().for_each(|b| me.buf.extend_from_slice(b)); From f0d443d9693bc042a17e1c8b885193648f5a39be Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Tue, 29 Jun 2021 11:54:00 +0300 Subject: [PATCH 6/8] Rewrite non-vectored branch in poll_write_vectored Improve readability by updating the IOSlice vector slice as constituent slices are gone over by the iteration, rather than using an iterator. --- tokio/src/io/util/buf_writer.rs | 41 +++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index b54f597d035..e5dcef3cd15 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -136,7 +136,7 @@ impl AsyncWrite for BufWriter { fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], + mut bufs: &[IoSlice<'_>], ) -> Poll> { if self.inner.is_write_vectored() { let total_len = bufs @@ -157,27 +157,34 @@ impl AsyncWrite for BufWriter { Poll::Ready(Ok(total_len)) } } else { - let mut iter = bufs.iter(); - let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) { - // This is the first non-empty slice to write, so if it does - // not fit in the buffer, we still get to flush and proceed. - if self.buf.len() + buf.len() > self.buf.capacity() { - ready!(self.as_mut().flush_buf(cx))?; + let mut total_written = loop { + if bufs.is_empty() { + return Poll::Ready(Ok(0)); } - let me = self.as_mut().project(); - if buf.len() >= me.buf.capacity() { - // The slice is at least as large as the buffering capacity, - // so it's better to write it directly, bypassing the buffer. - return me.inner.poll_write(cx, buf); + if bufs[0].is_empty() { + bufs = &bufs[1..]; + continue; } else { - me.buf.extend_from_slice(buf); - buf.len() + let buf = &bufs[0]; + // This is the first non-empty slice to write, so if it does + // not fit in the buffer, we still get to flush and proceed. + if self.buf.len() + buf.len() > self.buf.capacity() { + ready!(self.as_mut().flush_buf(cx))?; + } + let me = self.as_mut().project(); + if buf.len() >= me.buf.capacity() { + // The slice is at least as large as the buffering capacity, + // so it's better to write it directly, bypassing the buffer. + return me.inner.poll_write(cx, buf); + } else { + me.buf.extend_from_slice(buf); + bufs = &bufs[1..]; + break buf.len(); + } } - } else { - return Poll::Ready(Ok(0)); }; debug_assert!(total_written != 0); - for buf in iter { + for buf in bufs { let me = self.as_mut().project(); if me.buf.len() + buf.len() > me.buf.capacity() { break; From 2dff0adede4eddc819cc0f9e2a92e93f379aaf46 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Tue, 29 Jun 2021 17:23:04 +0300 Subject: [PATCH 7/8] More readability fixes in poll_write_vectored Partially incorporated code suggested by @Darksonn. --- tokio/src/io/util/buf_writer.rs | 53 ++++++++++++++++----------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index e5dcef3cd15..eaa4a054ed0 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -157,36 +157,33 @@ impl AsyncWrite for BufWriter { Poll::Ready(Ok(total_len)) } } else { - let mut total_written = loop { - if bufs.is_empty() { - return Poll::Ready(Ok(0)); - } - if bufs[0].is_empty() { - bufs = &bufs[1..]; - continue; - } else { - let buf = &bufs[0]; - // This is the first non-empty slice to write, so if it does - // not fit in the buffer, we still get to flush and proceed. - if self.buf.len() + buf.len() > self.buf.capacity() { - ready!(self.as_mut().flush_buf(cx))?; - } - let me = self.as_mut().project(); - if buf.len() >= me.buf.capacity() { - // The slice is at least as large as the buffering capacity, - // so it's better to write it directly, bypassing the buffer. - return me.inner.poll_write(cx, buf); - } else { - me.buf.extend_from_slice(buf); - bufs = &bufs[1..]; - break buf.len(); - } - } - }; + // Remove empty buffers at the beginning of bufs. + while bufs.first().map(|buf| buf.len()) == Some(0) { + bufs = &bufs[1..]; + } + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + // Flush if the first buffer doesn't fit. + let first_len = bufs[0].len(); + if first_len > self.buf.capacity() - self.buf.len() { + ready!(self.as_mut().flush_buf(cx))?; + debug_assert!(self.buf.is_empty()); + } + let me = self.as_mut().project(); + if first_len >= me.buf.capacity() { + // The slice is at least as large as the buffering capacity, + // so it's better to write it directly, bypassing the buffer. + return me.inner.poll_write(cx, &bufs[0]); + } else { + me.buf.extend_from_slice(&bufs[0]); + bufs = &bufs[1..]; + } + let mut total_written = first_len; debug_assert!(total_written != 0); + // Append the buffers that fit in the internal buffer. for buf in bufs { - let me = self.as_mut().project(); - if me.buf.len() + buf.len() > me.buf.capacity() { + if buf.len() > me.buf.capacity() - me.buf.len() { break; } else { me.buf.extend_from_slice(buf); From 598560b2ac165cba0b2c7ae6a3ee56b9a6b4e8a2 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Tue, 29 Jun 2021 20:01:21 +0300 Subject: [PATCH 8/8] BufWriter: assert the buffer is empty when bypassing Co-authored-by: Alice Ryhl --- tokio/src/io/util/buf_writer.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index eaa4a054ed0..8dd1bba60ab 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -174,6 +174,7 @@ impl AsyncWrite for BufWriter { if first_len >= me.buf.capacity() { // The slice is at least as large as the buffering capacity, // so it's better to write it directly, bypassing the buffer. + debug_assert!(me.buf.is_empty()); return me.inner.poll_write(cx, &bufs[0]); } else { me.buf.extend_from_slice(&bufs[0]);