diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 5673e0fca78..ab0ada7d098 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -60,6 +60,7 @@ use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; use std::mem; use std::ops; +use std::panic; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -530,27 +531,47 @@ impl Sender { Ok(()) } - /// Sends a new value via the channel, notifying all receivers and returning - /// the previous value in the channel. + /// Modifies watched value, notifying all receivers. /// - /// This can be useful for reusing the buffers inside a watched value. - /// Additionally, this method permits sending values even when there are no - /// receivers. + /// This can useful for modifying the watched value, without + /// having to allocate a new instance. Additionally, this + /// method permits sending values even when there are no receivers. + /// + /// # Panics + /// + /// This function panics if calling `func` results in a panic. + /// No receivers are notified if panic occurred, but if the closure has modified + /// the value, that change is still visible to future calls to `borrow`. /// /// # Examples /// /// ``` /// use tokio::sync::watch; /// - /// let (tx, _rx) = watch::channel(1); - /// assert_eq!(tx.send_replace(2), 1); - /// assert_eq!(tx.send_replace(3), 2); + /// struct State { + /// counter: usize, + /// } + /// let (state_tx, state_rx) = watch::channel(State { counter: 0 }); + /// state_tx.send_modify(|state| state.counter += 1); + /// assert_eq!(state_rx.borrow().counter, 1); /// ``` - pub fn send_replace(&self, value: T) -> T { - let old = { + pub fn send_modify(&self, func: F) + where + F: FnOnce(&mut T), + { + { // Acquire the write lock and update the value. let mut lock = self.shared.value.write().unwrap(); - let old = mem::replace(&mut *lock, value); + // Update the value and catch possible panic inside func. + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + func(&mut lock); + })); + // If the func panicked return the panic to the caller. + if let Err(error) = result { + // Drop the lock to avoid poisoning it. + drop(lock); + panic::resume_unwind(error); + } self.shared.state.increment_version(); @@ -560,14 +581,32 @@ impl Sender { // that receivers are able to figure out the version number of the // value they are currently looking at. drop(lock); + } - old - }; - - // Notify all watchers self.shared.notify_rx.notify_waiters(); + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, mut value: T) -> T { + // swap old watched value with the new one + self.send_modify(|old| mem::swap(old, &mut value)); - old + value } /// Returns a reference to the most recently sent value diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index 8b9ea81bb89..2097b8bdfdb 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -211,3 +211,31 @@ fn reopened_after_subscribe() { drop(rx); assert!(tx.is_closed()); } + +#[test] +fn send_modify_panic() { + let (tx, mut rx) = watch::channel("one"); + + tx.send_modify(|old| *old = "two"); + assert_eq!(*rx.borrow_and_update(), "two"); + + let mut rx2 = rx.clone(); + assert_eq!(*rx2.borrow_and_update(), "two"); + + let mut task = spawn(rx2.changed()); + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + tx.send_modify(|old| { + *old = "panicked"; + panic!(); + }) + })); + assert!(result.is_err()); + + assert_pending!(task.poll()); + assert_eq!(*rx.borrow(), "panicked"); + + tx.send_modify(|old| *old = "three"); + assert_ready_ok!(task.poll()); + assert_eq!(*rx.borrow_and_update(), "three"); +}