Skip to content

Commit

Permalink
Add IMDSv1 fallback (#2609) (#2610)
Browse files Browse the repository at this point in the history
* Add IMDSv1 fallback (#2609)

* Add config option
  • Loading branch information
tustvold committed Aug 30, 2022
1 parent 171f80b commit 62eeaa5
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 75 deletions.
165 changes: 139 additions & 26 deletions object_store/src/aws/credential.rs
Expand Up @@ -23,11 +23,12 @@ use bytes::Buf;
use chrono::{DateTime, Utc};
use futures::TryFutureExt;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, Method, Request, RequestBuilder};
use reqwest::{Client, Method, Request, RequestBuilder, StatusCode};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Instant;
use tracing::warn;

type StdError = Box<dyn std::error::Error + Send + Sync>;

Expand Down Expand Up @@ -284,18 +285,24 @@ pub struct InstanceCredentialProvider {
pub cache: TokenCache<Arc<AwsCredential>>,
pub client: Client,
pub retry_config: RetryConfig,
pub imdsv1_fallback: bool,
}

impl InstanceCredentialProvider {
async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
self.cache
.get_or_insert_with(|| {
const METADATA_ENDPOINT: &str = "http://169.254.169.254";
instance_creds(&self.client, &self.retry_config, METADATA_ENDPOINT)
.map_err(|source| crate::Error::Generic {
store: "S3",
source,
})
instance_creds(
&self.client,
&self.retry_config,
METADATA_ENDPOINT,
self.imdsv1_fallback,
)
.map_err(|source| crate::Error::Generic {
store: "S3",
source,
})
})
.await
}
Expand Down Expand Up @@ -360,36 +367,47 @@ async fn instance_creds(
client: &Client,
retry_config: &RetryConfig,
endpoint: &str,
imdsv1_fallback: bool,
) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials";
const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";

let token_url = format!("{}/latest/api/token", endpoint);
let token = client

let token_result = client
.request(Method::PUT, token_url)
.header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL
.send_retry(retry_config)
.await?
.text()
.await?;
.await;

let token = match token_result {
Ok(t) => Some(t.text().await?),
Err(e)
if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) =>
{
warn!("received 403 from metadata endpoint, falling back to IMDSv1");
None
}
Err(e) => return Err(e.into()),
};

let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH);
let role = client
.request(Method::GET, role_url)
.header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
.send_retry(retry_config)
.await?
.text()
.await?;
let mut role_request = client.request(Method::GET, role_url);

if let Some(token) = &token {
role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
}

let role = role_request.send_retry(retry_config).await?.text().await?;

let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role);
let creds: InstanceCredentials = client
.request(Method::GET, creds_url)
.header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
.send_retry(retry_config)
.await?
.json()
.await?;
let mut creds_request = client.request(Method::GET, creds_url);
if let Some(token) = &token {
creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
}

let creds: InstanceCredentials =
creds_request.send_retry(retry_config).await?.json().await?;

let now = Utc::now();
let ttl = (creds.expiration - now).to_std().unwrap_or_default();
Expand Down Expand Up @@ -470,6 +488,8 @@ async fn web_identity(
#[cfg(test)]
mod tests {
use super::*;
use crate::client::mock_server::MockServer;
use hyper::{Body, Response};
use reqwest::{Client, Method};
use std::env;

Expand Down Expand Up @@ -567,11 +587,11 @@ mod tests {

assert_eq!(
resp.status(),
reqwest::StatusCode::UNAUTHORIZED,
StatusCode::UNAUTHORIZED,
"Ensure metadata endpoint is set to only allow IMDSv2"
);

let creds = instance_creds(&client, &retry_config, &endpoint)
let creds = instance_creds(&client, &retry_config, &endpoint, false)
.await
.unwrap();

Expand All @@ -583,4 +603,97 @@ mod tests {
assert!(!secret.is_empty());
assert!(!token.is_empty())
}

#[tokio::test]
async fn test_mock() {
let server = MockServer::new();

const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token";

let secret_access_key = "SECRET";
let access_key_id = "KEYID";
let token = "TOKEN";

let endpoint = server.url();
let client = Client::new();
let retry_config = RetryConfig::default();

// Test IMDSv2
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/api/token");
assert_eq!(req.method(), &Method::PUT);
Response::new(Body::from("cupcakes"))
});
server.push_fn(|req| {
assert_eq!(
req.uri().path(),
"/latest/meta-data/iam/security-credentials/"
);
assert_eq!(req.method(), &Method::GET);
let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
assert_eq!(t, "cupcakes");
Response::new(Body::from("myrole"))
});
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
assert_eq!(req.method(), &Method::GET);
let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
assert_eq!(t, "cupcakes");
Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
});

let creds = instance_creds(&client, &retry_config, endpoint, true)
.await
.unwrap();

assert_eq!(creds.token.token.as_deref().unwrap(), token);
assert_eq!(&creds.token.key_id, access_key_id);
assert_eq!(&creds.token.secret_key, secret_access_key);

// Test IMDSv1 fallback
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/api/token");
assert_eq!(req.method(), &Method::PUT);
Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::empty())
.unwrap()
});
server.push_fn(|req| {
assert_eq!(
req.uri().path(),
"/latest/meta-data/iam/security-credentials/"
);
assert_eq!(req.method(), &Method::GET);
assert!(req.headers().get(IMDSV2_HEADER).is_none());
Response::new(Body::from("myrole"))
});
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
assert_eq!(req.method(), &Method::GET);
assert!(req.headers().get(IMDSV2_HEADER).is_none());
Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
});

let creds = instance_creds(&client, &retry_config, endpoint, true)
.await
.unwrap();

assert_eq!(creds.token.token.as_deref().unwrap(), token);
assert_eq!(&creds.token.key_id, access_key_id);
assert_eq!(&creds.token.secret_key, secret_access_key);

// Test IMDSv1 fallback disabled
server.push(
Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::empty())
.unwrap(),
);

// Should fail
instance_creds(&client, &retry_config, endpoint, false)
.await
.unwrap_err();
}
}
19 changes: 19 additions & 0 deletions object_store/src/aws/mod.rs
Expand Up @@ -339,6 +339,7 @@ pub struct AmazonS3Builder {
token: Option<String>,
retry_config: RetryConfig,
allow_http: bool,
imdsv1_fallback: bool,
}

impl AmazonS3Builder {
Expand Down Expand Up @@ -446,6 +447,23 @@ impl AmazonS3Builder {
self
}

/// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends
/// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack]
///
/// However, certain deployment environments, such as those running old versions of kube2iam,
/// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1
/// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported.
///
/// This option has no effect if not using instance credentials
///
/// [IMDSv2]: [https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html]
/// [SSRF attack]: [https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/]
///
pub fn with_imdsv1_fallback(mut self) -> Self {
self.imdsv1_fallback = true;
self
}

/// Create a [`AmazonS3`] instance from the provided values,
/// consuming `self`.
pub fn build(self) -> Result<AmazonS3> {
Expand Down Expand Up @@ -503,6 +521,7 @@ impl AmazonS3Builder {
cache: Default::default(),
client,
retry_config: self.retry_config.clone(),
imdsv1_fallback: self.imdsv1_fallback,
})
}
},
Expand Down
105 changes: 105 additions & 0 deletions object_store/src/client/mock_server.rs
@@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;

pub type ResponseFn = Box<dyn FnOnce(Request<Body>) -> Response<Body> + Send>;

/// A mock server
pub struct MockServer {
responses: Arc<Mutex<VecDeque<ResponseFn>>>,
shutdown: oneshot::Sender<()>,
handle: JoinHandle<()>,
url: String,
}

impl MockServer {
pub fn new() -> Self {
let responses: Arc<Mutex<VecDeque<ResponseFn>>> =
Arc::new(Mutex::new(VecDeque::with_capacity(10)));

let r = Arc::clone(&responses);
let make_service = make_service_fn(move |_conn| {
let r = Arc::clone(&r);
async move {
Ok::<_, Infallible>(service_fn(move |req| {
let r = Arc::clone(&r);
async move {
Ok::<_, Infallible>(match r.lock().pop_front() {
Some(r) => r(req),
None => Response::new(Body::from("Hello World")),
})
}
}))
}
});

let (shutdown, rx) = oneshot::channel::<()>();
let server =
Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service);

let url = format!("http://{}", server.local_addr());

let handle = tokio::spawn(async move {
server
.with_graceful_shutdown(async {
rx.await.ok();
})
.await
.unwrap()
});

Self {
responses,
shutdown,
handle,
url,
}
}

/// The url of the mock server
pub fn url(&self) -> &str {
&self.url
}

/// Add a response
pub fn push(&self, response: Response<Body>) {
self.push_fn(|_| response)
}

/// Add a response function
pub fn push_fn<F>(&self, f: F)
where
F: FnOnce(Request<Body>) -> Response<Body> + Send + 'static,
{
self.responses.lock().push_back(Box::new(f))
}

/// Shutdown the mock server
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
self.handle.await.unwrap()
}
}
2 changes: 2 additions & 0 deletions object_store/src/client/mod.rs
Expand Up @@ -18,6 +18,8 @@
//! Generic utilities reqwest based ObjectStore implementations

pub mod backoff;
#[cfg(test)]
pub mod mock_server;
pub mod pagination;
pub mod retry;
pub mod token;

0 comments on commit 62eeaa5

Please sign in to comment.