Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make copy continue filling the buffer when writer stalls #5066

Merged
merged 7 commits into from Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions benches/Cargo.toml
Expand Up @@ -7,6 +7,8 @@ edition = "2018"
[dependencies]
tokio = { version = "1.5.0", path = "../tokio", features = ["full"] }
bencher = "0.1.5"
rand = "0.8"
rand_chacha = "0.3"

[dev-dependencies]
tokio-util = { version = "0.7.0", path = "../tokio-util", features = ["full"] }
Expand Down Expand Up @@ -50,3 +52,8 @@ harness = false
name = "fs"
path = "fs.rs"
harness = false

[[bench]]
name = "copy"
path = "copy.rs"
harness = false
239 changes: 239 additions & 0 deletions benches/copy.rs
@@ -0,0 +1,239 @@
use bencher::{benchmark_group, benchmark_main, Bencher};

use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;

use tokio::io::{copy, repeat, AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::time::{interval, Interval, MissedTickBehavior};

use std::task::Poll;
use std::time::Duration;

const KILO: usize = 1024;

// Tunable parameters if you want to change this benchmark. If reader and writer
// are matched in kilobytes per second, then this only exposes buffering to the
// benchmark.
const RNG_SEED: u64 = 0;
// How much data to copy in a single benchmark run
const SOURCE_SIZE: u64 = 256 * KILO as u64;
// Read side provides CHUNK_SIZE every READ_SERVICE_PERIOD. If it's not called
// frequently, it'll burst to catch up (representing OS buffers draining)
const CHUNK_SIZE: usize = 2 * KILO;
const READ_SERVICE_PERIOD: Duration = Duration::from_millis(1);
// Write side buffers up to WRITE_BUFFER, and flushes to disk every
// WRITE_SERVICE_PERIOD.
const WRITE_BUFFER: usize = 40 * KILO;
const WRITE_SERVICE_PERIOD: Duration = Duration::from_millis(20);
// How likely you are to have to wait for previously written data to be flushed
// because another writer claimed the buffer space
const PROBABILITY_FLUSH_WAIT: f64 = 0.1;

/// A slow writer that aims to simulate HDD behaviour under heavy load.
///
/// There is a limited buffer, which is fully drained on the next write after
/// a time limit is reached. Flush waits for the time limit to be reached
/// and then drains the buffer.
///
/// At random, the HDD will stall writers while it flushes out all buffers. If
/// this happens to you, you will be unable to write until the next time the
/// buffer is drained.
struct SlowHddWriter {
service_intervals: Interval,
blocking_rng: ChaCha20Rng,
buffer_size: usize,
buffer_used: usize,
}

impl SlowHddWriter {
fn new(service_interval: Duration, buffer_size: usize) -> Self {
let blocking_rng = ChaCha20Rng::seed_from_u64(RNG_SEED);
let mut service_intervals = interval(service_interval);
service_intervals.set_missed_tick_behavior(MissedTickBehavior::Delay);
Self {
service_intervals,
blocking_rng,
buffer_size,
buffer_used: 0,
}
}

fn service_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
// If we hit a service interval, the buffer can be cleared
let res = self.service_intervals.poll_tick(cx).map(|_| Ok(()));
if let Poll::Ready(_) = res {
self.buffer_used = 0;
}
res
}

fn write_bytes(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
writeable: usize,
) -> std::task::Poll<Result<usize, std::io::Error>> {
let service_res = self.as_mut().service_write(cx);

if service_res.is_pending() && self.blocking_rng.gen_bool(PROBABILITY_FLUSH_WAIT) {
return Poll::Pending;
}
let available = self.buffer_size - self.buffer_used;

if available == 0 {
assert!(service_res.is_pending());
Poll::Pending
} else {
let written = available.min(writeable);
self.buffer_used += written;
Poll::Ready(Ok(written))
}
}
}

impl Unpin for SlowHddWriter {}

impl AsyncWrite for SlowHddWriter {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
self.write_bytes(cx, buf.len())
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.service_write(cx)
}

fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.service_write(cx)
}

fn poll_write_vectored(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let writeable = bufs.into_iter().fold(0, |acc, buf| acc + buf.len());
self.write_bytes(cx, writeable)
}

fn is_write_vectored(&self) -> bool {
true
}
}

/// A reader that limits the maximum chunk it'll give you back
///
/// Simulates something reading from a slow link - you get one chunk per call,
/// and you are offered chunks on a schedule
struct ChunkReader {
data: Vec<u8>,
service_intervals: Interval,
}

impl ChunkReader {
fn new(chunk_size: usize, service_interval: Duration) -> Self {
let mut service_intervals = interval(service_interval);
service_intervals.set_missed_tick_behavior(MissedTickBehavior::Burst);
let data: Vec<u8> = std::iter::repeat(0).take(chunk_size).collect();
Self {
data,
service_intervals,
}
}
}

impl AsyncRead for ChunkReader {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.service_intervals.poll_tick(cx).is_pending() {
return Poll::Pending;
}
buf.put_slice(&self.data[..buf.remaining().min(self.data.len())]);
Poll::Ready(Ok(()))
}
}

fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_time()
.build()
.unwrap()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A current-thread runtime probably makes more sense for the benchmark. You don't do any spawning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied that from fs.rs, which also doesn't do any spawning - it uses block_in_place, though, which might be the reason.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block_in_place method involves spawning.


fn copy_mem_to_mem(b: &mut Bencher) {
let rt = rt();

b.iter(|| {
let task = || async {
let mut source = repeat(0).take(SOURCE_SIZE);
let mut dest = Vec::new();
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

fn copy_mem_to_slow_hdd(b: &mut Bencher) {
let rt = rt();

b.iter(|| {
let task = || async {
let mut source = repeat(0).take(SOURCE_SIZE);
let mut dest = SlowHddWriter::new(WRITE_SERVICE_PERIOD, WRITE_BUFFER);
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

fn copy_chunk_to_mem(b: &mut Bencher) {
let rt = rt();
b.iter(|| {
let task = || async {
let mut source = ChunkReader::new(CHUNK_SIZE, READ_SERVICE_PERIOD).take(SOURCE_SIZE);
let mut dest = Vec::new();
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

fn copy_chunk_to_slow_hdd(b: &mut Bencher) {
let rt = rt();
b.iter(|| {
let task = || async {
let mut source = ChunkReader::new(CHUNK_SIZE, READ_SERVICE_PERIOD).take(SOURCE_SIZE);
let mut dest = SlowHddWriter::new(WRITE_SERVICE_PERIOD, WRITE_BUFFER);
copy(&mut source, &mut dest).await.unwrap();
};

rt.block_on(task());
})
}

benchmark_group!(
copy_bench,
copy_mem_to_mem,
copy_mem_to_slow_hdd,
copy_chunk_to_mem,
copy_chunk_to_slow_hdd,
);
benchmark_main!(copy_bench);
55 changes: 45 additions & 10 deletions tokio/src/io/util/copy.rs
Expand Up @@ -27,6 +27,46 @@ impl CopyBuffer {
}
}

fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf[me.cap..]);

let res = reader.poll_read(cx, &mut buf);
me.cap += buf.filled().len();
farnz marked this conversation as resolved.
Show resolved Hide resolved
res
}

fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap != me.buf.len() {
farnz marked this conversation as resolved.
Show resolved Hide resolved
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
}
Poll::Pending
}
res => res,
}
}

pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
Expand All @@ -41,10 +81,10 @@ impl CopyBuffer {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
self.pos = 0;
self.cap = 0;

match reader.as_mut().poll_read(cx, &mut buf) {
match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
Expand All @@ -59,19 +99,14 @@ impl CopyBuffer {
}
}

let n = buf.filled().len();
if n == 0 {
if self.cap == 0 {
self.read_done = true;
farnz marked this conversation as resolved.
Show resolved Hide resolved
} else {
self.pos = 0;
self.cap = n;
}
}

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
Expand Down