diff --git a/examples/websockets_chat.rs b/examples/websockets_chat.rs index 53ac1fcd3..8ccbd6709 100644 --- a/examples/websockets_chat.rs +++ b/examples/websockets_chat.rs @@ -1,12 +1,12 @@ -#![deny(warnings)] +// #![deny(warnings)] use std::collections::HashMap; use std::sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }; -use futures::{future, Future, FutureExt, StreamExt}; -use tokio::sync::mpsc; +use futures::{FutureExt, StreamExt}; +use tokio::sync::{mpsc, Mutex}; use warp::ws::{Message, WebSocket}; use warp::Filter; @@ -36,7 +36,7 @@ async fn main() { .and(users) .map(|ws: warp::ws::Ws, users| { // This will call our function if the handshake succeeds. - ws.on_upgrade(move |socket| user_connected(socket, users).map(|result| result.unwrap())) + ws.on_upgrade(move |socket| user_connected(socket, users)) }); // GET / -> index html @@ -47,14 +47,14 @@ async fn main() { warp::serve(routes).run(([127, 0, 0, 1], 3030)).await; } -fn user_connected(ws: WebSocket, users: Users) -> impl Future> { +async fn user_connected(ws: WebSocket, users: Users) { // Use a counter to assign a new unique ID for this user. let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed); eprintln!("new chat user: {}", my_id); // Split the socket into a sender and receive of messages. - let (user_ws_tx, user_ws_rx) = ws.split(); + let (user_ws_tx, mut user_ws_rx) = ws.split(); // Use an unbounded channel to handle buffering and flushing of messages // to the websocket... @@ -66,7 +66,7 @@ fn user_connected(ws: WebSocket, users: Users) -> impl Future impl Future msg, + Err(e) => { + eprintln!("websocket error(uid={}): {}", my_id, e); + continue; + } + }; + user_message(my_id, msg, &users).await; + } + + // user_ws_rx stream will keep processing as long as the user stays + // connected. Once they disconnect, then... + user_disconnected(my_id, &users2).await; } -fn user_message(my_id: usize, msg: Message, users: &Users) { +async fn user_message(my_id: usize, msg: Message, users: &Users) { // Skip any non-Text messages... let msg = if let Ok(s) = msg.to_str() { s @@ -107,25 +106,22 @@ fn user_message(my_id: usize, msg: Message, users: &Users) { // // We use `retain` instead of a for loop so that we can reap any user that // appears to have disconnected. - for (&uid, tx) in users.lock().unwrap().iter_mut() { + for (&uid, tx) in users.lock().await.iter_mut() { if my_id != uid { - match tx.send(Ok(Message::text(new_msg.clone()))) { - Ok(()) => (), - Err(_disconnected) => { - // The tx is disconnected, our `user_disconnected` code - // should be happening in another task, nothing more to - // do here. - } + if let Err(_disconnected) = tx.send(Ok(Message::text(new_msg.clone()))) { + // The tx is disconnected, our `user_disconnected` code + // should be happening in another task, nothing more to + // do here. } } } } -fn user_disconnected(my_id: usize, users: &Users) { +async fn user_disconnected(my_id: usize, users: &Users) { eprintln!("good bye user: {}", my_id); // Stream closed up, so remove from the user list - users.lock().unwrap().remove(&my_id); + users.lock().await.remove(&my_id); } static INDEX_HTML: &str = r#"