Skip to content

Commit

Permalink
blob/s3blob: Support S3 server side encryption headers for Write and …
Browse files Browse the repository at this point in the history
…Copy (#3340)
  • Loading branch information
tristan-newmann committed Feb 25, 2024
1 parent 83b78ee commit fb4e4b9
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 7 deletions.
87 changes: 80 additions & 7 deletions blob/s3blob/s3blob.go
Expand Up @@ -144,24 +144,58 @@ type URLOpener struct {
Options Options
}

const (
sseTypeParamKey = "ssetype"
kmsKeyIdParamKey = "kmskeyid"
)

func toServerSideEncryptionType(value string) (typesv2.ServerSideEncryption, error) {
for _, sseType := range typesv2.ServerSideEncryptionAes256.Values() {
if strings.ToLower(string(sseType)) == strings.ToLower(value) {
return sseType, nil
}
}
return "", fmt.Errorf("'%s' is not a valid value for '%s'", value, sseTypeParamKey)
}

// OpenBucketURL opens a blob.Bucket based on u.
func (o *URLOpener) OpenBucketURL(ctx context.Context, u *url.URL) (*blob.Bucket, error) {
q := u.Query()

if sseTypeParam := q.Get(sseTypeParamKey); sseTypeParam != "" {
q.Del(sseTypeParamKey)

sseType, err := toServerSideEncryptionType(sseTypeParam)
if err != nil {
return nil, err
}

o.Options.EncryptionType = sseType
}

if kmsKeyID := q.Get(kmsKeyIdParamKey); kmsKeyID != "" {
q.Del(kmsKeyIdParamKey)
o.Options.KMSEncryptionID = kmsKeyID
}

if o.UseV2 {
cfg, err := gcaws.V2ConfigFromURLParams(ctx, u.Query())
cfg, err := gcaws.V2ConfigFromURLParams(ctx, q)
if err != nil {
return nil, fmt.Errorf("open bucket %v: %v", u, err)
}
clientV2 := s3v2.NewFromConfig(cfg)

return OpenBucketV2(ctx, clientV2, u.Host, &o.Options)
}
configProvider := &gcaws.ConfigOverrider{
Base: o.ConfigProvider,
}
overrideCfg, err := gcaws.ConfigFromURLParams(u.Query())
overrideCfg, err := gcaws.ConfigFromURLParams(q)
if err != nil {
return nil, fmt.Errorf("open bucket %v: %v", u, err)
}
configProvider.Configs = append(configProvider.Configs, overrideCfg)

return OpenBucket(ctx, configProvider, u.Host, &o.Options)
}

Expand All @@ -171,6 +205,16 @@ type Options struct {
// Some S3-compatible services (like CEPH) do not currently support
// ListObjectsV2.
UseLegacyList bool

// EncryptionType sets the encryption type headers when making write or
// copy calls. This is required if the bucket has a restrictive bucket
// policy that enforces a specific encryption type
EncryptionType typesv2.ServerSideEncryption

// KMSEncryptionID sets the kms key id header for write or copy calls.
// This is required when a bucket policy enforces the use of a specific
// KMS key for uploads
KMSEncryptionID string
}

// openBucket returns an S3 Bucket.
Expand All @@ -193,11 +237,13 @@ func openBucket(ctx context.Context, useV2 bool, sess client.ConfigProvider, cli
client = s3.New(sess)
}
return &bucket{
useV2: useV2,
name: bucketName,
client: client,
clientV2: clientV2,
useLegacyList: opts.UseLegacyList,
useV2: useV2,
name: bucketName,
client: client,
clientV2: clientV2,
useLegacyList: opts.UseLegacyList,
kmsKeyId: opts.KMSEncryptionID,
encryptionType: opts.EncryptionType,
}, nil
}

Expand Down Expand Up @@ -365,6 +411,9 @@ type bucket struct {
client *s3.S3
clientV2 *s3v2.Client
useLegacyList bool

encryptionType typesv2.ServerSideEncryption
kmsKeyId string
}

func (b *bucket) Close() error {
Expand Down Expand Up @@ -973,6 +1022,12 @@ func (b *bucket) NewTypedWriter(ctx context.Context, key string, contentType str
if len(opts.ContentMD5) > 0 {
reqV2.ContentMD5 = aws.String(base64.StdEncoding.EncodeToString(opts.ContentMD5))
}
if b.encryptionType != "" {
reqV2.ServerSideEncryption = b.encryptionType
}
if b.kmsKeyId != "" {
reqV2.SSEKMSKeyId = aws.String(b.kmsKeyId)
}
if opts.BeforeWrite != nil {
asFunc := func(i interface{}) bool {
// Note that since the Go CDK Blob
Expand Down Expand Up @@ -1046,6 +1101,12 @@ func (b *bucket) NewTypedWriter(ctx context.Context, key string, contentType str
if len(opts.ContentMD5) > 0 {
req.ContentMD5 = aws.String(base64.StdEncoding.EncodeToString(opts.ContentMD5))
}
if b.encryptionType != "" {
req.ServerSideEncryption = aws.String(string(b.encryptionType))
}
if b.kmsKeyId != "" {
req.SSEKMSKeyId = aws.String(b.kmsKeyId)
}
if opts.BeforeWrite != nil {
asFunc := func(i interface{}) bool {
pu, ok := i.(**s3manager.Uploader)
Expand Down Expand Up @@ -1083,6 +1144,12 @@ func (b *bucket) Copy(ctx context.Context, dstKey, srcKey string, opts *driver.C
CopySource: aws.String(b.name + "/" + srcKey),
Key: aws.String(dstKey),
}
if b.encryptionType != "" {
input.ServerSideEncryption = b.encryptionType
}
if b.kmsKeyId != "" {
input.SSEKMSKeyId = aws.String(b.kmsKeyId)
}
if opts.BeforeCopy != nil {
asFunc := func(i interface{}) bool {
switch v := i.(type) {
Expand All @@ -1104,6 +1171,12 @@ func (b *bucket) Copy(ctx context.Context, dstKey, srcKey string, opts *driver.C
CopySource: aws.String(b.name + "/" + srcKey),
Key: aws.String(dstKey),
}
if b.encryptionType != "" {
input.ServerSideEncryption = aws.String(string(b.encryptionType))
}
if b.kmsKeyId != "" {
input.SSEKMSKeyId = aws.String(b.kmsKeyId)
}
if opts.BeforeCopy != nil {
asFunc := func(i interface{}) bool {
switch v := i.(type) {
Expand Down
33 changes: 33 additions & 0 deletions blob/s3blob/s3blob_test.go
Expand Up @@ -466,6 +466,10 @@ func TestOpenBucketFromURL(t *testing.T) {
{"s3://mybucket?profile=main&region=us-west-1", false},
// OK, use V2.
{"s3://mybucket?awssdk=v2", false},
// OK, use KMS Server Side Encryption
{"s3://mybucket?ssetype=aws:kms&kmskeyid=arn:aws:us-east-1:12345:key/1-a-2-b", false},
// Invalid ssetype
{"s3://mybucket?ssetype=aws:notkmsoraes&kmskeyid=arn:aws:us-east-1:12345:key/1-a-2-b", true},
// Invalid parameter together with a valid one.
{"s3://mybucket?profile=main&param=value", true},
// Invalid parameter.
Expand All @@ -483,3 +487,32 @@ func TestOpenBucketFromURL(t *testing.T) {
}
}
}

func TestToServerSideEncryptionType(t *testing.T) {
tests := []struct {
value string
sseType typesv2.ServerSideEncryption
expectedError error
}{
// OK.
{"AES256", typesv2.ServerSideEncryptionAes256, nil},
// OK, KMS
{"aws:kms", typesv2.ServerSideEncryptionAwsKms, nil},
// OK, KMS
{"aws:kms:dsse", typesv2.ServerSideEncryptionAwsKmsDsse, nil},
// OK, AES256 mixed case
{"Aes256", typesv2.ServerSideEncryptionAes256, nil},
// Invalid SSE type
{"invalid", "", fmt.Errorf("'invalid' is not a valid value for '%s'", sseTypeParamKey)},
}

for _, test := range tests {
sseType, err := toServerSideEncryptionType(test.value)
if ((err != nil) != (test.expectedError != nil)) && err.Error() != test.expectedError.Error() {
t.Errorf("%s: got error \"%v\", want error \"%v\"", test.value, err, test.expectedError)
}
if sseType != test.sseType {
t.Errorf("%s: got type %v, want type %v", test.value, sseType, test.sseType)
}
}
}

0 comments on commit fb4e4b9

Please sign in to comment.