Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,290 @@ | ||
use tokio::io::BufWriter; | ||
use tokio::prelude::*; | ||
|
||
use futures::future; | ||
use tokio_test::assert_ok; | ||
|
||
use std::cmp; | ||
use std::io::IoSlice; | ||
use std::pin::Pin; | ||
use std::task::{Context, Poll}; | ||
|
||
mod support { | ||
pub(crate) mod io_vec; | ||
} | ||
use support::io_vec::IoBufs; | ||
|
||
async fn write_vectored<W>(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result<usize> | ||
where | ||
W: AsyncWrite + Unpin, | ||
{ | ||
let mut writer = Pin::new(writer); | ||
future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, bufs)).await | ||
} | ||
|
||
struct MockWriter { | ||
data: Vec<u8>, | ||
write_len: usize, | ||
vectored: bool, | ||
} | ||
|
||
impl MockWriter { | ||
fn new(write_len: usize) -> Self { | ||
MockWriter { | ||
data: Vec::new(), | ||
write_len, | ||
vectored: false, | ||
} | ||
} | ||
|
||
fn vectored(write_len: usize) -> Self { | ||
MockWriter { | ||
data: Vec::new(), | ||
write_len, | ||
vectored: true, | ||
} | ||
} | ||
|
||
fn write_up_to(&mut self, buf: &[u8], limit: usize) -> usize { | ||
let len = cmp::min(buf.len(), limit); | ||
self.data.extend_from_slice(&buf[..len]); | ||
len | ||
} | ||
} | ||
|
||
impl AsyncWrite for MockWriter { | ||
fn poll_write( | ||
self: Pin<&mut Self>, | ||
_: &mut Context<'_>, | ||
buf: &[u8], | ||
) -> Poll<Result<usize, io::Error>> { | ||
let this = self.get_mut(); | ||
let n = this.write_up_to(buf, this.write_len); | ||
Ok(n).into() | ||
} | ||
|
||
fn poll_write_vectored( | ||
self: Pin<&mut Self>, | ||
_: &mut Context<'_>, | ||
bufs: &[IoSlice<'_>], | ||
) -> Poll<Result<usize, io::Error>> { | ||
let this = self.get_mut(); | ||
let mut total_written = 0; | ||
for buf in bufs { | ||
let n = this.write_up_to(buf, this.write_len - total_written); | ||
total_written += n; | ||
if total_written == this.write_len { | ||
break; | ||
} | ||
} | ||
Ok(total_written).into() | ||
} | ||
|
||
fn is_write_vectored(&self) -> bool { | ||
self.vectored | ||
} | ||
|
||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { | ||
Ok(()).into() | ||
} | ||
|
||
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { | ||
Ok(()).into() | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_empty_on_non_vectored() { | ||
let mut w = BufWriter::new(MockWriter::new(4)); | ||
let n = assert_ok!(write_vectored(&mut w, &[]).await); | ||
assert_eq!(n, 0); | ||
|
||
let io_vec = [IoSlice::new(&[]); 3]; | ||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await); | ||
assert_eq!(n, 0); | ||
|
||
assert_ok!(w.flush().await); | ||
assert!(w.get_ref().data.is_empty()); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_empty_on_vectored() { | ||
let mut w = BufWriter::new(MockWriter::vectored(4)); | ||
let n = assert_ok!(write_vectored(&mut w, &[]).await); | ||
assert_eq!(n, 0); | ||
|
||
let io_vec = [IoSlice::new(&[]); 3]; | ||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await); | ||
assert_eq!(n, 0); | ||
|
||
assert_ok!(w.flush().await); | ||
assert!(w.get_ref().data.is_empty()); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_basic_on_non_vectored() { | ||
let msg = b"foo bar baz"; | ||
let bufs = [ | ||
IoSlice::new(&msg[0..4]), | ||
IoSlice::new(&msg[4..8]), | ||
IoSlice::new(&msg[8..]), | ||
]; | ||
let mut w = BufWriter::new(MockWriter::new(4)); | ||
let n = assert_ok!(write_vectored(&mut w, &bufs).await); | ||
assert_eq!(n, msg.len()); | ||
assert!(w.buffer() == &msg[..]); | ||
assert_ok!(w.flush().await); | ||
assert_eq!(w.get_ref().data, msg); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_basic_on_vectored() { | ||
let msg = b"foo bar baz"; | ||
let bufs = [ | ||
IoSlice::new(&msg[0..4]), | ||
IoSlice::new(&msg[4..8]), | ||
IoSlice::new(&msg[8..]), | ||
]; | ||
let mut w = BufWriter::new(MockWriter::vectored(4)); | ||
let n = assert_ok!(write_vectored(&mut w, &bufs).await); | ||
assert_eq!(n, msg.len()); | ||
assert!(w.buffer() == &msg[..]); | ||
assert_ok!(w.flush().await); | ||
assert_eq!(w.get_ref().data, msg); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_large_total_on_non_vectored() { | ||
let msg = b"foo bar baz"; | ||
let mut bufs = [ | ||
IoSlice::new(&msg[0..4]), | ||
IoSlice::new(&msg[4..8]), | ||
IoSlice::new(&msg[8..]), | ||
]; | ||
let io_vec = IoBufs::new(&mut bufs); | ||
let mut w = BufWriter::with_capacity(8, MockWriter::new(4)); | ||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await); | ||
assert_eq!(n, 8); | ||
assert!(w.buffer() == &msg[..8]); | ||
let io_vec = io_vec.advance(n); | ||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await); | ||
assert_eq!(n, 3); | ||
assert!(w.get_ref().data.as_slice() == &msg[..8]); | ||
assert!(w.buffer() == &msg[8..]); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_large_total_on_vectored() { | ||
let msg = b"foo bar baz"; | ||
let mut bufs = [ | ||
IoSlice::new(&msg[0..4]), | ||
IoSlice::new(&msg[4..8]), | ||
IoSlice::new(&msg[8..]), | ||
]; | ||
let io_vec = IoBufs::new(&mut bufs); | ||
let mut w = BufWriter::with_capacity(8, MockWriter::vectored(10)); | ||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await); | ||
assert_eq!(n, 10); | ||
assert!(w.buffer().is_empty()); | ||
let io_vec = io_vec.advance(n); | ||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await); | ||
assert_eq!(n, 1); | ||
assert!(w.get_ref().data.as_slice() == &msg[..10]); | ||
assert!(w.buffer() == &msg[10..]); | ||
} | ||
|
||
struct VectoredWriteHarness { | ||
writer: BufWriter<MockWriter>, | ||
buf_capacity: usize, | ||
} | ||
|
||
impl VectoredWriteHarness { | ||
fn new(buf_capacity: usize) -> Self { | ||
VectoredWriteHarness { | ||
writer: BufWriter::with_capacity(buf_capacity, MockWriter::new(4)), | ||
buf_capacity, | ||
} | ||
} | ||
|
||
fn with_vectored_backend(buf_capacity: usize) -> Self { | ||
VectoredWriteHarness { | ||
writer: BufWriter::with_capacity(buf_capacity, MockWriter::vectored(4)), | ||
buf_capacity, | ||
} | ||
} | ||
|
||
async fn write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize { | ||
let mut total_written = 0; | ||
while !io_vec.is_empty() { | ||
let n = assert_ok!(write_vectored(&mut self.writer, &io_vec).await); | ||
assert!(n != 0); | ||
assert!(self.writer.buffer().len() <= self.buf_capacity); | ||
total_written += n; | ||
io_vec = io_vec.advance(n); | ||
} | ||
total_written | ||
} | ||
|
||
async fn flush(&mut self) -> &[u8] { | ||
assert_ok!(self.writer.flush().await); | ||
&self.writer.get_ref().data | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_odd_on_non_vectored() { | ||
let msg = b"foo bar baz"; | ||
let mut bufs = [ | ||
IoSlice::new(&msg[0..4]), | ||
IoSlice::new(&[]), | ||
IoSlice::new(&msg[4..9]), | ||
IoSlice::new(&msg[9..]), | ||
]; | ||
let mut h = VectoredWriteHarness::new(8); | ||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; | ||
assert_eq!(bytes_written, msg.len()); | ||
assert_eq!(h.flush().await, msg); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_odd_on_vectored() { | ||
let msg = b"foo bar baz"; | ||
let mut bufs = [ | ||
IoSlice::new(&msg[0..4]), | ||
IoSlice::new(&[]), | ||
IoSlice::new(&msg[4..9]), | ||
IoSlice::new(&msg[9..]), | ||
]; | ||
let mut h = VectoredWriteHarness::with_vectored_backend(8); | ||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; | ||
assert_eq!(bytes_written, msg.len()); | ||
assert_eq!(h.flush().await, msg); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_large_slice_on_non_vectored() { | ||
let msg = b"foo bar baz"; | ||
let mut bufs = [ | ||
IoSlice::new(&[]), | ||
IoSlice::new(&msg[..9]), | ||
IoSlice::new(&msg[9..]), | ||
]; | ||
let mut h = VectoredWriteHarness::new(8); | ||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; | ||
assert_eq!(bytes_written, msg.len()); | ||
assert_eq!(h.flush().await, msg); | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_vectored_large_slice_on_vectored() { | ||
let msg = b"foo bar baz"; | ||
let mut bufs = [ | ||
IoSlice::new(&[]), | ||
IoSlice::new(&msg[..9]), | ||
IoSlice::new(&msg[9..]), | ||
]; | ||
let mut h = VectoredWriteHarness::with_vectored_backend(8); | ||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; | ||
assert_eq!(bytes_written, msg.len()); | ||
assert_eq!(h.flush().await, msg); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
use std::io::IoSlice; | ||
use std::ops::Deref; | ||
use std::slice; | ||
|
||
pub struct IoBufs<'a, 'b>(&'b mut [IoSlice<'a>]); | ||
|
||
impl<'a, 'b> IoBufs<'a, 'b> { | ||
pub fn new(slices: &'b mut [IoSlice<'a>]) -> Self { | ||
IoBufs(slices) | ||
} | ||
|
||
pub fn is_empty(&self) -> bool { | ||
self.0.is_empty() | ||
} | ||
|
||
pub fn advance(mut self, n: usize) -> IoBufs<'a, 'b> { | ||
let mut to_remove = 0; | ||
let mut remaining_len = n; | ||
for slice in self.0.iter() { | ||
if remaining_len < slice.len() { | ||
break; | ||
} else { | ||
remaining_len -= slice.len(); | ||
to_remove += 1; | ||
} | ||
} | ||
self.0 = self.0.split_at_mut(to_remove).1; | ||
if let Some(slice) = self.0.first_mut() { | ||
let tail = &slice[remaining_len..]; | ||
// Safety: recasts slice to the original lifetime | ||
let tail = unsafe { slice::from_raw_parts(tail.as_ptr(), tail.len()) }; | ||
*slice = IoSlice::new(tail); | ||
} else if remaining_len != 0 { | ||
panic!("advance past the end of the slice vector"); | ||
} | ||
self | ||
} | ||
} | ||
|
||
impl<'a, 'b> Deref for IoBufs<'a, 'b> { | ||
type Target = [IoSlice<'a>]; | ||
fn deref(&self) -> &[IoSlice<'a>] { | ||
self.0 | ||
} | ||
} |