diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index 1221a5b6474..6406af635ae 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -55,6 +55,7 @@ tokio-stream = { version = "0.1", path = "../tokio-stream" } async-stream = "0.3.0" futures = "0.3.0" futures-test = "0.3.5" +parking_lot = "0.12.0" [package.metadata.docs.rs] all-features = true diff --git a/tokio-util/src/io/sync_bridge.rs b/tokio-util/src/io/sync_bridge.rs index 9be9446a7de..5587175b947 100644 --- a/tokio-util/src/io/sync_bridge.rs +++ b/tokio-util/src/io/sync_bridge.rs @@ -85,9 +85,10 @@ impl SyncIoBridge { /// /// Use e.g. `SyncIoBridge::new(Box::pin(src))`. /// - /// # Panic + /// # Panics /// /// This will panic if called outside the context of a Tokio runtime. + #[track_caller] pub fn new(src: T) -> Self { Self::new_with_handle(src, tokio::runtime::Handle::current()) } diff --git a/tokio-util/src/sync/mpsc.rs b/tokio-util/src/sync/mpsc.rs index 34a47c18911..55ed5c4dee2 100644 --- a/tokio-util/src/sync/mpsc.rs +++ b/tokio-util/src/sync/mpsc.rs @@ -136,6 +136,7 @@ impl PollSender { /// /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method /// will panic. + #[track_caller] pub fn send_item(&mut self, value: T) -> Result<(), PollSendError> { let (result, next_state) = match self.take_state() { State::Idle(_) | State::Acquiring => { diff --git a/tokio-util/src/task/spawn_pinned.rs b/tokio-util/src/task/spawn_pinned.rs index 567f9db0d8c..3811e6a4881 100644 --- a/tokio-util/src/task/spawn_pinned.rs +++ b/tokio-util/src/task/spawn_pinned.rs @@ -57,7 +57,9 @@ impl LocalPoolHandle { /// pool via [`LocalPoolHandle::spawn_pinned`]. /// /// # Panics + /// /// Panics if the pool size is less than one. + #[track_caller] pub fn new(pool_size: usize) -> LocalPoolHandle { assert!(pool_size > 0); @@ -167,6 +169,7 @@ impl LocalPoolHandle { /// } /// ``` /// + #[track_caller] pub fn spawn_pinned_by_idx(&self, create_task: F, idx: usize) -> JoinHandle where F: FnOnce() -> Fut, @@ -196,6 +199,7 @@ struct LocalPool { impl LocalPool { /// Spawn a `?Send` future onto a worker + #[track_caller] fn spawn_pinned( &self, create_task: F, @@ -324,6 +328,7 @@ impl LocalPool { } } + #[track_caller] fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) { let worker = &self.workers[idx]; worker.task_count.fetch_add(1, Ordering::SeqCst); diff --git a/tokio-util/src/time/delay_queue.rs b/tokio-util/src/time/delay_queue.rs index a0c5e5c5b06..07082b97793 100644 --- a/tokio-util/src/time/delay_queue.rs +++ b/tokio-util/src/time/delay_queue.rs @@ -531,6 +531,7 @@ impl DelayQueue { /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # + #[track_caller] pub fn insert_at(&mut self, value: T, when: Instant) -> Key { assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); @@ -649,10 +650,12 @@ impl DelayQueue { /// [`reset`]: method@Self::reset /// [`Key`]: struct@Key /// [type]: # + #[track_caller] pub fn insert(&mut self, value: T, timeout: Duration) -> Key { self.insert_at(value, Instant::now() + timeout) } + #[track_caller] fn insert_idx(&mut self, when: u64, key: Key) { use self::wheel::{InsertError, Stack}; @@ -674,6 +677,7 @@ impl DelayQueue { /// # Panics /// /// Panics if the key is not contained in the expired queue or the wheel. + #[track_caller] fn remove_key(&mut self, key: &Key) { use crate::time::wheel::Stack; @@ -713,6 +717,7 @@ impl DelayQueue { /// assert_eq!(*item.get_ref(), "foo"); /// # } /// ``` + #[track_caller] pub fn remove(&mut self, key: &Key) -> Expired { let prev_deadline = self.next_deadline(); @@ -769,6 +774,7 @@ impl DelayQueue { /// // "foo" is now scheduled to be returned in 10 seconds /// # } /// ``` + #[track_caller] pub fn reset_at(&mut self, key: &Key, when: Instant) { self.remove_key(key); @@ -873,6 +879,7 @@ impl DelayQueue { /// // "foo"is now scheduled to be returned in 10 seconds /// # } /// ``` + #[track_caller] pub fn reset(&mut self, key: &Key, timeout: Duration) { self.reset_at(key, Instant::now() + timeout); } @@ -978,7 +985,12 @@ impl DelayQueue { /// assert!(delay_queue.capacity() >= 11); /// # } /// ``` + #[track_caller] pub fn reserve(&mut self, additional: usize) { + assert!( + self.slab.capacity() + additional <= MAX_ENTRIES, + "max queue capacity exceeded" + ); self.slab.reserve(additional); } @@ -1117,6 +1129,7 @@ impl wheel::Stack for Stack { } } + #[track_caller] fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { let key = *item; assert!(store.contains(item)); diff --git a/tokio-util/src/time/wheel/mod.rs b/tokio-util/src/time/wheel/mod.rs index 4191e401df4..ffa05ab71bf 100644 --- a/tokio-util/src/time/wheel/mod.rs +++ b/tokio-util/src/time/wheel/mod.rs @@ -118,6 +118,7 @@ where } /// Remove `item` from the timing wheel. + #[track_caller] pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) { let when = T::when(item, store); diff --git a/tokio-util/tests/panic.rs b/tokio-util/tests/panic.rs new file mode 100644 index 00000000000..7a85e47c3bb --- /dev/null +++ b/tokio-util/tests/panic.rs @@ -0,0 +1,232 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use parking_lot::{const_mutex, Mutex}; +use std::error::Error; +use std::panic; +use std::sync::Arc; +use tokio::runtime::Runtime; +use tokio::sync::mpsc::channel; +use tokio::time::{Duration, Instant}; +use tokio_test::task; +use tokio_util::io::SyncIoBridge; +use tokio_util::sync::PollSender; +use tokio_util::task::LocalPoolHandle; +use tokio_util::time::DelayQueue; + +fn test_panic(func: Func) -> Option { + static PANIC_MUTEX: Mutex<()> = const_mutex(()); + + { + let _guard = PANIC_MUTEX.lock(); + let panic_file: Arc>> = Arc::new(Mutex::new(None)); + + let prev_hook = panic::take_hook(); + { + let panic_file = panic_file.clone(); + panic::set_hook(Box::new(move |panic_info| { + let panic_location = panic_info.location().unwrap(); + panic_file + .lock() + .clone_from(&Some(panic_location.file().to_string())); + })); + } + + let result = panic::catch_unwind(func); + // Return to the previously set panic hook (maybe default) so that we get nice error + // messages in the tests. + panic::set_hook(prev_hook); + + if result.is_err() { + panic_file.lock().clone() + } else { + None + } + } +} + +#[test] +fn sync_bridge_new_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let _ = SyncIoBridge::new(tokio::io::empty()); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn poll_sender_send_item_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let (send, _) = channel::(3); + let mut send = PollSender::new(send); + + let _ = send.send_item(42); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] + +fn local_pool_handle_new_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let _ = LocalPoolHandle::new(0); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] + +fn local_pool_handle_spawn_pinned_by_idx_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + + rt.block_on(async { + let handle = LocalPoolHandle::new(2); + handle.spawn_pinned_by_idx(|| async { "test" }, 3); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} +#[test] +fn delay_queue_insert_at_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let _k = queue.insert_at( + "1", + // ~24,855 days in the future + Instant::now() + Duration::from_secs(2_u64.pow(31)), + ); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_insert_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let _k = queue.insert( + "1", + // ~24,855 days + Duration::from_secs(2_u64.pow(31)), + ); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_remove_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let key = queue.insert_at("1", Instant::now()); + queue.remove(&key); + queue.remove(&key); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_reset_at_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let key = queue.insert_at("1", Instant::now()); + queue.reset_at( + &key, + // ~24,855 days in the future + Instant::now() + Duration::from_secs(2_u64.pow(31)), + ); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_reset_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let key = queue.insert_at("1", Instant::now()); + queue.reset( + &key, + // ~24,855 days + Duration::from_secs(2_u64.pow(31)), + ); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_reserve_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::::with_capacity(3)); + + queue.reserve((1 << 30) as usize); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +fn basic() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() +}