Skip to content

Commit

Permalink
add factory for customizing credential providers
Browse files Browse the repository at this point in the history
  • Loading branch information
rittneje committed Nov 12, 2021
1 parent 1369084 commit 649261d
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Expand Up @@ -2,4 +2,6 @@

### SDK Enhancements

* `aws/session`: Add factory for customizing the construction of default credential providers. Currently only supported for `stscreds.WebIdentityRoleProvider`.

### SDK Bugs
33 changes: 23 additions & 10 deletions aws/session/credentials.go
Expand Up @@ -14,8 +14,17 @@ import (
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

// A CredentialsProviderFactory specifies constructor functions for credentials.
type CredentialsProviderFactory struct {
// NewWebIdentityRoleProvider will return a new sts.WebIdentityRoleProvider.
// If omitted, then stscreds.NewWebIdentityRoleProvider will be used instead.
NewWebIdentityRoleProvider func(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *stscreds.WebIdentityRoleProvider
}

func resolveCredentials(cfg *aws.Config,
envCfg envConfig, sharedCfg sharedConfig,
handlers request.Handlers,
Expand All @@ -40,6 +49,7 @@ func resolveCredentials(cfg *aws.Config,
envCfg.WebIdentityTokenFilePath,
envCfg.RoleARN,
envCfg.RoleSessionName,
sessOpts.CredentialsProviderFactory,
)

default:
Expand All @@ -59,6 +69,7 @@ var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, "
func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
filepath string,
roleARN, sessionName string,
factory *CredentialsProviderFactory,
) (*credentials.Credentials, error) {

if len(filepath) == 0 {
Expand All @@ -69,17 +80,18 @@ func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers,
return nil, WebIdentityEmptyRoleARNErr
}

creds := stscreds.NewWebIdentityCredentials(
&Session{
Config: cfg,
Handlers: handlers.Copy(),
},
roleARN,
sessionName,
filepath,
)
svc := sts.New(&Session{
Config: cfg,
Handlers: handlers.Copy(),
})

return creds, nil
newProviderFunc := stscreds.NewWebIdentityRoleProvider
if factory != nil && factory.NewWebIdentityRoleProvider != nil {
newProviderFunc = factory.NewWebIdentityRoleProvider
}

p := newProviderFunc(svc, roleARN, sessionName, filepath)
return credentials.NewCredentials(p), nil
}

func resolveCredsFromProfile(cfg *aws.Config,
Expand Down Expand Up @@ -114,6 +126,7 @@ func resolveCredsFromProfile(cfg *aws.Config,
sharedCfg.WebIdentityTokenFile,
sharedCfg.RoleARN,
sharedCfg.RoleSessionName,
sessOpts.CredentialsProviderFactory,
)

case sharedCfg.hasSSOConfiguration():
Expand Down
53 changes: 53 additions & 0 deletions aws/session/credentials_test.go
Expand Up @@ -19,13 +19,15 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdktesting"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/aws/aws-sdk-go/private/protocol"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

func newEc2MetadataServer(key, secret string, closeAfterGetCreds bool) *httptest.Server {
Expand Down Expand Up @@ -814,6 +816,57 @@ func TestSessionAssumeRole_WithMFA_ExtendedDuration(t *testing.T) {
}
}

func TestSessionAssumeRoleWithWebIdentity_Factory(t *testing.T) {
restoreEnvFn := initSessionTestEnv()
defer restoreEnvFn()

os.Setenv("AWS_REGION", "us-east-1")
os.Setenv("AWS_ROLE_ARN", "web_identity_role_arn")
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "./testdata/wit.txt")

endpointResolver, teardown := setupCredentialsEndpoints(t)
defer teardown()

customFactoryCalled := false

sess, err := NewSessionWithOptions(Options{
Config: aws.Config{
EndpointResolver: endpointResolver,
},
CredentialsProviderFactory: &CredentialsProviderFactory{
NewWebIdentityRoleProvider: func(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *stscreds.WebIdentityRoleProvider {
customFactoryCalled = true
return stscreds.NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
},
},
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if !customFactoryCalled {
t.Errorf("expect custom factory to be called")
}

creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := "WEB_IDENTITY_AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "WEB_IDENTITY_SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "WEB_IDENTITY_SESSION_TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := stscreds.WebIdentityProviderName, creds.ProviderName; e != a {
t.Errorf("expect %v,got %v", e, a)
}
}

func ssoTestSetup() (func(), error) {
dir, err := ioutil.TempDir("", "sso-test")
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions aws/session/session.go
Expand Up @@ -304,6 +304,15 @@ type Options struct {
//
// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState

// Specifies custom constructor functions for creating credential
// providers. These functions are only used if the aws.Config does
// not already include credentials.
//
// If this field is nil, or if the corresponding constructor
// function is nil, then the default constructor function for
// the provider in question will be used.
CredentialsProviderFactory *CredentialsProviderFactory
}

// NewSessionWithOptions returns a new Session created from SDK defaults, config files,
Expand Down

0 comments on commit 649261d

Please sign in to comment.