From 6d721245f2a3f12e0d280dae802cc8a7c0f3442f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Mon, 9 Aug 2021 11:20:30 +0200 Subject: [PATCH] Make NIC not mut anymore This moves the Mutex outside NetworkState, avoiding the `static mut` and properly locking the whole network state on access. --- hermit-sys/src/lib.rs | 1 - hermit-sys/src/net/device.rs | 9 ++- hermit-sys/src/net/mod.rs | 115 +++++++++++++++++------------------ 3 files changed, 60 insertions(+), 65 deletions(-) diff --git a/hermit-sys/src/lib.rs b/hermit-sys/src/lib.rs index 8d9835f76..812fd4952 100644 --- a/hermit-sys/src/lib.rs +++ b/hermit-sys/src/lib.rs @@ -1,4 +1,3 @@ -#![allow(clippy::mut_mutex_lock)] #![allow(clippy::large_enum_variant)] #![allow(clippy::new_ret_no_self)] diff --git a/hermit-sys/src/net/device.rs b/hermit-sys/src/net/device.rs index 5ed92ca27..6edcbd613 100644 --- a/hermit-sys/src/net/device.rs +++ b/hermit-sys/src/net/device.rs @@ -7,7 +7,6 @@ use std::convert::TryInto; #[cfg(not(feature = "dhcpv4"))] use std::net::Ipv4Addr; use std::slice; -use std::sync::Mutex; #[cfg(feature = "dhcpv4")] use smoltcp::dhcp::Dhcpv4Client; @@ -94,13 +93,13 @@ impl NetworkInterface { .routes(routes) .finalize(); - NetworkState::Initialized(Mutex::new(Self { + NetworkState::Initialized(Self { iface, sockets, dhcp, prev_cidr, waker: WakerRegistration::new(), - })) + }) } #[cfg(not(feature = "dhcpv4"))] @@ -168,11 +167,11 @@ impl NetworkInterface { .routes(routes) .finalize(); - NetworkState::Initialized(Mutex::new(Self { + NetworkState::Initialized(Self { iface, sockets: SocketSet::new(vec![]), waker: WakerRegistration::new(), - })) + }) } } diff --git a/hermit-sys/src/net/mod.rs b/hermit-sys/src/net/mod.rs index 3bdce3d17..fa3965e17 100644 --- a/hermit-sys/src/net/mod.rs +++ b/hermit-sys/src/net/mod.rs @@ -7,6 +7,7 @@ use aarch64::regs::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::_rdtsc; use std::convert::TryInto; +use std::ops::DerefMut; use std::str::FromStr; use std::sync::atomic::{AtomicU16, Ordering}; use std::sync::Mutex; @@ -35,10 +36,21 @@ use crate::net::waker::WakerRegistration; pub(crate) enum NetworkState { Missing, InitializationFailed, - Initialized(Mutex>), + Initialized(NetworkInterface), } -static mut NIC: NetworkState = NetworkState::Missing; +impl NetworkState { + fn as_nic_mut(&mut self) -> Result<&mut NetworkInterface, &'static str> { + match self { + NetworkState::Initialized(nic) => Ok(nic), + _ => Err("Network is not initialized!"), + } + } +} + +lazy_static! { + static ref NIC: Mutex = Mutex::new(NetworkState::Missing); +} extern "C" { fn sys_yield(); @@ -158,28 +170,24 @@ pub(crate) struct AsyncSocket(Handle); impl AsyncSocket { pub(crate) fn new() -> Self { - match unsafe { &mut NIC } { - NetworkState::Initialized(nic) => { - AsyncSocket(nic.lock().unwrap().create_handle().unwrap()) - } - _ => { - panic!("Network isn't initialized!"); - } - } + let handle = NIC + .lock() + .unwrap() + .as_nic_mut() + .unwrap() + .create_handle() + .unwrap(); + Self(handle) } fn with(&self, f: impl FnOnce(&mut TcpSocket) -> R) -> R { - let mut guard = match unsafe { &mut NIC } { - NetworkState::Initialized(nic) => nic.lock().unwrap(), - _ => { - panic!("Network isn't initialized!"); - } - }; + let mut guard = NIC.lock().unwrap(); + let nic = guard.as_nic_mut().unwrap(); let res = { - let mut s = guard.sockets.get::(self.0); + let mut s = nic.sockets.get::(self.0); f(&mut *s) }; - guard.wake(); + nic.wake(); res } @@ -232,17 +240,13 @@ impl AsyncSocket { }) .await?; - match unsafe { &mut NIC } { - NetworkState::Initialized(nic) => { - let mut guard = nic.lock().unwrap(); - let mut socket = guard.sockets.get::(self.0); - socket.set_keep_alive(Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL))); - let endpoint = socket.remote_endpoint(); + let mut guard = NIC.lock().unwrap(); + let nic = guard.as_nic_mut().map_err(|_| Error::Illegal)?; + let mut socket = nic.sockets.get::(self.0); + socket.set_keep_alive(Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL))); + let endpoint = socket.remote_endpoint(); - Ok((endpoint.addr, endpoint.port)) - } - _ => Err(Error::Illegal), - } + Ok((endpoint.addr, endpoint.port)) } pub(crate) async fn read(&self, buffer: &mut [u8]) -> Result { @@ -347,16 +351,13 @@ fn start_endpoint() -> u16 { } pub(crate) fn network_delay(timestamp: Instant) -> Option { - match unsafe { &mut NIC } { - NetworkState::Initialized(nic) => nic.lock().unwrap().poll_delay(timestamp), - _ => None, - } + NIC.lock().unwrap().as_nic_mut().ok()?.poll_delay(timestamp) } pub(crate) async fn network_run() { - future::poll_fn(|cx| match unsafe { &mut NIC } { + future::poll_fn(|cx| match NIC.lock().unwrap().deref_mut() { NetworkState::Initialized(nic) => { - nic.lock().unwrap().poll(cx, Instant::now()); + nic.poll(cx, Instant::now()); Poll::Pending } _ => Poll::Ready(()), @@ -366,14 +367,12 @@ pub(crate) async fn network_run() { extern "C" fn nic_thread(_: usize) { loop { - unsafe { - sys_netwait(); - } + unsafe { sys_netwait() }; trace!("Network thread checks the devices"); - if let NetworkState::Initialized(nic) = unsafe { &mut NIC } { - nic.lock().unwrap().poll_common(Instant::now()); + if let NetworkState::Initialized(nic) = NIC.lock().unwrap().deref_mut() { + nic.poll_common(Instant::now()); } } } @@ -382,25 +381,25 @@ pub(crate) fn network_init() -> Result<(), ()> { // initialize variable, which contains the next local endpoint LOCAL_ENDPOINT.store(start_endpoint(), Ordering::SeqCst); - unsafe { - NIC = NetworkInterface::::new(); + let mut guard = NIC.lock().unwrap(); - if let NetworkState::Initialized(nic) = &mut NIC { - nic.lock().unwrap().poll_common(Instant::now()); + *guard = NetworkInterface::::new(); - // create thread, which manages the network stack - // use a higher priority to reduce the network latency - let mut tid: Tid = 0; - let ret = sys_spawn(&mut tid, nic_thread, 0, 3, 0); - if ret >= 0 { - debug!("Spawn network thread with id {}", tid); - } - - spawn(network_run()).detach(); + if let NetworkState::Initialized(nic) = guard.deref_mut() { + nic.poll_common(Instant::now()); - // switch to network thread - sys_yield(); + // create thread, which manages the network stack + // use a higher priority to reduce the network latency + let mut tid: Tid = 0; + let ret = unsafe { sys_spawn(&mut tid, nic_thread, 0, 3, 0) }; + if ret >= 0 { + debug!("Spawn network thread with id {}", tid); } + + spawn(network_run()).detach(); + + // switch to network thread + unsafe { sys_yield() }; } Ok(()) @@ -498,11 +497,9 @@ pub fn sys_tcp_stream_get_tll(_handle: Handle) -> Result { #[no_mangle] pub fn sys_tcp_stream_peer_addr(handle: Handle) -> Result<(IpAddress, u16), ()> { - let mut guard = match unsafe { &mut NIC } { - NetworkState::Initialized(nic) => nic.lock().unwrap(), - _ => return Err(()), - }; - let mut socket = guard.sockets.get::(handle); + let mut guard = NIC.lock().unwrap(); + let nic = guard.as_nic_mut().map_err(drop)?; + let mut socket = nic.sockets.get::(handle); socket.set_keep_alive(Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL))); let endpoint = socket.remote_endpoint();