Skip to content

Commit

Permalink
Merge pull request #657 from teloxide/gc_for_workers
Browse files Browse the repository at this point in the history
Add "GC" for dispatcher workers
  • Loading branch information
WaffleLapkin committed Jun 26, 2022
2 parents f812093 + 3e35d40 commit 19eca5d
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 17 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Expand Up @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## unreleased

### Fixed

- `Dispatcher` no longer "leaks" memory for every inactive user ([PR 657](https://github.com/teloxide/teloxide/pull/657)).

### Changed

- Add the `Key: Clone` requirement for `impl Dispatcher` [**BC**].

## 0.9.2 - 2022-06-07

### Fixed
Expand Down
101 changes: 85 additions & 16 deletions src/dispatching/dispatcher.rs
Expand Up @@ -11,17 +11,20 @@ use crate::{

use dptree::di::{DependencyMap, DependencySupplier};
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;

use std::{
collections::HashMap,
fmt::Debug,
future::Future,
hash::Hash,
ops::{ControlFlow, Deref},
sync::Arc,
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
};
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;

use std::future::Future;

/// The builder for [`Dispatcher`].
pub struct DispatcherBuilder<R, Err, Key> {
Expand Down Expand Up @@ -137,6 +140,8 @@ where
worker_queue_size,
workers: HashMap::new(),
default_worker: None,
current_number_of_active_workers: Default::default(),
max_number_of_active_workers: Default::default(),
}
}
}
Expand All @@ -158,6 +163,8 @@ pub struct Dispatcher<R, Err, Key> {

distribution_f: fn(&Update) -> Option<Key>,
worker_queue_size: usize,
current_number_of_active_workers: Arc<AtomicU32>,
max_number_of_active_workers: Arc<AtomicU32>,
// Tokio TX channel parts associated with chat IDs that consume updates sequentially.
workers: HashMap<Key, Worker>,
// The default TX part that consume updates concurrently.
Expand All @@ -171,6 +178,7 @@ pub struct Dispatcher<R, Err, Key> {
struct Worker {
tx: tokio::sync::mpsc::Sender<Update>,
handle: tokio::task::JoinHandle<()>,
is_waiting: Arc<AtomicBool>,
}

// TODO: it is allowed to return message as response on telegram request in
Expand Down Expand Up @@ -214,7 +222,7 @@ impl<R, Err, Key> Dispatcher<R, Err, Key>
where
R: Requester + Clone + Send + Sync + 'static,
Err: Send + Sync + 'static,
Key: Hash + Eq,
Key: Hash + Eq + Clone,
{
/// Starts your bot with the default parameters.
///
Expand Down Expand Up @@ -280,6 +288,8 @@ where
tokio::pin!(stream);

loop {
self.remove_inactive_workers_if_needed().await;

// False positive
#[allow(clippy::collapsible_match)]
if let Ok(upd) = timeout(shutdown_check_timeout, stream.next()).await {
Expand Down Expand Up @@ -342,6 +352,8 @@ where
handler,
default_handler,
error_handler,
Arc::clone(&self.current_number_of_active_workers),
Arc::clone(&self.max_number_of_active_workers),
self.worker_queue_size,
)
}),
Expand All @@ -367,6 +379,48 @@ where
}
}

async fn remove_inactive_workers_if_needed(&mut self) {
let workers = self.workers.len();
let max = self.max_number_of_active_workers.load(Ordering::Relaxed) as usize;

if workers <= max {
return;
}

self.remove_inactive_workers().await;
}

#[inline(never)] // Cold function.
async fn remove_inactive_workers(&mut self) {
let handles = self
.workers
.iter()
.filter(|(_, worker)| {
worker.tx.capacity() == self.worker_queue_size
&& worker.is_waiting.load(Ordering::Relaxed)
})
.map(|(k, _)| k)
.cloned()
.collect::<Vec<_>>()
.into_iter()
.map(|key| {
let Worker { tx, handle, .. } = self.workers.remove(&key).unwrap();

// Close channel, worker should stop almost immediately
// (it's been supposedly waiting on the channel)
drop(tx);

handle
});

for handle in handles {
// We must wait for worker to stop anyway, even though it should stop
// immediately. This helps in case if we've checked that the worker
// is waiting in between it received the update and set the flag.
let _ = handle.await;
}
}

/// Setups the `^C` handler that [`shutdown`]s dispatching.
///
/// [`shutdown`]: ShutdownToken::shutdown
Expand Down Expand Up @@ -405,25 +459,40 @@ fn spawn_worker<Err>(
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
current_number_of_active_workers: Arc<AtomicU32>,
max_number_of_active_workers: Arc<AtomicU32>,
queue_size: usize,
) -> Worker
where
Err: Send + Sync + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(queue_size);
let (tx, mut rx) = tokio::sync::mpsc::channel(queue_size);
let is_waiting = Arc::new(AtomicBool::new(true));
let is_waiting_local = Arc::clone(&is_waiting);

let deps = Arc::new(deps);

let handle = tokio::spawn(ReceiverStream::new(rx).for_each(move |update| {
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
let error_handler = Arc::clone(&error_handler);
let handle = tokio::spawn(async move {
while let Some(update) = rx.recv().await {
is_waiting_local.store(false, Ordering::Relaxed);
{
let current = current_number_of_active_workers.fetch_add(1, Ordering::Relaxed) + 1;
max_number_of_active_workers.fetch_max(current, Ordering::Relaxed);
}

handle_update(update, deps, handler, default_handler, error_handler)
}));
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
let error_handler = Arc::clone(&error_handler);

handle_update(update, deps, handler, default_handler, error_handler).await;

current_number_of_active_workers.fetch_sub(1, Ordering::Relaxed);
is_waiting_local.store(true, Ordering::Relaxed);
}
});

Worker { tx, handle }
Worker { tx, handle, is_waiting }
}

fn spawn_default_worker<Err>(
Expand All @@ -449,7 +518,7 @@ where
handle_update(update, deps, handler, default_handler, error_handler)
}));

Worker { tx, handle }
Worker { tx, handle, is_waiting: Arc::new(AtomicBool::new(true)) }
}

async fn handle_update<Err>(
Expand Down
2 changes: 1 addition & 1 deletion src/dispatching/distribution.rs
@@ -1,7 +1,7 @@
use teloxide_core::types::{ChatId, Update};

/// Default distribution key for dispatching.
#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct DefaultKey(ChatId);

pub(crate) fn default_distribution_function(update: &Update) -> Option<DefaultKey> {
Expand Down

0 comments on commit 19eca5d

Please sign in to comment.