Skip to content

Commit

Permalink
Abstract out async runtime
Browse files Browse the repository at this point in the history
Co-authored-by: Yureka <yuka@yuka.dev>
  • Loading branch information
Ralith and yu-re-ka committed Jul 5, 2022
1 parent 9f948d5 commit ffa85f2
Show file tree
Hide file tree
Showing 12 changed files with 392 additions and 82 deletions.
3 changes: 2 additions & 1 deletion perf/src/bin/perf_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
use anyhow::{Context, Result};
use bytes::Bytes;
use clap::Parser;
use quinn::TokioRuntime;
use tokio::sync::Semaphore;
use tracing::{debug, error, info};

Expand Down Expand Up @@ -103,7 +104,7 @@ async fn run(opt: Opt) -> Result<()> {

let socket = bind_socket(bind_addr, opt.send_buffer_size, opt.recv_buffer_size)?;

let (endpoint, _) = quinn::Endpoint::new(Default::default(), None, socket)?;
let (endpoint, _) = quinn::Endpoint::new(Default::default(), None, socket, TokioRuntime)?;

let mut crypto = rustls::ClientConfig::builder()
.with_cipher_suites(perf::PERF_CIPHER_SUITES)
Expand Down
11 changes: 8 additions & 3 deletions perf/src/bin/perf_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{fs, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{Context, Result};
use bytes::Bytes;
use clap::Parser;
use quinn::TokioRuntime;
use tracing::{debug, error, info};

use perf::bind_socket;
Expand Down Expand Up @@ -77,9 +78,13 @@ async fn run(opt: Opt) -> Result<()> {

let socket = bind_socket(opt.listen, opt.send_buffer_size, opt.recv_buffer_size)?;

let (endpoint, mut incoming) =
quinn::Endpoint::new(Default::default(), Some(server_config), socket)
.context("creating endpoint")?;
let (endpoint, mut incoming) = quinn::Endpoint::new(
Default::default(),
Some(server_config),
socket,
TokioRuntime,
)
.context("creating endpoint")?;

info!("listening on {}", endpoint.local_addr().unwrap());

Expand Down
8 changes: 6 additions & 2 deletions quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@ rust-version = "1.57"
all-features = true

[features]
default = ["native-certs", "tls-rustls"]
default = ["native-certs", "tls-rustls", "runtime-tokio"]
# Records how long locks are held, and warns if they are held >= 1ms
lock_tracking = []
# Provides `ClientConfig::with_native_roots()` convenience method
native-certs = ["proto/native-certs"]
tls-rustls = ["rustls", "webpki", "proto/tls-rustls", "ring"]
# Enables `Endpoint::client` and `Endpoint::server` conveniences
ring = ["proto/ring"]
runtime-tokio = ["tokio/time", "tokio/rt", "tokio/net"]
runtime-async-std = ["async-io", "async-std"]

[badges]
codecov = { repository = "djc/quinn" }
maintenance = { status = "experimental" }

[dependencies]
async-io = { version = "1.6", optional = true }
async-std = { version = "1.11", optional = true }
bytes = "1"
# Enables futures::io::{AsyncRead, AsyncWrite} support for streams
futures-io = { version = "0.3.19", optional = true }
Expand All @@ -39,7 +43,7 @@ proto = { package = "quinn-proto", path = "../quinn-proto", version = "0.8", def
rustls = { version = "0.20.3", default-features = false, features = ["quic"], optional = true }
thiserror = "1.0.21"
tracing = "0.1.10"
tokio = { version = "1.0.1", features = ["rt", "time", "sync", "net"] }
tokio = { version = "1.0.1", features = ["sync"] }
udp = { package = "quinn-udp", path = "../quinn-udp", version = "0.2" }
webpki = { version = "0.22", default-features = false, optional = true }

Expand Down
4 changes: 2 additions & 2 deletions quinn/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tokio::runtime::{Builder, Runtime};
use tracing::error_span;
use tracing_futures::Instrument as _;

use quinn::Endpoint;
use quinn::{Endpoint, TokioRuntime};

benchmark_group!(
benches,
Expand Down Expand Up @@ -102,7 +102,7 @@ impl Context {
let runtime = rt();
let (_, mut incoming) = {
let _guard = runtime.enter();
Endpoint::new(Default::default(), Some(config), sock).unwrap()
Endpoint::new(Default::default(), Some(config), sock, TokioRuntime).unwrap()
};
let handle = runtime.spawn(
async move {
Expand Down
18 changes: 12 additions & 6 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use std::{
time::{Duration, Instant},
};

use crate::runtime::{AsyncTimer, Runtime};
use bytes::Bytes;
use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId};
use rustc_hash::FxHashMap;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot, Notify};
use tokio::time::{sleep_until, Instant as TokioInstant, Sleep};
use tracing::debug_span;
use udp::UdpState;

Expand Down Expand Up @@ -44,6 +44,7 @@ impl Connecting {
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
udp_state: Arc<UdpState>,
runtime: Arc<Box<dyn Runtime>>,
) -> Connecting {
let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel();
let (on_connected_send, on_connected_recv) = oneshot::channel();
Expand All @@ -55,9 +56,10 @@ impl Connecting {
on_handshake_data_send,
on_connected_send,
udp_state,
runtime.clone(),
);

tokio::spawn(ConnectionDriver(conn.clone()));
runtime.spawn(Box::pin(ConnectionDriver(conn.clone())));

Connecting {
conn: Some(conn),
Expand Down Expand Up @@ -692,6 +694,7 @@ impl futures_core::Stream for Datagrams {
pub struct ConnectionRef(Arc<Mutex<ConnectionInner>>);

impl ConnectionRef {
#[allow(clippy::too_many_arguments)]
fn new(
handle: ConnectionHandle,
conn: proto::Connection,
Expand All @@ -700,6 +703,7 @@ impl ConnectionRef {
on_handshake_data: oneshot::Sender<()>,
on_connected: oneshot::Sender<bool>,
udp_state: Arc<UdpState>,
runtime: Arc<Box<dyn Runtime>>,
) -> Self {
Self(Arc::new(Mutex::new(ConnectionInner {
inner: conn,
Expand All @@ -723,6 +727,7 @@ impl ConnectionRef {
error: None,
ref_count: 0,
udp_state,
runtime,
})))
}

Expand Down Expand Up @@ -768,8 +773,8 @@ pub struct ConnectionInner {
on_handshake_data: Option<oneshot::Sender<()>>,
on_connected: Option<oneshot::Sender<bool>>,
connected: bool,
timer: Option<Pin<Box<Sleep>>>,
timer_deadline: Option<TokioInstant>,
timer: Option<Pin<Box<dyn AsyncTimer>>>,
timer_deadline: Option<Instant>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
pub(crate) blocked_writers: FxHashMap<StreamId, Waker>,
Expand All @@ -785,6 +790,7 @@ pub struct ConnectionInner {
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
ref_count: usize,
udp_state: Arc<UdpState>,
runtime: Arc<Box<dyn Runtime>>,
}

impl ConnectionInner {
Expand Down Expand Up @@ -924,7 +930,7 @@ impl ConnectionInner {
// Check whether we need to (re)set the timer. If so, we must poll again to ensure the
// timer is registered with the runtime (and check whether it's already
// expired).
match self.inner.poll_timeout().map(TokioInstant::from_std) {
match self.inner.poll_timeout() {
Some(deadline) => {
if let Some(delay) = &mut self.timer {
// There is no need to reset the tokio timer if the deadline
Expand All @@ -937,7 +943,7 @@ impl ConnectionInner {
delay.as_mut().reset(deadline);
}
} else {
self.timer = Some(Box::pin(sleep_until(deadline)));
self.timer = Some(self.runtime.new_timer(deadline));
}
// Store the actual expiration time of the timer
self.timer_deadline = Some(deadline);
Expand Down

0 comments on commit ffa85f2

Please sign in to comment.