Skip to content

Commit

Permalink
fix: Use 1024 buffer size for ARD VNC session
Browse files Browse the repository at this point in the history
_ARD_ uses _MVS_ video codec which doesn't like buffering, and we need to have the buffer as minimal as possible.

Also, this commit adds new `copy_bidirectional` transport that is forked [the one from tokio](https://docs.rs/tokio/latest/tokio/io/fn.copy_bidirectional.html).
It's forked because the original function doesn't allow overriding the buffer size (8K is used by default). There is [an issue](tokio-rs/tokio#6454) on tokio side for it. We will be able to replace our fork with the upstream easily when it's ready.

Releated to Devolutions/IronVNC#338.
  • Loading branch information
RRRadicalEdward committed Apr 15, 2024
1 parent 3e7aa30 commit 206fecd
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 5 deletions.
123 changes: 123 additions & 0 deletions crates/transport/src/copy_bidirectional.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//! Fork of https://github.com/tokio-rs/tokio/blob/master/tokio/src/io/util/copy.rs to allow us set
//! variable length `CopyBuffer` size instead of default 8k.
//! See <https://github.com/tokio-rs/tokio/issues/6454>.

use super::copy_buffer::CopyBuffer;
use futures_core::ready;
use tokio::io::{AsyncRead, AsyncWrite};

use std::future::Future;
use std::io::{self};
use std::pin::Pin;
use std::task::{Context, Poll};

enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}

struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
}

fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);

loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;

*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}

impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Unpack self into mut refs to each field to avoid borrow check issues.
let CopyBidirectional {
a,
b,
a_to_b,
b_to_a,
} = &mut *self;

let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;

// It is not a problem if ready! returns early because transfer_one_direction for the
// other direction will keep returning TransferState::Done(count) in future calls to poll
let a_to_b = ready!(a_to_b);
let b_to_a = ready!(b_to_a);

Poll::Ready(Ok((a_to_b, b_to_a)))
}
}

/// Copies data in both directions between `a` and `b`.
///
/// This function returns a future that will read from both streams,
/// writing any data read to the opposing stream.
/// This happens in both directions concurrently.
///
/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
/// the other, and reading from that stream will stop. Copying of data in
/// the other direction will continue.
///
/// The future will complete successfully once both directions of communication has been shut down.
/// A direction is shut down when the reader reports EOF,
/// at which point [`shutdown()`] is called on the corresponding writer. When finished,
/// it will return a tuple of the number of bytes copied from a to b
/// and the number of bytes copied from b to a, in that order.
///
/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
///
/// # Errors
///
/// The future will immediately return an error if any IO operation on `a`
/// or `b` returns an error. Some data read from either stream may be lost (not
/// written to the other stream) in this case.
///
/// # Return value
///
/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`.
pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B, send_buffer_size: usize, recv_buffer_size: usize) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a,
b,
a_to_b: TransferState::Running(CopyBuffer::new(send_buffer_size)),
b_to_a: TransferState::Running(CopyBuffer::new(recv_buffer_size)),
}
.await
}
143 changes: 143 additions & 0 deletions crates/transport/src/copy_buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//! Fork of https://github.com/tokio-rs/tokio/blob/master/tokio/src/io/util/copy.rs to allow us set
//! variable length `CopyBuffer` size instead of default 8k.
//! See <https://github.com/tokio-rs/tokio/issues/6454>.
use futures_core::ready;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use std::io::{self};
use std::pin::Pin;
use std::task::{Context, Poll};


#[derive(Debug)]
pub(super) struct CopyBuffer {
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}

impl CopyBuffer {
pub(super) fn new(buffer_size: usize) -> Self { // <- This is our change
Self {
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; buffer_size].into_boxed_slice(),
}
}

fn poll_fill_buf<R>(
&mut self,
cx: &mut Context<'_>,
reader: Pin<&mut R>,
) -> Poll<io::Result<()>>
where
R: AsyncRead + ?Sized,
{
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
buf.set_filled(me.cap);

let res = reader.poll_read(cx, &mut buf);
if let Poll::Ready(Ok(_)) = res {
let filled_len = buf.filled().len();
me.read_done = me.cap == filled_len;
me.cap = filled_len;
}
res
}

fn poll_write_buf<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<usize>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
let me = &mut *self;
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
Poll::Pending => {
// Top up the buffer towards full if we can read a bit more
// data - this should improve the chances of a large write
if !me.read_done && me.cap < me.buf.len() {
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
}
Poll::Pending
}
res => res,
}
}

pub(super) fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
self.pos = 0;
self.cap = 0;

match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
self.need_flush = false;
}

return Poll::Pending;
}
}
}

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
}

// If pos larger than cap, this loop will never stop.
// In particular, user's wrong poll_write implementation returning
// incorrect written length may lead to thread blocking.
debug_assert!(
self.pos <= self.cap,
"writer returned length larger than input slice"
);

// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}
3 changes: 3 additions & 0 deletions crates/transport/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
mod forward;
mod ws;
mod copy_bidirectional;
mod copy_buffer;

pub use copy_bidirectional::*;
pub use self::forward::*;
pub use self::ws::*;

Expand Down
11 changes: 10 additions & 1 deletion devolutions-gateway/src/api/fwd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::http::HttpError;
use crate::proxy::Proxy;
use crate::session::{ConnectionModeDetails, SessionInfo, SessionMessageSender};
use crate::subscriber::SubscriberSender;
use crate::token::{AssociationTokenClaims, ConnectionMode};
use crate::token::{ApplicationProtocol, AssociationTokenClaims, ConnectionMode, Protocol};
use crate::{utils, DgwState};

pub fn make_router<S>(state: DgwState) -> Router<S> {
Expand Down Expand Up @@ -162,6 +162,13 @@ where
trace!(%selected_target, "Connected");
span.record("target", selected_target.to_string());

// ARD uses MVS codec which doesn't like buffering.
let buffer_size = if claims.jet_ap == ApplicationProtocol::Known(Protocol::Ard) {
Some(1024)
} else {
None
};

if with_tls {
trace!("Establishing TLS connection with server");

Expand Down Expand Up @@ -193,6 +200,7 @@ where
.transport_b(server_stream)
.sessions(sessions)
.subscriber_tx(subscriber_tx)
.buffer_size(buffer_size)
.build()
.select_dissector_and_forward()
.await
Expand Down Expand Up @@ -220,6 +228,7 @@ where
.transport_b(server_stream)
.sessions(sessions)
.subscriber_tx(subscriber_tx)
.buffer_size(buffer_size)
.build()
.select_dissector_and_forward()
.await
Expand Down
21 changes: 17 additions & 4 deletions devolutions-gateway/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub struct Proxy<A, B> {
address_b: SocketAddr,
sessions: SessionMessageSender,
subscriber_tx: SubscriberSender,
#[builder(default = None)]
buffer_size: Option<usize>,
}

impl<A, B> Proxy<A, B>
Expand Down Expand Up @@ -95,6 +97,7 @@ where
address_b: self.address_b,
sessions: self.sessions,
subscriber_tx: self.subscriber_tx,
buffer_size: self.buffer_size,
}
.forward()
.await
Expand All @@ -121,12 +124,22 @@ where
// NOTE(DGW-86): when recording is required, should we wait for it to start before we forward, or simply spawn
// a timer to check if the recording is started within a few seconds?

let forward_fut = tokio::io::copy_bidirectional(&mut transport_a, &mut transport_b);
let kill_notified = notify_kill.notified();

let res = match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await {
Either::Left((res, _)) => res.map(|_| ()),
Either::Right(_) => Ok(()),
let res = if let Some(buffer_size) = self.buffer_size {
// Use our for of copy_bidirectional because tokio doesn't have an API to set the buffer size.
// See https://github.com/tokio-rs/tokio/issues/6454.
let forward_fut = transport::copy_bidirectional(&mut transport_a, &mut transport_b, buffer_size, buffer_size);
match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await {
Either::Left((res, _)) => res.map(|_| ()),
Either::Right(_) => Ok(()),
}
} else {
let forward_fut = tokio::io::copy_bidirectional(&mut transport_a, &mut transport_b);
match futures::future::select(pin!(forward_fut), pin!(kill_notified)).await {
Either::Left((res, _)) => res.map(|_| ()),
Either::Right(_) => Ok(()),
}
};

// Ensure we close the transports cleanly at the end (ignore errors at this point)
Expand Down

0 comments on commit 206fecd

Please sign in to comment.