From d35de32636c8dd21f9bf420c7c1b10658159aa99 Mon Sep 17 00:00:00 2001 From: Robert van Gent Date: Wed, 3 Aug 2022 12:42:40 -0700 Subject: [PATCH] secrets/awskms: Add support for EncryptionContext parameters. (#3154) --- secrets/awskms/kms.go | 81 ++++++++++++++++++++++++++++++++------ secrets/awskms/kms_test.go | 48 ++++++++++++++++++++++ 2 files changed, 118 insertions(+), 11 deletions(-) diff --git a/secrets/awskms/kms.go b/secrets/awskms/kms.go index c0acfca1de..7cbbb899de 100644 --- a/secrets/awskms/kms.go +++ b/secrets/awskms/kms.go @@ -124,6 +124,11 @@ const Scheme = "awskms" // Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2, // or anything else to accept the default. // +// EncryptionContext key/value pairs can be provided by providing URL parameters prefixed +// with "context_"; e.g., "...&context_abc=foo&context_def=bar" would result in +// an EncryptionContext of {abc=foo, def=bar}. +// See https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#encrypt_context. +// // For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters // for overriding the aws.Session from the URL. // For V2, see gocloud.dev/aws/V2ConfigFromURLParams. @@ -135,9 +140,29 @@ type URLOpener struct { ConfigProvider client.ConfigProvider // Options specifies the options to pass to OpenKeeper. + // EncryptionContext parameters from the URL are merged in. Options KeeperOptions } +// addEncryptionContextFromURLParams merges any EncryptionContext URL parameters from +// u into opts.EncryptionParameters. +// It removes the processed URL parameters from u. +func addEncryptionContextFromURLParams(opts *KeeperOptions, u url.Values) error { + for k, vs := range u { + if strings.HasPrefix(k, "context_") { + if len(vs) != 1 { + return fmt.Errorf("open keeper: EncryptionContext URL parameters %q must have exactly 1 value", k) + } + u.Del(k) + if opts.EncryptionContext == nil { + opts.EncryptionContext = map[string]string{} + } + opts.EncryptionContext[k[8:]] = vs[0] + } + } + return nil +} + // OpenKeeperURL opens an AWS KMS Keeper based on u. func (o *URLOpener) OpenKeeperURL(ctx context.Context, u *url.URL) (*secrets.Keeper, error) { // A leading "/" means the Host was empty; trim the slash. @@ -145,8 +170,14 @@ func (o *URLOpener) OpenKeeperURL(ctx context.Context, u *url.URL) (*secrets.Kee // "/foo:bar". keyID := strings.TrimPrefix(path.Join(u.Host, u.Path), "/") + queryParams := u.Query() + opts := o.Options + if err := addEncryptionContextFromURLParams(&opts, queryParams); err != nil { + return nil, err + } + if o.UseV2 { - cfg, err := gcaws.V2ConfigFromURLParams(ctx, u.Query()) + cfg, err := gcaws.V2ConfigFromURLParams(ctx, queryParams) if err != nil { return nil, fmt.Errorf("open keeper %v: %v", u, err) } @@ -154,12 +185,12 @@ func (o *URLOpener) OpenKeeperURL(ctx context.Context, u *url.URL) (*secrets.Kee if err != nil { return nil, err } - return OpenKeeperV2(clientV2, keyID, &o.Options), nil + return OpenKeeperV2(clientV2, keyID, &opts), nil } configProvider := &gcaws.ConfigOverrider{ Base: o.ConfigProvider, } - overrideCfg, err := gcaws.ConfigFromURLParams(u.Query()) + overrideCfg, err := gcaws.ConfigFromURLParams(queryParams) if err != nil { return nil, fmt.Errorf("open keeper %v: %v", u, err) } @@ -168,7 +199,7 @@ func (o *URLOpener) OpenKeeperURL(ctx context.Context, u *url.URL) (*secrets.Kee if err != nil { return nil, err } - return OpenKeeper(client, keyID, &o.Options), nil + return OpenKeeper(client, keyID, &opts), nil } // OpenKeeper returns a *secrets.Keeper that uses AWS KMS. @@ -178,10 +209,14 @@ func (o *URLOpener) OpenKeeperURL(ctx context.Context, u *url.URL) (*secrets.Kee // for more details. // See the package documentation for an example. func OpenKeeper(client *kms.KMS, keyID string, opts *KeeperOptions) *secrets.Keeper { + if opts == nil { + opts = &KeeperOptions{} + } return secrets.NewKeeper(&keeper{ useV2: false, keyID: keyID, client: client, + opts: *opts, }) } @@ -192,25 +227,42 @@ func OpenKeeper(client *kms.KMS, keyID string, opts *KeeperOptions) *secrets.Kee // for more details. // See the package documentation for an example. func OpenKeeperV2(client *kmsv2.Client, keyID string, opts *KeeperOptions) *secrets.Keeper { + if opts == nil { + opts = &KeeperOptions{} + } return secrets.NewKeeper(&keeper{ useV2: true, keyID: keyID, clientV2: client, + opts: *opts, }) } type keeper struct { useV2 bool keyID string + opts KeeperOptions client *kms.KMS clientV2 *kmsv2.Client } +func (k *keeper) v1EncryptionContext() map[string]*string { + if len(k.opts.EncryptionContext) == 0 { + return nil + } + ec := map[string]*string{} + for k, v := range k.opts.EncryptionContext { + ec[k] = &v + } + return ec +} + // Decrypt decrypts the ciphertext into a plaintext. func (k *keeper) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) { if k.useV2 { result, err := k.clientV2.Decrypt(ctx, &kmsv2.DecryptInput{ - CiphertextBlob: ciphertext, + CiphertextBlob: ciphertext, + EncryptionContext: k.opts.EncryptionContext, }) if err != nil { return nil, err @@ -218,7 +270,8 @@ func (k *keeper) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) return result.Plaintext, nil } result, err := k.client.Decrypt(&kms.DecryptInput{ - CiphertextBlob: ciphertext, + CiphertextBlob: ciphertext, + EncryptionContext: k.v1EncryptionContext(), }) if err != nil { return nil, err @@ -230,8 +283,9 @@ func (k *keeper) Decrypt(ctx context.Context, ciphertext []byte) ([]byte, error) func (k *keeper) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) { if k.useV2 { result, err := k.clientV2.Encrypt(ctx, &kmsv2.EncryptInput{ - KeyId: aws.String(k.keyID), - Plaintext: plaintext, + KeyId: aws.String(k.keyID), + Plaintext: plaintext, + EncryptionContext: k.opts.EncryptionContext, }) if err != nil { return nil, err @@ -239,8 +293,9 @@ func (k *keeper) Encrypt(ctx context.Context, plaintext []byte) ([]byte, error) return result.CiphertextBlob, nil } result, err := k.client.Encrypt(&kms.EncryptInput{ - KeyId: aws.String(k.keyID), - Plaintext: plaintext, + KeyId: aws.String(k.keyID), + Plaintext: plaintext, + EncryptionContext: k.v1EncryptionContext(), }) if err != nil { return nil, err @@ -305,4 +360,8 @@ var errorCodeMap = map[string]gcerrors.ErrorCode{ // KeeperOptions controls Keeper behaviors. // It is provided for future extensibility. -type KeeperOptions struct{} +type KeeperOptions struct { + // EncryptionContext parameters. + // See https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#encrypt_context. + EncryptionContext map[string]string +} diff --git a/secrets/awskms/kms_test.go b/secrets/awskms/kms_test.go index eb6ad5fcb3..e872625f18 100644 --- a/secrets/awskms/kms_test.go +++ b/secrets/awskms/kms_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "net/url" "os" "testing" @@ -26,6 +27,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/smithy-go" + "github.com/google/go-cmp/cmp" "gocloud.dev/internal/testing/setup" "gocloud.dev/secrets" "gocloud.dev/secrets/driver" @@ -145,6 +147,48 @@ func TestNoConnectionError(t *testing.T) { } } +func TestEncryptionContext(t *testing.T) { + tests := []struct { + Existing map[string]string + URL string + WantErr bool + Want map[string]string + }{ + // None before or after. + {nil, "http://foo", false, nil}, + // New parameter. + {nil, "http://foo?context_foo=bar", false, map[string]string{"foo": "bar"}}, + // 2 new parameters. + {nil, "http://foo?context_foo=bar&context_abc=baz", false, map[string]string{"foo": "bar", "abc": "baz"}}, + // Multiple values. + {nil, "http://foo?context_foo=bar&context_foo=baz", true, nil}, + // Existing, no new. + {map[string]string{"foo": "bar"}, "http://foo", false, map[string]string{"foo": "bar"}}, + // No-conflict merge. + {map[string]string{"foo": "bar"}, "http://foo?context_abc=baz", false, map[string]string{"foo": "bar", "abc": "baz"}}, + // Overwrite merge. + {map[string]string{"foo": "bar"}, "http://foo?context_foo=baz", false, map[string]string{"foo": "baz"}}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("existing %v URL %v", test.Existing, test.URL), func(t *testing.T) { + opts := KeeperOptions{ + EncryptionContext: test.Existing, + } + u, err := url.Parse(test.URL) + if err != nil { + t.Fatal(err) + } + err = addEncryptionContextFromURLParams(&opts, u.Query()) + if (err != nil) != test.WantErr { + t.Fatalf("got err %v, want error? %v", err, test.WantErr) + } + if diff := cmp.Diff(opts.EncryptionContext, test.Want); diff != "" { + t.Errorf("diff %v", diff) + } + }) + } +} + func TestOpenKeeper(t *testing.T) { tests := []struct { URL string @@ -162,6 +206,10 @@ func TestOpenKeeper(t *testing.T) { {"awskms://alias/my-key?awssdk=v1", false}, // OK, using V2. {"awskms://alias/my-key?awssdk=v2", false}, + // OK, adding EncryptionContext. + {"awskms://alias/my-key?context_abc=foo&context_def=bar", false}, + // Multiple values for an EncryptionContext. + {"awskms://alias/my-key?context_abc=foo&context_abc=bar", true}, // Unknown parameter. {"awskms://alias/my-key?param=value", true}, }