-
Notifications
You must be signed in to change notification settings - Fork 224
/
server-custom-accept.rs
198 lines (170 loc) · 6.27 KB
/
server-custom-accept.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
//! A chat server that broadcasts a message to all connections.
//!
//! This is a simple line-based server which accepts WebSocket connections,
//! reads lines from those connections, and broadcasts the lines to all other
//! connected clients.
//!
//! You can test this out by running:
//!
//! cargo run --example server 127.0.0.1:12345
//!
//! And then in another window run:
//!
//! cargo run --example client ws://127.0.0.1:12345/socket
//!
//! You can run the second command in multiple windows and then chat between the
//! two, seeing the messages from the other client as they're received. For all
//! connected clients they'll all join the same room and see everyone else's
//! messages.
use std::{
collections::HashMap,
convert::Infallible,
env,
net::SocketAddr,
sync::{Arc, Mutex},
};
use hyper::{
body::Incoming,
header::{
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
UPGRADE,
},
server::conn::http1,
service::service_fn,
upgrade::Upgraded,
Method, Request, Response, StatusCode, Version,
};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use futures_channel::mpsc::{unbounded, UnboundedSender};
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};
use tokio_tungstenite::{
tungstenite::{
handshake::derive_accept_key,
protocol::{Message, Role},
},
WebSocketStream,
};
type Tx = UnboundedSender<Message>;
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
/// Helper methods to create responses.
mod body {
use http_body_util::{Either, Empty, Full};
use hyper::body::Bytes;
pub type Body = Either<Empty<Bytes>, Full<Bytes>>;
pub fn empty() -> Body {
Either::Left(Empty::new())
}
pub fn bytes<B: Into<Bytes>>(chunk: B) -> Body {
Either::Right(Full::from(chunk.into()))
}
}
async fn handle_connection(
peer_map: PeerMap,
ws_stream: WebSocketStream<TokioIo<Upgraded>>,
addr: SocketAddr,
) {
println!("WebSocket connection established: {}", addr);
// Insert the write part of this peer to the peer map.
let (tx, rx) = unbounded();
peer_map.lock().unwrap().insert(addr, tx);
let (outgoing, incoming) = ws_stream.split();
let broadcast_incoming = incoming.try_for_each(|msg| {
println!("Received a message from {}: {}", addr, msg.to_text().unwrap());
let peers = peer_map.lock().unwrap();
// We want to broadcast the message to everyone except ourselves.
let broadcast_recipients =
peers.iter().filter(|(peer_addr, _)| peer_addr != &&addr).map(|(_, ws_sink)| ws_sink);
for recp in broadcast_recipients {
recp.unbounded_send(msg.clone()).unwrap();
}
future::ok(())
});
let receive_from_others = rx.map(Ok).forward(outgoing);
pin_mut!(broadcast_incoming, receive_from_others);
future::select(broadcast_incoming, receive_from_others).await;
println!("{} disconnected", &addr);
peer_map.lock().unwrap().remove(&addr);
}
async fn handle_request(
peer_map: PeerMap,
mut req: Request<Incoming>,
addr: SocketAddr,
) -> Result<Response<body::Body>, Infallible> {
println!("Received a new, potentially ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
for (ref header, _value) in req.headers() {
println!("* {}", header);
}
let upgrade = HeaderValue::from_static("Upgrade");
let websocket = HeaderValue::from_static("websocket");
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap()))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !headers.get(SEC_WEBSOCKET_VERSION).map(|h| h == "13").unwrap_or(false)
|| key.is_none()
|| req.uri() != "/socket"
{
return Ok(Response::new(body::bytes("Hello World!")));
}
let ver = req.version();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
let upgraded = TokioIo::new(upgraded);
handle_connection(
peer_map,
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
addr,
)
.await;
}
Err(e) => println!("upgrade error: {}", e),
}
});
let mut res = Response::new(body::empty());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut().append(CONNECTION, upgrade);
res.headers_mut().append(UPGRADE, websocket);
res.headers_mut().append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
// Let's add an additional header to our response to the client.
res.headers_mut().append("MyCustomHeader", ":)".parse().unwrap());
res.headers_mut().append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());
Ok(res)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let state = PeerMap::new(Mutex::new(HashMap::new()));
let addr =
env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string()).parse::<SocketAddr>()?;
let listener = TcpListener::bind(addr).await?;
loop {
let (stream, remote_addr) = listener.accept().await?;
let state = state.clone();
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr));
let conn = http1::Builder::new().serve_connection(io, service).with_upgrades();
if let Err(err) = conn.await {
eprintln!("failed to serve connection: {err:?}");
}
});
}
}