Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "GC" for dispatcher workers #657

Merged
merged 5 commits into from Jun 26, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 85 additions & 16 deletions src/dispatching/dispatcher.rs
Expand Up @@ -11,17 +11,20 @@ use crate::{

use dptree::di::{DependencyMap, DependencySupplier};
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;

use std::{
collections::HashMap,
fmt::Debug,
future::Future,
hash::Hash,
ops::{ControlFlow, Deref},
sync::Arc,
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
};
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;

use std::future::Future;

/// The builder for [`Dispatcher`].
pub struct DispatcherBuilder<R, Err, Key> {
Expand Down Expand Up @@ -137,6 +140,8 @@ where
worker_queue_size,
workers: HashMap::new(),
default_worker: None,
current_number_of_active_workers: Default::default(),
max_number_of_active_workers: Default::default(),
}
}
}
Expand All @@ -158,6 +163,8 @@ pub struct Dispatcher<R, Err, Key> {

distribution_f: fn(&Update) -> Option<Key>,
worker_queue_size: usize,
current_number_of_active_workers: Arc<AtomicU32>,
max_number_of_active_workers: Arc<AtomicU32>,
// Tokio TX channel parts associated with chat IDs that consume updates sequentially.
workers: HashMap<Key, Worker>,
// The default TX part that consume updates concurrently.
Expand All @@ -171,6 +178,7 @@ pub struct Dispatcher<R, Err, Key> {
struct Worker {
tx: tokio::sync::mpsc::Sender<Update>,
handle: tokio::task::JoinHandle<()>,
is_waiting: Arc<AtomicBool>,
}

// TODO: it is allowed to return message as response on telegram request in
Expand Down Expand Up @@ -214,7 +222,7 @@ impl<R, Err, Key> Dispatcher<R, Err, Key>
where
R: Requester + Clone + Send + Sync + 'static,
Err: Send + Sync + 'static,
Key: Hash + Eq,
Key: Hash + Eq + Clone,
Hirrolot marked this conversation as resolved.
Show resolved Hide resolved
{
/// Starts your bot with the default parameters.
///
Expand Down Expand Up @@ -280,6 +288,8 @@ where
tokio::pin!(stream);

loop {
self.gc_workers_if_needed().await;

// False positive
#[allow(clippy::collapsible_match)]
if let Ok(upd) = timeout(shutdown_check_timeout, stream.next()).await {
Expand Down Expand Up @@ -342,6 +352,8 @@ where
handler,
default_handler,
error_handler,
Arc::clone(&self.current_number_of_active_workers),
Arc::clone(&self.max_number_of_active_workers),
self.worker_queue_size,
)
}),
Expand All @@ -367,6 +379,48 @@ where
}
}

async fn gc_workers_if_needed(&mut self) {
let workers = self.workers.len();
let max = self.max_number_of_active_workers.load(Ordering::Relaxed) as usize;

if workers <= max {
return;
}

self.gc_workers().await;
}

#[inline(never)]
async fn gc_workers(&mut self) {
Hirrolot marked this conversation as resolved.
Show resolved Hide resolved
Hirrolot marked this conversation as resolved.
Show resolved Hide resolved
let handles = self
.workers
.iter()
.filter(|(_, worker)| {
worker.tx.capacity() == self.worker_queue_size
&& worker.is_waiting.load(Ordering::Relaxed)
})
.map(|(k, _)| k)
.cloned()
.collect::<Vec<_>>()
.into_iter()
.map(|key| {
let Worker { tx, handle, .. } = self.workers.remove(&key).unwrap();

// Close channel, worker should stop almost immediately
// (it's been supposedly waiting on the channel)
drop(tx);

handle
});

for handle in handles {
// We must wait for worker to stop anyway, even though it should stop
// immediately. This helps in case if we've checked that the worker
// is waiting in between it received the update and set the flag.
let _ = handle.await;
}
}

/// Setups the `^C` handler that [`shutdown`]s dispatching.
///
/// [`shutdown`]: ShutdownToken::shutdown
Expand Down Expand Up @@ -405,25 +459,40 @@ fn spawn_worker<Err>(
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
current_number_of_active_workers: Arc<AtomicU32>,
max_number_of_active_workers: Arc<AtomicU32>,
queue_size: usize,
) -> Worker
where
Err: Send + Sync + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(queue_size);
let (tx, mut rx) = tokio::sync::mpsc::channel(queue_size);
let is_waiting = Arc::new(AtomicBool::new(true));
let is_waiting_local = Arc::clone(&is_waiting);

let deps = Arc::new(deps);

let handle = tokio::spawn(ReceiverStream::new(rx).for_each(move |update| {
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
let error_handler = Arc::clone(&error_handler);
let handle = tokio::spawn(async move {
while let Some(update) = rx.recv().await {
is_waiting_local.store(false, Ordering::Relaxed);
{
let current = current_number_of_active_workers.fetch_add(1, Ordering::Relaxed) + 1;
max_number_of_active_workers.fetch_max(current, Ordering::Relaxed);
}

handle_update(update, deps, handler, default_handler, error_handler)
}));
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
let error_handler = Arc::clone(&error_handler);

handle_update(update, deps, handler, default_handler, error_handler).await;

current_number_of_active_workers.fetch_sub(1, Ordering::Relaxed);
is_waiting_local.store(true, Ordering::Relaxed);
}
});

Worker { tx, handle }
Worker { tx, handle, is_waiting }
}

fn spawn_default_worker<Err>(
Expand All @@ -449,7 +518,7 @@ where
handle_update(update, deps, handler, default_handler, error_handler)
}));

Worker { tx, handle }
Worker { tx, handle, is_waiting: Arc::new(AtomicBool::new(true)) }
}

async fn handle_update<Err>(
Expand Down
2 changes: 1 addition & 1 deletion src/dispatching/distribution.rs
@@ -1,7 +1,7 @@
use teloxide_core::types::{ChatId, Update};

/// Default distribution key for dispatching.
#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct DefaultKey(ChatId);

pub(crate) fn default_distribution_function(update: &Update) -> Option<DefaultKey> {
Expand Down