diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index f800fec3dc5..5ec9390ec89 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -16,6 +16,7 @@ // under the License. use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider}; +use crate::aws::STRICT_PATH_ENCODE_SET; use crate::client::pagination::stream_paginated; use crate::client::retry::RetryExt; use crate::multipart::UploadPart; @@ -26,26 +27,13 @@ use crate::{ }; use bytes::{Buf, Bytes}; use chrono::{DateTime, Utc}; -use percent_encoding::{utf8_percent_encode, AsciiSet, PercentEncode, NON_ALPHANUMERIC}; +use percent_encoding::{utf8_percent_encode, PercentEncode}; use reqwest::{Client as ReqwestClient, Method, Response, StatusCode}; use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu}; use std::ops::Range; use std::sync::Arc; -// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html -// -// Do not URI-encode any of the unreserved characters that RFC 3986 defines: -// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ). -const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC - .remove(b'-') - .remove(b'.') - .remove(b'_') - .remove(b'~'); - -/// This struct is used to maintain the URI path encoding -const STRICT_PATH_ENCODE_SET: AsciiSet = STRICT_ENCODE_SET.remove(b'/'); - /// A specialized `Error` for object store-related errors #[derive(Debug, Snafu)] #[allow(missing_docs)] diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index 1abf42be910..d4461645f3c 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::aws::STRICT_ENCODE_SET; use crate::client::retry::RetryExt; use crate::client::token::{TemporaryToken, TokenCache}; use crate::util::hmac_sha256; @@ -22,6 +23,7 @@ use crate::{Result, RetryConfig}; use bytes::Buf; use chrono::{DateTime, Utc}; use futures::TryFutureExt; +use percent_encoding::utf8_percent_encode; use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Client, Method, Request, RequestBuilder, StatusCode}; use serde::Deserialize; @@ -29,6 +31,7 @@ use std::collections::BTreeMap; use std::sync::Arc; use std::time::Instant; use tracing::warn; +use url::Url; type StdError = Box; @@ -103,13 +106,14 @@ impl<'a> RequestSigner<'a> { request.headers_mut().insert(HASH_HEADER, header_digest); let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); + let canonical_query = canonicalize_query(request.url()); // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html let canonical_request = format!( "{}\n{}\n{}\n{}\n{}\n{}", request.method().as_str(), request.url().path(), // S3 doesn't percent encode this like other services - request.url().query().unwrap_or(""), // This assumes the query pairs are in order + canonical_query, canonical_headers, signed_headers, digest @@ -207,6 +211,37 @@ fn hex_encode(bytes: &[u8]) -> String { out } +/// Canonicalizes query parameters into the AWS canonical form +/// +/// +fn canonicalize_query(url: &Url) -> String { + use std::fmt::Write; + + let capacity = match url.query() { + Some(q) if !q.is_empty() => q.len(), + _ => return String::new(), + }; + let mut encoded = String::with_capacity(capacity + 1); + + let mut headers = url.query_pairs().collect::>(); + headers.sort_unstable_by(|(a, _), (b, _)| a.cmp(b)); + + let mut first = true; + for (k, v) in headers { + if !first { + encoded.push('&'); + } + first = false; + let _ = write!( + encoded, + "{}={}", + utf8_percent_encode(k.as_ref(), &STRICT_ENCODE_SET), + utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET) + ); + } + encoded +} + /// Canonicalizes headers into the AWS Canonical Form. /// /// diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index d1d0a12cdaf..89c1a4c2af0 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -58,6 +58,20 @@ use crate::{ mod client; mod credential; +// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html +// +// Do not URI-encode any of the unreserved characters that RFC 3986 defines: +// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ). +pub(crate) const STRICT_ENCODE_SET: percent_encoding::AsciiSet = + percent_encoding::NON_ALPHANUMERIC + .remove(b'-') + .remove(b'.') + .remove(b'_') + .remove(b'~'); + +/// This struct is used to maintain the URI path encoding +const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet = STRICT_ENCODE_SET.remove(b'/'); + /// A specialized `Error` for object store-related errors #[derive(Debug, Snafu)] #[allow(missing_docs)] diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 16f0c6f3a2a..4c9527a24fd 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -701,6 +701,20 @@ mod tests { assert_eq!(files, vec![path.clone()]); storage.delete(&path).await.unwrap(); + + let path = Path::parse("foo bar/I contain spaces.parquet").unwrap(); + storage.put(&path, Bytes::from(vec![0, 1])).await.unwrap(); + storage.head(&path).await.unwrap(); + let files = flatten_list_stream(storage, Some(&Path::from("foo bar"))) + .await + .unwrap(); + assert_eq!(files, vec![path.clone()]); + storage.delete(&path).await.unwrap(); + + let files = flatten_list_stream(storage, Some(&Path::from("foo bar"))) + .await + .unwrap(); + assert!(files.is_empty(), "{:?}", files); } fn get_vec_of_bytes(chunk_length: usize, num_chunks: usize) -> Vec { diff --git a/object_store/src/path/mod.rs b/object_store/src/path/mod.rs index e5a7b6443bb..80e0f792aa5 100644 --- a/object_store/src/path/mod.rs +++ b/object_store/src/path/mod.rs @@ -534,4 +534,15 @@ mod tests { needle ); } + + #[test] + fn path_containing_spaces() { + let a = Path::from_iter(["foo bar", "baz"]); + let b = Path::from("foo bar/baz"); + let c = Path::parse("foo bar/baz").unwrap(); + + assert_eq!(a.raw, "foo bar/baz"); + assert_eq!(a.raw, b.raw); + assert_eq!(b.raw, c.raw); + } }