Skip to content

Commit

Permalink
Add async LineWriter (#2477)
Browse files Browse the repository at this point in the history
  • Loading branch information
FelipeLema committed Oct 8, 2021
1 parent ee23679 commit 3601bb7
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 1 deletion.
65 changes: 64 additions & 1 deletion futures-util/src/io/buf_writer.rs
Expand Up @@ -6,6 +6,7 @@ use pin_project_lite::pin_project;
use std::fmt;
use std::io::{self, Write};
use std::pin::Pin;
use std::ptr;

pin_project! {
/// Wraps a writer and buffers its output.
Expand Down Expand Up @@ -49,7 +50,7 @@ impl<W: AsyncWrite> BufWriter<W> {
Self { inner, buf: Vec::with_capacity(cap), written: 0 }
}

fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
pub(super) fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.project();

let len = this.buf.len();
Expand Down Expand Up @@ -83,6 +84,68 @@ impl<W: AsyncWrite> BufWriter<W> {
pub fn buffer(&self) -> &[u8] {
&self.buf
}

/// Capacity of `buf`. how many chars can be held in buffer
pub(super) fn capacity(&self) -> usize {
self.buf.capacity()
}

/// Remaining number of bytes to reach `buf` 's capacity
#[inline]
pub(super) fn spare_capacity(&self) -> usize {
self.buf.capacity() - self.buf.len()
}

/// Write a byte slice directly into buffer
///
/// Will truncate the number of bytes written to `spare_capacity()` so you want to
/// calculate the size of your slice to avoid losing bytes
///
/// Based on `std::io::BufWriter`
pub(super) fn write_to_buf(self: Pin<&mut Self>, buf: &[u8]) -> usize {
let available = self.spare_capacity();
let amt_to_buffer = available.min(buf.len());

// SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
unsafe {
self.write_to_buffer_unchecked(&buf[..amt_to_buffer]);
}

amt_to_buffer
}

/// Write byte slice directly into `self.buf`
///
/// Based on `std::io::BufWriter`
#[inline]
unsafe fn write_to_buffer_unchecked(self: Pin<&mut Self>, buf: &[u8]) {
debug_assert!(buf.len() <= self.spare_capacity());
let this = self.project();
let old_len = this.buf.len();
let buf_len = buf.len();
let src = buf.as_ptr();
let dst = this.buf.as_mut_ptr().add(old_len);
ptr::copy_nonoverlapping(src, dst, buf_len);
this.buf.set_len(old_len + buf_len);
}

/// Write directly using `inner`, bypassing buffering
pub(super) fn inner_poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}

/// Write directly using `inner`, bypassing buffering
pub(super) fn inner_poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
}

impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
Expand Down
155 changes: 155 additions & 0 deletions futures-util/src/io/line_writer.rs
@@ -0,0 +1,155 @@
use super::buf_writer::BufWriter;
use futures_core::ready;
use futures_core::task::{Context, Poll};
use futures_io::AsyncWrite;
use futures_io::IoSlice;
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;

pin_project! {
/// Wrap a writer, like [`BufWriter`] does, but prioritizes buffering lines
///
/// This was written based on `std::io::LineWriter` which goes into further details
/// explaining the code.
///
/// Buffering is actually done using `BufWriter`. This class will leverage `BufWriter`
/// to write on-each-line.
#[derive(Debug)]
pub struct LineWriter<W: AsyncWrite> {
#[pin]
buf_writer: BufWriter<W>,
}
}

impl<W: AsyncWrite> LineWriter<W> {
/// Create a new `LineWriter` with default buffer capacity. The default is currently 1KB
/// which was taken from `std::io::LineWriter`
pub fn new(inner: W) -> LineWriter<W> {
LineWriter::with_capacity(1024, inner)
}

/// Creates a new `LineWriter` with the specified buffer capacity.
pub fn with_capacity(capacity: usize, inner: W) -> LineWriter<W> {
LineWriter { buf_writer: BufWriter::with_capacity(capacity, inner) }
}

/// Flush `buf_writer` if last char is "new line"
fn flush_if_completed_line(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
match this.buf_writer.buffer().last().copied() {
Some(b'\n') => this.buf_writer.flush_buf(cx),
_ => Poll::Ready(Ok(())),
}
}

/// Returns a reference to `buf_writer`'s internally buffered data.
pub fn buffer(&self) -> &[u8] {
self.buf_writer.buffer()
}

/// Acquires a reference to the underlying sink or stream that this combinator is
/// pulling from.
pub fn get_ref(&self) -> &W {
self.buf_writer.get_ref()
}
}

impl<W: AsyncWrite> AsyncWrite for LineWriter<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let mut this = self.as_mut().project();
let newline_index = match memchr::memrchr(b'\n', buf) {
None => {
ready!(self.as_mut().flush_if_completed_line(cx)?);
return self.project().buf_writer.poll_write(cx, buf);
}
Some(newline_index) => newline_index + 1,
};

ready!(this.buf_writer.as_mut().poll_flush(cx)?);

let lines = &buf[..newline_index];

let flushed = { ready!(this.buf_writer.as_mut().inner_poll_write(cx, lines))? };

if flushed == 0 {
return Poll::Ready(Ok(0));
}

let tail = if flushed >= newline_index {
&buf[flushed..]
} else if newline_index - flushed <= this.buf_writer.capacity() {
&buf[flushed..newline_index]
} else {
let scan_area = &buf[flushed..];
let scan_area = &scan_area[..this.buf_writer.capacity()];
match memchr::memrchr(b'\n', scan_area) {
Some(newline_index) => &scan_area[..newline_index + 1],
None => scan_area,
}
};

let buffered = this.buf_writer.as_mut().write_to_buf(tail);
Poll::Ready(Ok(flushed + buffered))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let mut this = self.as_mut().project();
// `is_write_vectored()` is handled in original code, but not in this crate
// see https://github.com/rust-lang/rust/issues/70436

let last_newline_buf_idx = bufs
.iter()
.enumerate()
.rev()
.find_map(|(i, buf)| memchr::memchr(b'\n', buf).map(|_| i));
let last_newline_buf_idx = match last_newline_buf_idx {
None => {
ready!(self.as_mut().flush_if_completed_line(cx)?);
return self.project().buf_writer.poll_write_vectored(cx, bufs);
}
Some(i) => i,
};

ready!(this.buf_writer.as_mut().poll_flush(cx)?);

let (lines, tail) = bufs.split_at(last_newline_buf_idx + 1);

let flushed = { ready!(this.buf_writer.as_mut().inner_poll_write_vectored(cx, lines))? };
if flushed == 0 {
return Poll::Ready(Ok(0));
}

let lines_len = lines.iter().map(|buf| buf.len()).sum();
if flushed < lines_len {
return Poll::Ready(Ok(flushed));
}

let buffered: usize = tail
.iter()
.filter(|buf| !buf.is_empty())
.map(|buf| this.buf_writer.as_mut().write_to_buf(buf))
.take_while(|&n| n > 0)
.sum();

Poll::Ready(Ok(flushed + buffered))
}

/// Forward to `buf_writer` 's `BufWriter::poll_flush()`
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.as_mut().project().buf_writer.poll_flush(cx)
}

/// Forward to `buf_writer` 's `BufWriter::poll_close()`
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.as_mut().project().buf_writer.poll_close(cx)
}
}
3 changes: 3 additions & 0 deletions futures-util/src/io/mod.rs
Expand Up @@ -61,6 +61,9 @@ pub use self::buf_reader::{BufReader, SeeKRelative};
mod buf_writer;
pub use self::buf_writer::BufWriter;

mod line_writer;
pub use self::line_writer::LineWriter;

mod chain;
pub use self::chain::Chain;

Expand Down
73 changes: 73 additions & 0 deletions futures/tests/io_line_writer.rs
@@ -0,0 +1,73 @@
use futures::executor::block_on;
use futures::io::{AsyncWriteExt, LineWriter};
use std::io;

#[test]
fn line_writer() {
let mut writer = LineWriter::new(Vec::new());

block_on(writer.write(&[0])).unwrap();
assert_eq!(*writer.get_ref(), []);

block_on(writer.write(&[1])).unwrap();
assert_eq!(*writer.get_ref(), []);

block_on(writer.flush()).unwrap();
assert_eq!(*writer.get_ref(), [0, 1]);

block_on(writer.write(&[0, b'\n', 1, b'\n', 2])).unwrap();
assert_eq!(*writer.get_ref(), [0, 1, 0, b'\n', 1, b'\n']);

block_on(writer.flush()).unwrap();
assert_eq!(*writer.get_ref(), [0, 1, 0, b'\n', 1, b'\n', 2]);

block_on(writer.write(&[3, b'\n'])).unwrap();
assert_eq!(*writer.get_ref(), [0, 1, 0, b'\n', 1, b'\n', 2, 3, b'\n']);
}

#[test]
fn line_vectored() {
let mut line_writer = LineWriter::new(Vec::new());
assert_eq!(
block_on(line_writer.write_vectored(&[
io::IoSlice::new(&[]),
io::IoSlice::new(b"\n"),
io::IoSlice::new(&[]),
io::IoSlice::new(b"a"),
]))
.unwrap(),
2
);
assert_eq!(line_writer.get_ref(), b"\n");

assert_eq!(
block_on(line_writer.write_vectored(&[
io::IoSlice::new(&[]),
io::IoSlice::new(b"b"),
io::IoSlice::new(&[]),
io::IoSlice::new(b"a"),
io::IoSlice::new(&[]),
io::IoSlice::new(b"c"),
]))
.unwrap(),
3
);
assert_eq!(line_writer.get_ref(), b"\n");
block_on(line_writer.flush()).unwrap();
assert_eq!(line_writer.get_ref(), b"\nabac");
assert_eq!(block_on(line_writer.write_vectored(&[])).unwrap(), 0);

assert_eq!(
block_on(line_writer.write_vectored(&[
io::IoSlice::new(&[]),
io::IoSlice::new(&[]),
io::IoSlice::new(&[]),
io::IoSlice::new(&[]),
]))
.unwrap(),
0
);

assert_eq!(block_on(line_writer.write_vectored(&[io::IoSlice::new(b"a\nb")])).unwrap(), 3);
assert_eq!(line_writer.get_ref(), b"\nabaca\nb");
}

0 comments on commit 3601bb7

Please sign in to comment.