diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e5fa7404..53ff1e5e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## unreleased +### Fixed + + - `Dispatcher` no longer "leaks" memory for every inactive user ([PR 657](https://github.com/teloxide/teloxide/pull/657)). + +### Changed + + - Add the `Key: Clone` requirement for `impl Dispatcher` [**BC**]. + ## 0.9.2 - 2022-06-07 ### Fixed diff --git a/src/dispatching/dispatcher.rs b/src/dispatching/dispatcher.rs index f458c23f5..74eace826 100644 --- a/src/dispatching/dispatcher.rs +++ b/src/dispatching/dispatcher.rs @@ -11,17 +11,20 @@ use crate::{ use dptree::di::{DependencyMap, DependencySupplier}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; +use tokio::time::timeout; +use tokio_stream::wrappers::ReceiverStream; + use std::{ collections::HashMap, fmt::Debug, + future::Future, hash::Hash, ops::{ControlFlow, Deref}, - sync::Arc, + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, }; -use tokio::time::timeout; -use tokio_stream::wrappers::ReceiverStream; - -use std::future::Future; /// The builder for [`Dispatcher`]. pub struct DispatcherBuilder { @@ -137,6 +140,8 @@ where worker_queue_size, workers: HashMap::new(), default_worker: None, + current_number_of_active_workers: Default::default(), + max_number_of_active_workers: Default::default(), } } } @@ -158,6 +163,8 @@ pub struct Dispatcher { distribution_f: fn(&Update) -> Option, worker_queue_size: usize, + current_number_of_active_workers: Arc, + max_number_of_active_workers: Arc, // Tokio TX channel parts associated with chat IDs that consume updates sequentially. workers: HashMap, // The default TX part that consume updates concurrently. @@ -171,6 +178,7 @@ pub struct Dispatcher { struct Worker { tx: tokio::sync::mpsc::Sender, handle: tokio::task::JoinHandle<()>, + is_waiting: Arc, } // TODO: it is allowed to return message as response on telegram request in @@ -214,7 +222,7 @@ impl Dispatcher where R: Requester + Clone + Send + Sync + 'static, Err: Send + Sync + 'static, - Key: Hash + Eq, + Key: Hash + Eq + Clone, { /// Starts your bot with the default parameters. /// @@ -280,6 +288,8 @@ where tokio::pin!(stream); loop { + self.remove_inactive_workers_if_needed().await; + // False positive #[allow(clippy::collapsible_match)] if let Ok(upd) = timeout(shutdown_check_timeout, stream.next()).await { @@ -342,6 +352,8 @@ where handler, default_handler, error_handler, + Arc::clone(&self.current_number_of_active_workers), + Arc::clone(&self.max_number_of_active_workers), self.worker_queue_size, ) }), @@ -367,6 +379,48 @@ where } } + async fn remove_inactive_workers_if_needed(&mut self) { + let workers = self.workers.len(); + let max = self.max_number_of_active_workers.load(Ordering::Relaxed) as usize; + + if workers <= max { + return; + } + + self.remove_inactive_workers().await; + } + + #[inline(never)] // Cold function. + async fn remove_inactive_workers(&mut self) { + let handles = self + .workers + .iter() + .filter(|(_, worker)| { + worker.tx.capacity() == self.worker_queue_size + && worker.is_waiting.load(Ordering::Relaxed) + }) + .map(|(k, _)| k) + .cloned() + .collect::>() + .into_iter() + .map(|key| { + let Worker { tx, handle, .. } = self.workers.remove(&key).unwrap(); + + // Close channel, worker should stop almost immediately + // (it's been supposedly waiting on the channel) + drop(tx); + + handle + }); + + for handle in handles { + // We must wait for worker to stop anyway, even though it should stop + // immediately. This helps in case if we've checked that the worker + // is waiting in between it received the update and set the flag. + let _ = handle.await; + } + } + /// Setups the `^C` handler that [`shutdown`]s dispatching. /// /// [`shutdown`]: ShutdownToken::shutdown @@ -405,25 +459,40 @@ fn spawn_worker( handler: Arc>, default_handler: DefaultHandler, error_handler: Arc + Send + Sync>, + current_number_of_active_workers: Arc, + max_number_of_active_workers: Arc, queue_size: usize, ) -> Worker where Err: Send + Sync + 'static, { - let (tx, rx) = tokio::sync::mpsc::channel(queue_size); + let (tx, mut rx) = tokio::sync::mpsc::channel(queue_size); + let is_waiting = Arc::new(AtomicBool::new(true)); + let is_waiting_local = Arc::clone(&is_waiting); let deps = Arc::new(deps); - let handle = tokio::spawn(ReceiverStream::new(rx).for_each(move |update| { - let deps = Arc::clone(&deps); - let handler = Arc::clone(&handler); - let default_handler = Arc::clone(&default_handler); - let error_handler = Arc::clone(&error_handler); + let handle = tokio::spawn(async move { + while let Some(update) = rx.recv().await { + is_waiting_local.store(false, Ordering::Relaxed); + { + let current = current_number_of_active_workers.fetch_add(1, Ordering::Relaxed) + 1; + max_number_of_active_workers.fetch_max(current, Ordering::Relaxed); + } - handle_update(update, deps, handler, default_handler, error_handler) - })); + let deps = Arc::clone(&deps); + let handler = Arc::clone(&handler); + let default_handler = Arc::clone(&default_handler); + let error_handler = Arc::clone(&error_handler); + + handle_update(update, deps, handler, default_handler, error_handler).await; + + current_number_of_active_workers.fetch_sub(1, Ordering::Relaxed); + is_waiting_local.store(true, Ordering::Relaxed); + } + }); - Worker { tx, handle } + Worker { tx, handle, is_waiting } } fn spawn_default_worker( @@ -449,7 +518,7 @@ where handle_update(update, deps, handler, default_handler, error_handler) })); - Worker { tx, handle } + Worker { tx, handle, is_waiting: Arc::new(AtomicBool::new(true)) } } async fn handle_update( diff --git a/src/dispatching/distribution.rs b/src/dispatching/distribution.rs index 208e0018f..2089b7b6d 100644 --- a/src/dispatching/distribution.rs +++ b/src/dispatching/distribution.rs @@ -1,7 +1,7 @@ use teloxide_core::types::{ChatId, Update}; /// Default distribution key for dispatching. -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct DefaultKey(ChatId); pub(crate) fn default_distribution_function(update: &Update) -> Option {