From b9b711ddb5fcb6f49e799e615c68073ef7fdf954 Mon Sep 17 00:00:00 2001 From: pan93412 Date: Mon, 18 Apr 2022 18:54:23 +0800 Subject: [PATCH] feat(rest-api): add CORS and Rate Limit layer Note that CORS layer can't work since https://github.com/tower-rs/tower-http/pull/237 has not been released yet. --- rest-api/Cargo.toml | 2 ++ rest-api/src/main.rs | 42 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/rest-api/Cargo.toml b/rest-api/Cargo.toml index 33309137..eb4943d9 100644 --- a/rest-api/Cargo.toml +++ b/rest-api/Cargo.toml @@ -25,6 +25,8 @@ serde_json = "1.0.79" thiserror = "1.0.30" tokio = { version = "1.17.0", features = ["full"] } toml = "0.5.9" +tower = { version = "0.4.12", features = ["buffer", "limit", "load-shed"] } +tower-http = { version = "0.2.5", features = ["cors"] } tracing = "0.1.34" tracing-log = "0.1.2" tracing-subscriber = "0.3.11" diff --git a/rest-api/src/main.rs b/rest-api/src/main.rs index 58fbedb1..c306dbae 100644 --- a/rest-api/src/main.rs +++ b/rest-api/src/main.rs @@ -5,11 +5,17 @@ pub(crate) mod retrieve; pub(crate) mod schema; use axum::{ + error_handling::HandleErrorLayer, routing::{get, post}, Extension, Json, Router, }; +use http::{Method, StatusCode, HeaderMap}; use serde_json::{json, Value}; -use std::{net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tower::ServiceBuilder; +use tower_http::{ + cors::{Any, CorsLayer}, +}; use tracing::{debug, info, warn}; use unm_types::ContextBuilder; @@ -35,6 +41,37 @@ async fn main() { }); info!("Constructing app…"); + + let cors_layer = CorsLayer::new() + .allow_methods(vec![Method::GET, Method::POST]) + .allow_headers(vec![http::header::CONTENT_TYPE]) + .allow_origin(Any); + + let rate_limit_layer = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_| async { + ( + StatusCode::TOO_MANY_REQUESTS, + { + let mut hm = HeaderMap::new(); + hm.insert( + http::header::CONTENT_TYPE, + http::HeaderValue::from_static("application/json"), + ); + hm + }, + r#"{"error": "You request too fast. Please wait 5 minutes."}"#.to_string() + ) + })) + .buffer(1024) // Let RateLimit clone-able + .load_shed() + .rate_limit(30, Duration::from_secs(300)) // Allow only 30 requests per 5 minutes + .into_inner(); + + let limit_layer = ServiceBuilder::new() + .layer(cors_layer) + .layer(rate_limit_layer) + .into_inner(); + let app = Router::new() // `GET /` goes to `root` .route("/", get(root)) @@ -62,7 +99,8 @@ async fn main() { .route("/index", get(schema::schema_v1_index)) .route("/search", get(schema::schema_v1_search)) .route("/error", get(schema::schema_v1_error)) - }); + }) + .layer(limit_layer); let serve_address = std::env::var("SERVE_ADDRESS").unwrap_or_else(|_| "0.0.0.0:3000".to_string());