Skip to content

Commit

Permalink
Allow overriding streams
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>
  • Loading branch information
rylev committed Jan 22, 2024
1 parent 2ac3638 commit cc1e62f
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 152 deletions.
109 changes: 69 additions & 40 deletions crates/wasi/src/preview2/host/udp.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::preview2::udp::UdpSocket;
use crate::preview2::udp::{IncomingDatagramStream, OutgoingDatagramStream, UdpSocket};
use crate::preview2::{
bindings::{
sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network},
sockets::udp,
sockets::udp_create_socket,
},
udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState},
udp::UdpState,
Subscribe,
};
use crate::preview2::{Pollable, SocketError, SocketResult, WasiView};
Expand Down Expand Up @@ -106,7 +106,7 @@ impl<T: WasiView> udp::HostUdpSocket for T {

let has_active_streams = table
.iter_children(&this)?
.any(|c| c.is::<IncomingDatagramStream>() || c.is::<OutgoingDatagramStream>());
.any(|c| c.is::<IncomingStreamResource>() || c.is::<OutgoingStreamResource>());

if has_active_streams {
return Err(SocketError::trap(anyhow!("UDP streams not dropped yet")));
Expand Down Expand Up @@ -138,15 +138,9 @@ impl<T: WasiView> udp::HostUdpSocket for T {
socket.udp_state = UdpState::Connected;
}

let incoming_stream = IncomingDatagramStream {
inner: socket.inner.clone(),
is_connected: remote_address.is_some(),
};
let outgoing_stream = OutgoingDatagramStream {
inner: socket.inner.clone(),
is_connected: remote_address.is_some(),
send_state: SendState::Idle,
};
let (incoming_stream, outgoing_stream) = socket.inner.streams();
let incoming_stream = IncomingStreamResource::new(incoming_stream, remote_address);
let outgoing_stream = OutgoingStreamResource::new(outgoing_stream, remote_address);

Ok((
self.table_mut().push_child(incoming_stream, &this)?,
Expand Down Expand Up @@ -288,13 +282,14 @@ impl<T: WasiView> udp::HostIncomingDatagramStream for T {
) -> SocketResult<Vec<udp::IncomingDatagram>> {
// Returns Ok(None) when the message was dropped.
fn recv_one(
stream: &IncomingDatagramStream,
stream: &IncomingStreamResource,
) -> SocketResult<Option<udp::IncomingDatagram>> {
let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
let (size, received_addr) = stream.inner.receive_data(&mut buf)?;
let (size, received_addr) = stream.inner.recv(&mut buf)?;
debug_assert!(size <= buf.len());

if stream.is_connected && stream.inner.remote_address()? != received_addr {
if matches!(stream.remote_address, Some(remote_address) if remote_address != received_addr)
{
// Normally, this should have already been checked for us by the OS.
return Ok(None);
}
Expand Down Expand Up @@ -357,13 +352,41 @@ impl<T: WasiView> udp::HostIncomingDatagramStream for T {
}
}

pub struct OutgoingStreamResource {
pub(crate) inner: OutgoingDatagramStream,

/// If this has a value, the stream is "connected".
pub(crate) remote_address: Option<SocketAddr>,

pub(crate) send_state: SendState,
}

pub(crate) enum SendState {
/// Waiting for the API consumer to call `check-send`.
Idle,
/// Ready to send up to x datagrams.
Permitted(usize),
/// Waiting for the OS.
Waiting,
}

impl OutgoingStreamResource {
fn new(inner: OutgoingDatagramStream, remote_address: Option<SocketAddr>) -> Self {
Self {
inner,
remote_address,
send_state: SendState::Idle,
}
}
}

#[async_trait]
impl Subscribe for IncomingDatagramStream {
impl Subscribe for OutgoingStreamResource {
async fn ready(&mut self) {
self.inner
.await_readable()
.await
.expect("failed to await UDP socket readiness");
if let SendState::Waiting = self.send_state {
self.inner.ready().await;
self.send_state = SendState::Idle;
}
}
}

Expand Down Expand Up @@ -391,26 +414,26 @@ impl<T: WasiView> udp::HostOutgoingDatagramStream for T {
datagrams: Vec<udp::OutgoingDatagram>,
) -> SocketResult<u64> {
fn send_one(
stream: &OutgoingDatagramStream,
stream: &OutgoingStreamResource,
datagram: &udp::OutgoingDatagram,
) -> SocketResult<()> {
if datagram.data.len() > MAX_UDP_DATAGRAM_SIZE {
return Err(ErrorCode::DatagramTooLarge.into());
}

let provided_addr = datagram.remote_address.map(SocketAddr::from);
match (stream.is_connected, provided_addr) {
(false, Some(addr)) => {
stream.inner.send_data_to(addr, &datagram.data)?;
match (stream.remote_address, provided_addr) {
(None, Some(target)) => {
stream.inner.send(&datagram.data, target)?;
}
(true, None) => {
stream.inner.send_data(&datagram.data)?;
(Some(target), None) => {
stream.inner.send(&datagram.data, target)?;
}
(true, Some(provided_addr)) if stream.inner.remote_address()? == provided_addr => {
stream.inner.send_data(&datagram.data)?;
(Some(connected_addr), Some(provided_addr)) if connected_addr == provided_addr => {
stream.inner.send(&datagram.data, provided_addr)?;
}
_ => return Err(ErrorCode::InvalidArgument.into()),
};
}

Ok(())
}
Expand Down Expand Up @@ -479,18 +502,24 @@ impl<T: WasiView> udp::HostOutgoingDatagramStream for T {
}
}

pub struct IncomingStreamResource {
inner: IncomingDatagramStream,
/// If this has a value, the stream is "connected".
pub(crate) remote_address: Option<SocketAddr>,
}

impl IncomingStreamResource {
fn new(inner: IncomingDatagramStream, remote_address: Option<SocketAddr>) -> Self {
Self {
inner,
remote_address,
}
}
}

#[async_trait]
impl Subscribe for OutgoingDatagramStream {
impl Subscribe for IncomingStreamResource {
async fn ready(&mut self) {
match self.send_state {
SendState::Idle | SendState::Permitted(_) => {}
SendState::Waiting => {
self.inner
.await_writable()
.await
.expect("failed to await UDP socket readiness");
self.send_state = SendState::Idle;
}
}
self.inner.ready().await;
}
}
4 changes: 2 additions & 2 deletions crates/wasi/src/preview2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ pub mod bindings {
"wasi:sockets/network/network": super::network::NetworkResource,
"wasi:sockets/tcp/tcp-socket": super::host::tcp::TcpSocketResource,
"wasi:sockets/udp/udp-socket": super::host::udp::UdpSocketResource,
"wasi:sockets/udp/incoming-datagram-stream": super::udp::IncomingDatagramStream,
"wasi:sockets/udp/outgoing-datagram-stream": super::udp::OutgoingDatagramStream,
"wasi:sockets/udp/incoming-datagram-stream": super::host::udp::IncomingStreamResource,
"wasi:sockets/udp/outgoing-datagram-stream": super::host::udp::OutgoingStreamResource,
"wasi:sockets/ip-name-lookup/resolve-address-stream": super::host::ip_name_lookup::ResolveAddressStreamResource,
"wasi:filesystem/types/directory-entry-stream": super::filesystem::ReaddirIterator,
"wasi:filesystem/types/descriptor": super::filesystem::Descriptor,
Expand Down
4 changes: 0 additions & 4 deletions crates/wasi/src/preview2/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,6 @@ impl SocketAddrCheck {
Self(Arc::new(|_, _| false))
}

pub fn allow() -> Self {
Self(Arc::new(|_, _| true))
}

pub fn check(&self, addr: &SocketAddr, reason: SocketAddrUse) -> std::io::Result<()> {
if (self.0)(addr, reason) {
Ok(())
Expand Down

0 comments on commit cc1e62f

Please sign in to comment.