diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index e6c1bdd74a0..1abf42be910 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -23,11 +23,12 @@ use bytes::Buf; use chrono::{DateTime, Utc}; use futures::TryFutureExt; use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest::{Client, Method, Request, RequestBuilder}; +use reqwest::{Client, Method, Request, RequestBuilder, StatusCode}; use serde::Deserialize; use std::collections::BTreeMap; use std::sync::Arc; use std::time::Instant; +use tracing::warn; type StdError = Box; @@ -284,6 +285,7 @@ pub struct InstanceCredentialProvider { pub cache: TokenCache>, pub client: Client, pub retry_config: RetryConfig, + pub imdsv1_fallback: bool, } impl InstanceCredentialProvider { @@ -291,11 +293,16 @@ impl InstanceCredentialProvider { self.cache .get_or_insert_with(|| { const METADATA_ENDPOINT: &str = "http://169.254.169.254"; - instance_creds(&self.client, &self.retry_config, METADATA_ENDPOINT) - .map_err(|source| crate::Error::Generic { - store: "S3", - source, - }) + instance_creds( + &self.client, + &self.retry_config, + METADATA_ENDPOINT, + self.imdsv1_fallback, + ) + .map_err(|source| crate::Error::Generic { + store: "S3", + source, + }) }) .await } @@ -360,36 +367,47 @@ async fn instance_creds( client: &Client, retry_config: &RetryConfig, endpoint: &str, + imdsv1_fallback: bool, ) -> Result>, StdError> { const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials"; const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token"; let token_url = format!("{}/latest/api/token", endpoint); - let token = client + + let token_result = client .request(Method::PUT, token_url) .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL .send_retry(retry_config) - .await? - .text() - .await?; + .await; + + let token = match token_result { + Ok(t) => Some(t.text().await?), + Err(e) + if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) => + { + warn!("received 403 from metadata endpoint, falling back to IMDSv1"); + None + } + Err(e) => return Err(e.into()), + }; let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH); - let role = client - .request(Method::GET, role_url) - .header(AWS_EC2_METADATA_TOKEN_HEADER, &token) - .send_retry(retry_config) - .await? - .text() - .await?; + let mut role_request = client.request(Method::GET, role_url); + + if let Some(token) = &token { + role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token); + } + + let role = role_request.send_retry(retry_config).await?.text().await?; let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role); - let creds: InstanceCredentials = client - .request(Method::GET, creds_url) - .header(AWS_EC2_METADATA_TOKEN_HEADER, &token) - .send_retry(retry_config) - .await? - .json() - .await?; + let mut creds_request = client.request(Method::GET, creds_url); + if let Some(token) = &token { + creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token); + } + + let creds: InstanceCredentials = + creds_request.send_retry(retry_config).await?.json().await?; let now = Utc::now(); let ttl = (creds.expiration - now).to_std().unwrap_or_default(); @@ -470,6 +488,8 @@ async fn web_identity( #[cfg(test)] mod tests { use super::*; + use crate::client::mock_server::MockServer; + use hyper::{Body, Response}; use reqwest::{Client, Method}; use std::env; @@ -567,11 +587,11 @@ mod tests { assert_eq!( resp.status(), - reqwest::StatusCode::UNAUTHORIZED, + StatusCode::UNAUTHORIZED, "Ensure metadata endpoint is set to only allow IMDSv2" ); - let creds = instance_creds(&client, &retry_config, &endpoint) + let creds = instance_creds(&client, &retry_config, &endpoint, false) .await .unwrap(); @@ -583,4 +603,97 @@ mod tests { assert!(!secret.is_empty()); assert!(!token.is_empty()) } + + #[tokio::test] + async fn test_mock() { + let server = MockServer::new(); + + const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token"; + + let secret_access_key = "SECRET"; + let access_key_id = "KEYID"; + let token = "TOKEN"; + + let endpoint = server.url(); + let client = Client::new(); + let retry_config = RetryConfig::default(); + + // Test IMDSv2 + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/api/token"); + assert_eq!(req.method(), &Method::PUT); + Response::new(Body::from("cupcakes")) + }); + server.push_fn(|req| { + assert_eq!( + req.uri().path(), + "/latest/meta-data/iam/security-credentials/" + ); + assert_eq!(req.method(), &Method::GET); + let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); + assert_eq!(t, "cupcakes"); + Response::new(Body::from("myrole")) + }); + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); + assert_eq!(req.method(), &Method::GET); + let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); + assert_eq!(t, "cupcakes"); + Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) + }); + + let creds = instance_creds(&client, &retry_config, endpoint, true) + .await + .unwrap(); + + assert_eq!(creds.token.token.as_deref().unwrap(), token); + assert_eq!(&creds.token.key_id, access_key_id); + assert_eq!(&creds.token.secret_key, secret_access_key); + + // Test IMDSv1 fallback + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/api/token"); + assert_eq!(req.method(), &Method::PUT); + Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::empty()) + .unwrap() + }); + server.push_fn(|req| { + assert_eq!( + req.uri().path(), + "/latest/meta-data/iam/security-credentials/" + ); + assert_eq!(req.method(), &Method::GET); + assert!(req.headers().get(IMDSV2_HEADER).is_none()); + Response::new(Body::from("myrole")) + }); + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); + assert_eq!(req.method(), &Method::GET); + assert!(req.headers().get(IMDSV2_HEADER).is_none()); + Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) + }); + + let creds = instance_creds(&client, &retry_config, endpoint, true) + .await + .unwrap(); + + assert_eq!(creds.token.token.as_deref().unwrap(), token); + assert_eq!(&creds.token.key_id, access_key_id); + assert_eq!(&creds.token.secret_key, secret_access_key); + + // Test IMDSv1 fallback disabled + server.push( + Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::empty()) + .unwrap(), + ); + + // Should fail + instance_creds(&client, &retry_config, endpoint, false) + .await + .unwrap_err(); + } } diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index ab90afa5deb..d1d0a12cdaf 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -339,6 +339,7 @@ pub struct AmazonS3Builder { token: Option, retry_config: RetryConfig, allow_http: bool, + imdsv1_fallback: bool, } impl AmazonS3Builder { @@ -446,6 +447,23 @@ impl AmazonS3Builder { self } + /// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends + /// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack] + /// + /// However, certain deployment environments, such as those running old versions of kube2iam, + /// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1 + /// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported. + /// + /// This option has no effect if not using instance credentials + /// + /// [IMDSv2]: [https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html] + /// [SSRF attack]: [https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/] + /// + pub fn with_imdsv1_fallback(mut self) -> Self { + self.imdsv1_fallback = true; + self + } + /// Create a [`AmazonS3`] instance from the provided values, /// consuming `self`. pub fn build(self) -> Result { @@ -503,6 +521,7 @@ impl AmazonS3Builder { cache: Default::default(), client, retry_config: self.retry_config.clone(), + imdsv1_fallback: self.imdsv1_fallback, }) } }, diff --git a/object_store/src/client/mock_server.rs b/object_store/src/client/mock_server.rs new file mode 100644 index 00000000000..adb7e0fff77 --- /dev/null +++ b/object_store/src/client/mock_server.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Request, Response, Server}; +use parking_lot::Mutex; +use std::collections::VecDeque; +use std::convert::Infallible; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; + +pub type ResponseFn = Box) -> Response + Send>; + +/// A mock server +pub struct MockServer { + responses: Arc>>, + shutdown: oneshot::Sender<()>, + handle: JoinHandle<()>, + url: String, +} + +impl MockServer { + pub fn new() -> Self { + let responses: Arc>> = + Arc::new(Mutex::new(VecDeque::with_capacity(10))); + + let r = Arc::clone(&responses); + let make_service = make_service_fn(move |_conn| { + let r = Arc::clone(&r); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + let r = Arc::clone(&r); + async move { + Ok::<_, Infallible>(match r.lock().pop_front() { + Some(r) => r(req), + None => Response::new(Body::from("Hello World")), + }) + } + })) + } + }); + + let (shutdown, rx) = oneshot::channel::<()>(); + let server = + Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service); + + let url = format!("http://{}", server.local_addr()); + + let handle = tokio::spawn(async move { + server + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .await + .unwrap() + }); + + Self { + responses, + shutdown, + handle, + url, + } + } + + /// The url of the mock server + pub fn url(&self) -> &str { + &self.url + } + + /// Add a response + pub fn push(&self, response: Response) { + self.push_fn(|_| response) + } + + /// Add a response function + pub fn push_fn(&self, f: F) + where + F: FnOnce(Request) -> Response + Send + 'static, + { + self.responses.lock().push_back(Box::new(f)) + } + + /// Shutdown the mock server + pub async fn shutdown(self) { + let _ = self.shutdown.send(()); + self.handle.await.unwrap() + } +} diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index e6de3e92923..c93c68a1faa 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -18,6 +18,8 @@ //! Generic utilities reqwest based ObjectStore implementations pub mod backoff; +#[cfg(test)] +pub mod mock_server; pub mod pagination; pub mod retry; pub mod token; diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index 44d7835a554..d66628aec45 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -180,54 +180,17 @@ impl RetryExt for reqwest::RequestBuilder { #[cfg(test)] mod tests { + use crate::client::mock_server::MockServer; use crate::client::retry::RetryExt; use crate::RetryConfig; use hyper::header::LOCATION; - use hyper::service::{make_service_fn, service_fn}; - use hyper::{Body, Response, Server}; - use parking_lot::Mutex; + use hyper::{Body, Response}; use reqwest::{Client, Method, StatusCode}; - use std::collections::VecDeque; - use std::convert::Infallible; - use std::net::SocketAddr; - use std::sync::Arc; use std::time::Duration; #[tokio::test] async fn test_retry() { - let responses: Arc>>> = - Arc::new(Mutex::new(VecDeque::with_capacity(10))); - - let r = Arc::clone(&responses); - let make_service = make_service_fn(move |_conn| { - let r = Arc::clone(&r); - async move { - Ok::<_, Infallible>(service_fn(move |_req| { - let r = Arc::clone(&r); - async move { - Ok::<_, Infallible>(match r.lock().pop_front() { - Some(r) => r, - None => Response::new(Body::from("Hello World")), - }) - } - })) - } - }); - - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let server = - Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service); - - let url = format!("http://{}", server.local_addr()); - - let server_handle = tokio::spawn(async move { - server - .with_graceful_shutdown(async { - rx.await.ok(); - }) - .await - .unwrap() - }); + let mock = MockServer::new(); let retry = RetryConfig { backoff: Default::default(), @@ -236,14 +199,14 @@ mod tests { }; let client = Client::new(); - let do_request = || client.request(Method::GET, &url).send_retry(&retry); + let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); // Simple request should work let r = do_request().await.unwrap(); assert_eq!(r.status(), StatusCode::OK); // Returns client errors immediately with status message - responses.lock().push_back( + mock.push( Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::from("cupcakes")) @@ -256,7 +219,7 @@ mod tests { assert_eq!(&e.message, "cupcakes"); // Handles client errors with no payload - responses.lock().push_back( + mock.push( Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::empty()) @@ -269,7 +232,7 @@ mod tests { assert_eq!(&e.message, "No Body"); // Should retry server error request - responses.lock().push_back( + mock.push( Response::builder() .status(StatusCode::BAD_GATEWAY) .body(Body::empty()) @@ -280,7 +243,7 @@ mod tests { assert_eq!(r.status(), StatusCode::OK); // Accepts 204 status code - responses.lock().push_back( + mock.push( Response::builder() .status(StatusCode::NO_CONTENT) .body(Body::empty()) @@ -291,7 +254,7 @@ mod tests { assert_eq!(r.status(), StatusCode::NO_CONTENT); // Follows redirects - responses.lock().push_back( + mock.push( Response::builder() .status(StatusCode::FOUND) .header(LOCATION, "/foo") @@ -305,7 +268,7 @@ mod tests { // Gives up after the retrying the specified number of times for _ in 0..=retry.max_retries { - responses.lock().push_back( + mock.push( Response::builder() .status(StatusCode::BAD_GATEWAY) .body(Body::from("ignored")) @@ -318,7 +281,6 @@ mod tests { assert_eq!(e.message, "502 Bad Gateway"); // Shutdown - let _ = tx.send(()); - server_handle.await.unwrap(); + mock.shutdown().await } }