Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update websockets chat to use tokio Mutex instead of std #438

Merged
merged 2 commits into from Feb 10, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 32 additions & 36 deletions examples/websockets_chat.rs
@@ -1,12 +1,12 @@
#![deny(warnings)]
// #![deny(warnings)]
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
Arc,
};

use futures::{future, Future, FutureExt, StreamExt};
use tokio::sync::mpsc;
use futures::{FutureExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use warp::ws::{Message, WebSocket};
use warp::Filter;

Expand Down Expand Up @@ -36,7 +36,7 @@ async fn main() {
.and(users)
.map(|ws: warp::ws::Ws, users| {
// This will call our function if the handshake succeeds.
ws.on_upgrade(move |socket| user_connected(socket, users).map(|result| result.unwrap()))
ws.on_upgrade(move |socket| user_connected(socket, users))
});

// GET / -> index html
Expand All @@ -47,14 +47,14 @@ async fn main() {
warp::serve(routes).run(([127, 0, 0, 1], 3030)).await;
}

fn user_connected(ws: WebSocket, users: Users) -> impl Future<Output = Result<(), ()>> {
async fn user_connected(ws: WebSocket, users: Users) {
// Use a counter to assign a new unique ID for this user.
let my_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);

eprintln!("new chat user: {}", my_id);

// Split the socket into a sender and receive of messages.
let (user_ws_tx, user_ws_rx) = ws.split();
let (user_ws_tx, mut user_ws_rx) = ws.split();

// Use an unbounded channel to handle buffering and flushing of messages
// to the websocket...
Expand All @@ -66,34 +66,33 @@ fn user_connected(ws: WebSocket, users: Users) -> impl Future<Output = Result<()
}));

// Save the sender in our list of connected users.
users.lock().unwrap().insert(my_id, tx);
users.lock().await.insert(my_id, tx);

// Return a `Future` that is basically a state machine managing
// this specific user's connection.

// Make an extra clone to give to our disconnection handler...
let users2 = users.clone();

user_ws_rx
// Every time the user sends a message, broadcast it to
// all other users...
.for_each(move |msg| {
user_message(my_id, msg.unwrap(), &users);
future::ready(())
})
// for_each will keep processing as long as the user stays
// connected. Once they disconnect, then...
.then(move |result| {
user_disconnected(my_id, &users2);
future::ok(result)
})
// If at any time, there was a websocket error, log here...
// .map_err(move |e| {
// eprintln!("websocket error(uid={}): {}", my_id, e);
// })
// Every time the user sends a message, broadcast it to
// all other users...
while let Some(result) = user_ws_rx.next().await {
let msg = match result {
Ok(msg) => msg,
Err(e) => {
eprintln!("websocket error(uid={}): {}", my_id, e);
continue;
}
};
user_message(my_id, msg, &users).await;
}

// user_ws_rx stream will keep processing as long as the user stays
// connected. Once they disconnect, then...
user_disconnected(my_id, &users2).await;
}

fn user_message(my_id: usize, msg: Message, users: &Users) {
async fn user_message(my_id: usize, msg: Message, users: &Users) {
// Skip any non-Text messages...
let msg = if let Ok(s) = msg.to_str() {
s
Expand All @@ -107,25 +106,22 @@ fn user_message(my_id: usize, msg: Message, users: &Users) {
//
// We use `retain` instead of a for loop so that we can reap any user that
// appears to have disconnected.
for (&uid, tx) in users.lock().unwrap().iter_mut() {
for (&uid, tx) in users.lock().await.iter_mut() {
if my_id != uid {
match tx.send(Ok(Message::text(new_msg.clone()))) {
Ok(()) => (),
Err(_disconnected) => {
// The tx is disconnected, our `user_disconnected` code
// should be happening in another task, nothing more to
// do here.
}
if let Err(_disconnected) = tx.send(Ok(Message::text(new_msg.clone()))) {
// The tx is disconnected, our `user_disconnected` code
// should be happening in another task, nothing more to
// do here.
}
}
}
}

fn user_disconnected(my_id: usize, users: &Users) {
async fn user_disconnected(my_id: usize, users: &Users) {
eprintln!("good bye user: {}", my_id);

// Stream closed up, so remove from the user list
users.lock().unwrap().remove(&my_id);
users.lock().await.remove(&my_id);
}

static INDEX_HTML: &str = r#"
Expand Down