diff --git a/examples/axum-key-value-store/Cargo.toml b/examples/axum-key-value-store/Cargo.toml index 765dc397..b6e0e6f2 100644 --- a/examples/axum-key-value-store/Cargo.toml +++ b/examples/axum-key-value-store/Cargo.toml @@ -9,9 +9,9 @@ license = "MIT" [dependencies] hyper = { version = "0.14.15", features = ["full"] } tokio = { version = "1.2.0", features = ["full"] } -tower = { version = "0.4.5", features = ["full"] } +tower = { version = "0.4.13", features = ["full"] } tower-http = { path = "../../tower-http", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } -axum = "0.4" +axum = "0.6" clap = { version = "3.1.13", features = ["derive"] } diff --git a/examples/axum-key-value-store/src/main.rs b/examples/axum-key-value-store/src/main.rs index 18b0fc5a..12841189 100644 --- a/examples/axum-key-value-store/src/main.rs +++ b/examples/axum-key-value-store/src/main.rs @@ -1,21 +1,21 @@ use axum::{ body::Bytes, - error_handling::HandleErrorLayer, - extract::{Extension, Path}, + extract::{Path, State}, http::{header, HeaderValue, StatusCode}, response::IntoResponse, routing::get, - BoxError, Router, + Router, }; use clap::Parser; use std::{ collections::HashMap, - net::SocketAddr, + net::{Ipv4Addr, SocketAddr}, sync::{Arc, RwLock}, time::Duration, }; use tower::ServiceBuilder; use tower_http::{ + timeout::TimeoutLayer, trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}, LatencyUnit, ServiceBuilderExt, }; @@ -29,7 +29,7 @@ struct Config { } #[derive(Clone, Debug)] -struct State { +struct AppState { db: Arc>>, } @@ -42,7 +42,7 @@ async fn main() { let config = Config::parse(); // Run our service - let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); + let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, config.port)); tracing::info!("Listening on {}", addr); axum::Server::bind(&addr) .serve(app().into_make_service()) @@ -50,23 +50,9 @@ async fn main() { .expect("server error"); } -async fn handle_errors(err: BoxError) -> impl IntoResponse { - if err.is::() { - ( - StatusCode::REQUEST_TIMEOUT, - "Request took too long".to_string(), - ) - } else { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {}", err), - ) - } -} - fn app() -> Router { // Build our database for holding the key/value pairs - let state = State { + let state = AppState { db: Arc::new(RwLock::new(HashMap::new())), }; @@ -86,12 +72,10 @@ fn app() -> Router { .on_response(DefaultOnResponse::new().include_headers(true).latency_unit(LatencyUnit::Micros)), ) .sensitive_response_headers(sensitive_headers) - // Handle errors - .layer(HandleErrorLayer::new(handle_errors)) // Set a timeout - .timeout(Duration::from_secs(10)) - // Share the state with each handler via a request extension - .add_extension(state) + .layer(TimeoutLayer::new(Duration::from_secs(10))) + // Box the response body so it implements `Default` which is required by axum + .map_response_body(axum::body::boxed) // Compress responses .compression() // Set a `Content-Type` if there isn't one already. @@ -103,10 +87,11 @@ fn app() -> Router { // Build route service Router::new() .route("/:key", get(get_key).post(set_key)) - .layer(middleware.into_inner()) + .layer(middleware) + .with_state(state) } -async fn get_key(path: Path, state: Extension) -> impl IntoResponse { +async fn get_key(path: Path, state: State) -> impl IntoResponse { let state = state.db.read().unwrap(); if let Some(value) = state.get(&*path).cloned() { @@ -116,7 +101,7 @@ async fn get_key(path: Path, state: Extension) -> impl IntoRespon } } -async fn set_key(Path(path): Path, state: Extension, value: Bytes) { +async fn set_key(Path(path): Path, state: State, value: Bytes) { let mut state = state.db.write().unwrap(); state.insert(path, value); }