Skip to content

Commit

Permalink
test write_vectored on BufWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
mzabaluev committed Dec 16, 2020
1 parent 5e0a93c commit c2aee65
Show file tree
Hide file tree
Showing 2 changed files with 335 additions and 0 deletions.
290 changes: 290 additions & 0 deletions tokio/tests/io_buf_writer.rs
@@ -0,0 +1,290 @@
use tokio::io::BufWriter;
use tokio::prelude::*;

use futures::future;
use tokio_test::assert_ok;

use std::cmp;
use std::io::IoSlice;
use std::pin::Pin;
use std::task::{Context, Poll};

mod support {
pub(crate) mod io_vec;
}
use support::io_vec::IoBufs;

async fn write_vectored<W>(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result<usize>
where
W: AsyncWrite + Unpin,
{
let mut writer = Pin::new(writer);
future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, bufs)).await
}

struct MockWriter {
data: Vec<u8>,
write_len: usize,
vectored: bool,
}

impl MockWriter {
fn new(write_len: usize) -> Self {
MockWriter {
data: Vec::new(),
write_len,
vectored: false,
}
}

fn vectored(write_len: usize) -> Self {
MockWriter {
data: Vec::new(),
write_len,
vectored: true,
}
}

fn write_up_to(&mut self, buf: &[u8], limit: usize) -> usize {
let len = cmp::min(buf.len(), limit);
self.data.extend_from_slice(&buf[..len]);
len
}
}

impl AsyncWrite for MockWriter {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.get_mut();
let n = this.write_up_to(buf, this.write_len);
Ok(n).into()
}

fn poll_write_vectored(
self: Pin<&mut Self>,
_: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let this = self.get_mut();
let mut total_written = 0;
for buf in bufs {
let n = this.write_up_to(buf, this.write_len - total_written);
total_written += n;
if total_written == this.write_len {
break;
}
}
Ok(total_written).into()
}

fn is_write_vectored(&self) -> bool {
self.vectored
}

fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Ok(()).into()
}

fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Ok(()).into()
}
}

#[tokio::test]
async fn write_vectored_empty_on_non_vectored() {
let mut w = BufWriter::new(MockWriter::new(4));
let n = assert_ok!(write_vectored(&mut w, &[]).await);
assert_eq!(n, 0);

let io_vec = [IoSlice::new(&[]); 3];
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
assert_eq!(n, 0);

assert_ok!(w.flush().await);
assert!(w.get_ref().data.is_empty());
}

#[tokio::test]
async fn write_vectored_empty_on_vectored() {
let mut w = BufWriter::new(MockWriter::vectored(4));
let n = assert_ok!(write_vectored(&mut w, &[]).await);
assert_eq!(n, 0);

let io_vec = [IoSlice::new(&[]); 3];
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
assert_eq!(n, 0);

assert_ok!(w.flush().await);
assert!(w.get_ref().data.is_empty());
}

#[tokio::test]
async fn write_vectored_basic_on_non_vectored() {
let msg = b"foo bar baz";
let bufs = [
IoSlice::new(&msg[0..4]),
IoSlice::new(&msg[4..8]),
IoSlice::new(&msg[8..]),
];
let mut w = BufWriter::new(MockWriter::new(4));
let n = assert_ok!(write_vectored(&mut w, &bufs).await);
assert_eq!(n, msg.len());
assert!(w.buffer() == &msg[..]);
assert_ok!(w.flush().await);
assert_eq!(w.get_ref().data, msg);
}

#[tokio::test]
async fn write_vectored_basic_on_vectored() {
let msg = b"foo bar baz";
let bufs = [
IoSlice::new(&msg[0..4]),
IoSlice::new(&msg[4..8]),
IoSlice::new(&msg[8..]),
];
let mut w = BufWriter::new(MockWriter::vectored(4));
let n = assert_ok!(write_vectored(&mut w, &bufs).await);
assert_eq!(n, msg.len());
assert!(w.buffer() == &msg[..]);
assert_ok!(w.flush().await);
assert_eq!(w.get_ref().data, msg);
}

#[tokio::test]
async fn write_vectored_large_total_on_non_vectored() {
let msg = b"foo bar baz";
let mut bufs = [
IoSlice::new(&msg[0..4]),
IoSlice::new(&msg[4..8]),
IoSlice::new(&msg[8..]),
];
let io_vec = IoBufs::new(&mut bufs);
let mut w = BufWriter::with_capacity(8, MockWriter::new(4));
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
assert_eq!(n, 8);
assert!(w.buffer() == &msg[..8]);
let io_vec = io_vec.advance(n);
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
assert_eq!(n, 3);
assert!(w.get_ref().data.as_slice() == &msg[..8]);
assert!(w.buffer() == &msg[8..]);
}

#[tokio::test]
async fn write_vectored_large_total_on_vectored() {
let msg = b"foo bar baz";
let mut bufs = [
IoSlice::new(&msg[0..4]),
IoSlice::new(&msg[4..8]),
IoSlice::new(&msg[8..]),
];
let io_vec = IoBufs::new(&mut bufs);
let mut w = BufWriter::with_capacity(8, MockWriter::vectored(10));
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
assert_eq!(n, 10);
assert!(w.buffer().is_empty());
let io_vec = io_vec.advance(n);
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
assert_eq!(n, 1);
assert!(w.get_ref().data.as_slice() == &msg[..10]);
assert!(w.buffer() == &msg[10..]);
}

struct VectoredWriteHarness {
writer: BufWriter<MockWriter>,
buf_capacity: usize,
}

impl VectoredWriteHarness {
fn new(buf_capacity: usize) -> Self {
VectoredWriteHarness {
writer: BufWriter::with_capacity(buf_capacity, MockWriter::new(4)),
buf_capacity,
}
}

fn with_vectored_backend(buf_capacity: usize) -> Self {
VectoredWriteHarness {
writer: BufWriter::with_capacity(buf_capacity, MockWriter::vectored(4)),
buf_capacity,
}
}

async fn write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize {
let mut total_written = 0;
while !io_vec.is_empty() {
let n = assert_ok!(write_vectored(&mut self.writer, &io_vec).await);
assert!(n != 0);
assert!(self.writer.buffer().len() <= self.buf_capacity);
total_written += n;
io_vec = io_vec.advance(n);
}
total_written
}

async fn flush(&mut self) -> &[u8] {
assert_ok!(self.writer.flush().await);
&self.writer.get_ref().data
}
}

#[tokio::test]
async fn write_vectored_odd_on_non_vectored() {
let msg = b"foo bar baz";
let mut bufs = [
IoSlice::new(&msg[0..4]),
IoSlice::new(&[]),
IoSlice::new(&msg[4..9]),
IoSlice::new(&msg[9..]),
];
let mut h = VectoredWriteHarness::new(8);
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
assert_eq!(bytes_written, msg.len());
assert_eq!(h.flush().await, msg);
}

#[tokio::test]
async fn write_vectored_odd_on_vectored() {
let msg = b"foo bar baz";
let mut bufs = [
IoSlice::new(&msg[0..4]),
IoSlice::new(&[]),
IoSlice::new(&msg[4..9]),
IoSlice::new(&msg[9..]),
];
let mut h = VectoredWriteHarness::with_vectored_backend(8);
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
assert_eq!(bytes_written, msg.len());
assert_eq!(h.flush().await, msg);
}

#[tokio::test]
async fn write_vectored_large_slice_on_non_vectored() {
let msg = b"foo bar baz";
let mut bufs = [
IoSlice::new(&[]),
IoSlice::new(&msg[..9]),
IoSlice::new(&msg[9..]),
];
let mut h = VectoredWriteHarness::new(8);
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
assert_eq!(bytes_written, msg.len());
assert_eq!(h.flush().await, msg);
}

#[tokio::test]
async fn write_vectored_large_slice_on_vectored() {
let msg = b"foo bar baz";
let mut bufs = [
IoSlice::new(&[]),
IoSlice::new(&msg[..9]),
IoSlice::new(&msg[9..]),
];
let mut h = VectoredWriteHarness::with_vectored_backend(8);
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
assert_eq!(bytes_written, msg.len());
assert_eq!(h.flush().await, msg);
}
45 changes: 45 additions & 0 deletions tokio/tests/support/io_vec.rs
@@ -0,0 +1,45 @@
use std::io::IoSlice;
use std::ops::Deref;
use std::slice;

pub struct IoBufs<'a, 'b>(&'b mut [IoSlice<'a>]);

impl<'a, 'b> IoBufs<'a, 'b> {
pub fn new(slices: &'b mut [IoSlice<'a>]) -> Self {
IoBufs(slices)
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

pub fn advance(mut self, n: usize) -> IoBufs<'a, 'b> {
let mut to_remove = 0;
let mut remaining_len = n;
for slice in self.0.iter() {
if remaining_len < slice.len() {
break;
} else {
remaining_len -= slice.len();
to_remove += 1;
}
}
self.0 = self.0.split_at_mut(to_remove).1;
if let Some(slice) = self.0.first_mut() {
let tail = &slice[remaining_len..];
// Safety: recasts slice to the original lifetime
let tail = unsafe { slice::from_raw_parts(tail.as_ptr(), tail.len()) };
*slice = IoSlice::new(tail);
} else if remaining_len != 0 {
panic!("advance past the end of the slice vector");
}
self
}
}

impl<'a, 'b> Deref for IoBufs<'a, 'b> {
type Target = [IoSlice<'a>];
fn deref(&self) -> &[IoSlice<'a>] {
self.0
}
}

0 comments on commit c2aee65

Please sign in to comment.