Skip to content

Commit

Permalink
Add put_multipart_opts (#5435) (#5652)
Browse files Browse the repository at this point in the history
* Add put_multipart_opts (#5435)
  • Loading branch information
tustvold committed Apr 17, 2024
1 parent f276528 commit 4b49c34
Show file tree
Hide file tree
Showing 15 changed files with 461 additions and 292 deletions.
203 changes: 97 additions & 106 deletions object_store/src/aws/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use crate::aws::builder::S3EncryptionHeaders;
use crate::aws::checksum::Checksum;
use crate::aws::credential::{AwsCredential, CredentialExt};
use crate::aws::{
AwsAuthorizer, AwsCredentialProvider, S3ConditionalPut, S3CopyIfNotExists, STORE,
STRICT_PATH_ENCODE_SET,
AwsAuthorizer, AwsCredentialProvider, S3ConditionalPut, S3CopyIfNotExists, COPY_SOURCE_HEADER,
STORE, STRICT_PATH_ENCODE_SET, TAGS_HEADER,
};
use crate::client::get::GetClient;
use crate::client::header::{get_etag, HeaderConfig};
Expand All @@ -35,16 +35,16 @@ use crate::client::GetOptionsExt;
use crate::multipart::PartId;
use crate::path::DELIMITER;
use crate::{
Attribute, Attributes, ClientOptions, GetOptions, ListResult, MultipartId, Path, PutPayload,
PutResult, Result, RetryConfig,
Attribute, Attributes, ClientOptions, GetOptions, ListResult, MultipartId, Path,
PutMultipartOpts, PutPayload, PutResult, Result, RetryConfig, TagSet,
};
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::{Buf, Bytes};
use hyper::header::{CACHE_CONTROL, CONTENT_LENGTH};
use hyper::http;
use hyper::http::HeaderName;
use hyper::{http, HeaderMap};
use itertools::Itertools;
use md5::{Digest, Md5};
use percent_encoding::{utf8_percent_encode, PercentEncode};
Expand Down Expand Up @@ -98,9 +98,6 @@ pub(crate) enum Error {
#[snafu(display("Error getting list response body: {}", source))]
ListResponseBody { source: reqwest::Error },

#[snafu(display("Error performing create multipart request: {}", source))]
CreateMultipartRequest { source: crate::client::retry::Error },

#[snafu(display("Error getting create multipart response body: {}", source))]
CreateMultipartResponseBody { source: reqwest::Error },

Expand Down Expand Up @@ -289,8 +286,75 @@ impl<'a> Request<'a> {
Self { builder, ..self }
}

pub fn idempotent(mut self, idempotent: bool) -> Self {
self.idempotent = idempotent;
pub fn headers(self, headers: HeaderMap) -> Self {
let builder = self.builder.headers(headers);
Self { builder, ..self }
}

pub fn idempotent(self, idempotent: bool) -> Self {
Self { idempotent, ..self }
}

pub fn with_encryption_headers(self) -> Self {
let headers = self.config.encryption_headers.clone().into();
let builder = self.builder.headers(headers);
Self { builder, ..self }
}

pub fn with_session_creds(self, use_session_creds: bool) -> Self {
Self {
use_session_creds,
..self
}
}

pub fn with_tags(mut self, tags: TagSet) -> Self {
let tags = tags.encoded();
if !tags.is_empty() && !self.config.disable_tagging {
self.builder = self.builder.header(&TAGS_HEADER, tags);
}
self
}

pub fn with_attributes(self, attributes: Attributes) -> Self {
let mut has_content_type = false;
let mut builder = self.builder;
for (k, v) in &attributes {
builder = match k {
Attribute::CacheControl => builder.header(CACHE_CONTROL, v.as_ref()),
Attribute::ContentType => {
has_content_type = true;
builder.header(CONTENT_TYPE, v.as_ref())
}
};
}

if !has_content_type {
if let Some(value) = self.config.client_options.get_content_type(self.path) {
builder = builder.header(CONTENT_TYPE, value);
}
}
Self { builder, ..self }
}

pub fn with_payload(mut self, payload: PutPayload) -> Self {
if !self.config.skip_signature || self.config.checksum.is_some() {
let mut sha256 = Context::new(&digest::SHA256);
payload.iter().for_each(|x| sha256.update(x));
let payload_sha256 = sha256.finish();

if let Some(Checksum::SHA256) = self.config.checksum {
self.builder = self.builder.header(
"x-amz-checksum-sha256",
BASE64_STANDARD.encode(payload_sha256),
);
}
self.payload_sha256 = Some(payload_sha256);
}

let content_length = payload.content_length();
self.builder = self.builder.header(CONTENT_LENGTH, content_length);
self.payload = Some(payload);
self
}

Expand Down Expand Up @@ -335,81 +399,19 @@ impl S3Client {
Ok(Self { config, client })
}

/// Make an S3 PUT request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html>
///
/// Returns the ETag
pub fn put_request<'a>(
&'a self,
path: &'a Path,
payload: PutPayload,
attributes: Attributes,
with_encryption_headers: bool,
) -> Request<'a> {
pub fn request<'a>(&'a self, method: Method, path: &'a Path) -> Request<'a> {
let url = self.config.path_url(path);
let mut builder = self.client.request(Method::PUT, url);
if with_encryption_headers {
builder = builder.headers(self.config.encryption_headers.clone().into());
}

let mut sha256 = Context::new(&digest::SHA256);
payload.iter().for_each(|x| sha256.update(x));
let payload_sha256 = sha256.finish();

if let Some(Checksum::SHA256) = self.config.checksum {
builder = builder.header(
"x-amz-checksum-sha256",
BASE64_STANDARD.encode(payload_sha256),
)
}

let mut has_content_type = false;
for (k, v) in &attributes {
builder = match k {
Attribute::CacheControl => builder.header(CACHE_CONTROL, v.as_ref()),
Attribute::ContentType => {
has_content_type = true;
builder.header(CONTENT_TYPE, v.as_ref())
}
};
}

if !has_content_type {
if let Some(value) = self.config.client_options.get_content_type(path) {
builder = builder.header(CONTENT_TYPE, value);
}
}

Request {
path,
builder: builder.header(CONTENT_LENGTH, payload.content_length()),
payload: Some(payload),
payload_sha256: Some(payload_sha256),
builder: self.client.request(method, url),
payload: None,
payload_sha256: None,
config: &self.config,
use_session_creds: true,
idempotent: false,
}
}

/// Make an S3 Delete request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObject.html>
pub async fn delete_request<T: Serialize + ?Sized + Sync>(
&self,
path: &Path,
query: &T,
) -> Result<()> {
let credential = self.config.get_session_credential().await?;
let url = self.config.path_url(path);

self.client
.request(Method::DELETE, url)
.query(query)
.with_aws_sigv4(credential.authorizer(), None)
.send_retry(&self.config.retry_config)
.await
.map_err(|e| e.error(STORE, path.to_string()))?;

Ok(())
}

/// Make an S3 Delete Objects request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html>
///
/// Produces a vector of results, one for each path in the input vector. If
Expand Down Expand Up @@ -513,41 +515,29 @@ impl S3Client {
}

/// Make an S3 Copy request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html>
pub fn copy_request<'a>(&'a self, from: &'a Path, to: &Path) -> Request<'a> {
let url = self.config.path_url(to);
pub fn copy_request<'a>(&'a self, from: &Path, to: &'a Path) -> Request<'a> {
let source = format!("{}/{}", self.config.bucket, encode_path(from));

let builder = self
.client
.request(Method::PUT, url)
.header("x-amz-copy-source", source)
.headers(self.config.encryption_headers.clone().into());

Request {
builder,
path: from,
config: &self.config,
payload: None,
payload_sha256: None,
use_session_creds: false,
idempotent: false,
}
self.request(Method::PUT, to)
.idempotent(true)
.header(&COPY_SOURCE_HEADER, &source)
.headers(self.config.encryption_headers.clone().into())
.with_session_creds(false)
}

pub async fn create_multipart(&self, location: &Path) -> Result<MultipartId> {
let credential = self.config.get_session_credential().await?;
let url = format!("{}?uploads=", self.config.path_url(location),);

pub async fn create_multipart(
&self,
location: &Path,
opts: PutMultipartOpts,
) -> Result<MultipartId> {
let response = self
.client
.request(Method::POST, url)
.headers(self.config.encryption_headers.clone().into())
.with_aws_sigv4(credential.authorizer(), None)
.retryable(&self.config.retry_config)
.request(Method::POST, location)
.query(&[("uploads", "")])
.with_encryption_headers()
.with_attributes(opts.attributes)
.with_tags(opts.tags)
.idempotent(true)
.send()
.await
.context(CreateMultipartRequestSnafu)?
.await?
.bytes()
.await
.context(CreateMultipartResponseBodySnafu)?;
Expand All @@ -568,7 +558,8 @@ impl S3Client {
let part = (part_idx + 1).to_string();

let response = self
.put_request(path, data, Attributes::default(), false)
.request(Method::PUT, path)
.with_payload(data)
.query(&[("partNumber", &part), ("uploadId", upload_id)])
.idempotent(true)
.send()
Expand Down
48 changes: 34 additions & 14 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ use crate::signer::Signer;
use crate::util::STRICT_ENCODE_SET;
use crate::{
Error, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta,
ObjectStore, Path, PutMode, PutOptions, PutPayload, PutResult, Result, UploadPart,
ObjectStore, Path, PutMode, PutMultipartOpts, PutOptions, PutPayload, PutResult, Result,
UploadPart,
};

static TAGS_HEADER: HeaderName = HeaderName::from_static("x-amz-tagging");
static COPY_SOURCE_HEADER: HeaderName = HeaderName::from_static("x-amz-copy-source");

mod builder;
mod checksum;
Expand Down Expand Up @@ -156,12 +158,13 @@ impl ObjectStore for AmazonS3 {
payload: PutPayload,
opts: PutOptions,
) -> Result<PutResult> {
let attrs = opts.attributes;
let mut request = self.client.put_request(location, payload, attrs, true);
let tags = opts.tags.encoded();
if !tags.is_empty() && !self.client.config.disable_tagging {
request = request.header(&TAGS_HEADER, tags);
}
let request = self
.client
.request(Method::PUT, location)
.with_payload(payload)
.with_attributes(opts.attributes)
.with_tags(opts.tags)
.with_encryption_headers();

match (opts.mode, &self.client.config.conditional_put) {
(PutMode::Overwrite, _) => request.idempotent(true).do_put().await,
Expand Down Expand Up @@ -204,8 +207,12 @@ impl ObjectStore for AmazonS3 {
}
}

async fn put_multipart(&self, location: &Path) -> Result<Box<dyn MultipartUpload>> {
let upload_id = self.client.create_multipart(location).await?;
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOpts,
) -> Result<Box<dyn MultipartUpload>> {
let upload_id = self.client.create_multipart(location, opts).await?;

Ok(Box::new(S3MultiPartUpload {
part_idx: 0,
Expand All @@ -223,7 +230,8 @@ impl ObjectStore for AmazonS3 {
}

async fn delete(&self, location: &Path) -> Result<()> {
self.client.delete_request(location, &()).await
self.client.request(Method::DELETE, location).send().await?;
Ok(())
}

fn delete_stream<'a>(
Expand Down Expand Up @@ -351,15 +359,22 @@ impl MultipartUpload for S3MultiPartUpload {
async fn abort(&mut self) -> Result<()> {
self.state
.client
.delete_request(&self.state.location, &[("uploadId", &self.state.upload_id)])
.await
.request(Method::DELETE, &self.state.location)
.query(&[("uploadId", &self.state.upload_id)])
.idempotent(true)
.send()
.await?;

Ok(())
}
}

#[async_trait]
impl MultipartStore for AmazonS3 {
async fn create_multipart(&self, path: &Path) -> Result<MultipartId> {
self.client.create_multipart(path).await
self.client
.create_multipart(path, PutMultipartOpts::default())
.await
}

async fn put_part(
Expand All @@ -382,7 +397,12 @@ impl MultipartStore for AmazonS3 {
}

async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()> {
self.client.delete_request(path, &[("uploadId", id)]).await
self.client
.request(Method::DELETE, path)
.query(&[("uploadId", id)])
.send()
.await?;
Ok(())
}
}

Expand Down

0 comments on commit 4b49c34

Please sign in to comment.