From 649261d06b67fa07283d352afb2fc2b399d666b1 Mon Sep 17 00:00:00 2001 From: Jesse Rittner Date: Fri, 12 Nov 2021 18:23:15 -0500 Subject: [PATCH] add factory for customizing credential providers --- CHANGELOG_PENDING.md | 2 ++ aws/session/credentials.go | 33 +++++++++++++------- aws/session/credentials_test.go | 53 +++++++++++++++++++++++++++++++++ aws/session/session.go | 9 ++++++ 4 files changed, 87 insertions(+), 10 deletions(-) diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..a2a03ea8b16 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -2,4 +2,6 @@ ### SDK Enhancements +* `aws/session`: Add factory for customizing the construction of default credential providers. Currently only supported for `stscreds.WebIdentityRoleProvider`. + ### SDK Bugs diff --git a/aws/session/credentials.go b/aws/session/credentials.go index 3efdac29ff4..c8d3d4cedf9 100644 --- a/aws/session/credentials.go +++ b/aws/session/credentials.go @@ -14,8 +14,17 @@ import ( "github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/internal/shareddefaults" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" ) +// A CredentialsProviderFactory specifies constructor functions for credentials. +type CredentialsProviderFactory struct { + // NewWebIdentityRoleProvider will return a new sts.WebIdentityRoleProvider. + // If omitted, then stscreds.NewWebIdentityRoleProvider will be used instead. + NewWebIdentityRoleProvider func(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *stscreds.WebIdentityRoleProvider +} + func resolveCredentials(cfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, @@ -40,6 +49,7 @@ func resolveCredentials(cfg *aws.Config, envCfg.WebIdentityTokenFilePath, envCfg.RoleARN, envCfg.RoleSessionName, + sessOpts.CredentialsProviderFactory, ) default: @@ -59,6 +69,7 @@ var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, " func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers, filepath string, roleARN, sessionName string, + factory *CredentialsProviderFactory, ) (*credentials.Credentials, error) { if len(filepath) == 0 { @@ -69,17 +80,18 @@ func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers, return nil, WebIdentityEmptyRoleARNErr } - creds := stscreds.NewWebIdentityCredentials( - &Session{ - Config: cfg, - Handlers: handlers.Copy(), - }, - roleARN, - sessionName, - filepath, - ) + svc := sts.New(&Session{ + Config: cfg, + Handlers: handlers.Copy(), + }) - return creds, nil + newProviderFunc := stscreds.NewWebIdentityRoleProvider + if factory != nil && factory.NewWebIdentityRoleProvider != nil { + newProviderFunc = factory.NewWebIdentityRoleProvider + } + + p := newProviderFunc(svc, roleARN, sessionName, filepath) + return credentials.NewCredentials(p), nil } func resolveCredsFromProfile(cfg *aws.Config, @@ -114,6 +126,7 @@ func resolveCredsFromProfile(cfg *aws.Config, sharedCfg.WebIdentityTokenFile, sharedCfg.RoleARN, sharedCfg.RoleSessionName, + sessOpts.CredentialsProviderFactory, ) case sharedCfg.hasSSOConfiguration(): diff --git a/aws/session/credentials_test.go b/aws/session/credentials_test.go index 914ee8a986e..dd3aa0c68cf 100644 --- a/aws/session/credentials_test.go +++ b/aws/session/credentials_test.go @@ -19,6 +19,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" @@ -26,6 +27,7 @@ import ( "github.com/aws/aws-sdk-go/internal/shareddefaults" "github.com/aws/aws-sdk-go/private/protocol" "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" ) func newEc2MetadataServer(key, secret string, closeAfterGetCreds bool) *httptest.Server { @@ -814,6 +816,57 @@ func TestSessionAssumeRole_WithMFA_ExtendedDuration(t *testing.T) { } } +func TestSessionAssumeRoleWithWebIdentity_Factory(t *testing.T) { + restoreEnvFn := initSessionTestEnv() + defer restoreEnvFn() + + os.Setenv("AWS_REGION", "us-east-1") + os.Setenv("AWS_ROLE_ARN", "web_identity_role_arn") + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "./testdata/wit.txt") + + endpointResolver, teardown := setupCredentialsEndpoints(t) + defer teardown() + + customFactoryCalled := false + + sess, err := NewSessionWithOptions(Options{ + Config: aws.Config{ + EndpointResolver: endpointResolver, + }, + CredentialsProviderFactory: &CredentialsProviderFactory{ + NewWebIdentityRoleProvider: func(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *stscreds.WebIdentityRoleProvider { + customFactoryCalled = true + return stscreds.NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path) + }, + }, + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !customFactoryCalled { + t.Errorf("expect custom factory to be called") + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := "WEB_IDENTITY_AKID", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "WEB_IDENTITY_SECRET", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "WEB_IDENTITY_SESSION_TOKEN", creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := stscreds.WebIdentityProviderName, creds.ProviderName; e != a { + t.Errorf("expect %v,got %v", e, a) + } +} + func ssoTestSetup() (func(), error) { dir, err := ioutil.TempDir("", "sso-test") if err != nil { diff --git a/aws/session/session.go b/aws/session/session.go index ebace4bb79d..4c4998647bd 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -304,6 +304,15 @@ type Options struct { // // AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6 EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState + + // Specifies custom constructor functions for creating credential + // providers. These functions are only used if the aws.Config does + // not already include credentials. + // + // If this field is nil, or if the corresponding constructor + // function is nil, then the default constructor function for + // the provider in question will be used. + CredentialsProviderFactory *CredentialsProviderFactory } // NewSessionWithOptions returns a new Session created from SDK defaults, config files,