Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add options for customizing credential providers #4174

Merged
merged 2 commits into from Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Expand Up @@ -2,4 +2,6 @@

### SDK Enhancements

* `aws/session`: Add options for customizing the construction of credential providers. Currently only supported for `stscreds.WebIdentityRoleProvider`.

### SDK Bugs
40 changes: 34 additions & 6 deletions aws/credentials/stscreds/web_identity_provider.go
Expand Up @@ -28,7 +28,7 @@ const (
// compare test values.
var now = time.Now

// TokenFetcher shuold return WebIdentity token bytes or an error
// TokenFetcher should return WebIdentity token bytes or an error
type TokenFetcher interface {
FetchToken(credentials.Context) ([]byte, error)
}
Expand All @@ -50,6 +50,8 @@ func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) {
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry

// The policy ARNs to use with the web identity assumed role.
PolicyArns []*sts.PolicyDescriptorType

// Duration the STS credentials will be valid for. Truncated to seconds.
Expand All @@ -74,6 +76,9 @@ type WebIdentityRoleProvider struct {

// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options, and wrap with credentials.NewCredentials helper.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
Expand All @@ -82,19 +87,42 @@ func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName

// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path))
return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, FetchTokenPath(path))
}

// NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI and a TokenFetcher
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, tokenFetcher)
}

// NewWebIdentityRoleProviderWithOptions will return an initialize
// WebIdentityRoleProvider with the provided stsiface.STSAPI, role ARN, and a
// TokenFetcher. Additional options can be provided as functional options.
//
// TokenFetcher is the implementation that will retrieve the JWT token from to
// assume the role with. Use the provided FetchTokenPath implementation to
// retrieve the JWT token using a file system path.
func NewWebIdentityRoleProviderWithOptions(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher, optFns ...func(*WebIdentityRoleProvider)) *WebIdentityRoleProvider {
p := WebIdentityRoleProvider{
client: svc,
tokenFetcher: tokenFetcher,
roleARN: roleARN,
roleSessionName: roleSessionName,
}

for _, fn := range optFns {
fn(&p)
}

return &p
}

// Retrieve will attempt to assume a role from a token which is located at
Expand All @@ -104,9 +132,9 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
// RetrieveWithContext will attempt to assume a role from a token which is
// located at 'WebIdentityTokenFilePath' specified destination and if that is
// empty an error will be returned.
func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
b, err := p.tokenFetcher.FetchToken(ctx)
if err != nil {
Expand Down
244 changes: 146 additions & 98 deletions aws/credentials/stscreds/web_identity_provider_test.go
@@ -1,152 +1,107 @@
//go:build go1.7
// +build go1.7

package stscreds_test
package stscreds

import (
"fmt"
"net/http"
"reflect"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

func TestWebIdentityProviderRetrieve(t *testing.T) {
var reqCount int
cases := map[string]struct {
onSendReq func(*testing.T, *request.Request)
roleARN string
tokenFilepath string
tokenPath string
sessionName string
newClient func(t *testing.T) stsiface.STSAPI
duration time.Duration
expectedError string
expectedCredValue credentials.Value
}{
"session name case": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := "foo", *input.RoleSessionName; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if input.DurationSeconds != nil {
t.Errorf("expect no duration, got %v", *input.DurationSeconds)
}
roleARN: "arn01234567890123456789",
tokenPath: "testdata/token.jwt",
sessionName: "foo",
newClient: func(t *testing.T) stsiface.STSAPI {
return mockAssumeRoleWithWebIdentityClient{
t: t,
doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := "foo", *input.RoleSessionName; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if input.DurationSeconds != nil {
t.Errorf("expect no duration, got %v", *input.DurationSeconds)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: stscreds.WebIdentityProviderName,
ProviderName: WebIdentityProviderName,
},
},
"with duration": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
duration: 15 * time.Minute,
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a {
t.Errorf("expect %v duration, got %v", e, a)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: stscreds.WebIdentityProviderName,
},
},
"invalid token retry": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

if reqCount == 0 {
r.HTTPResponse.StatusCode = 400
r.Error = awserr.New(sts.ErrCodeInvalidIdentityTokenException,
"some error message", nil)
return
}
roleARN: "arn01234567890123456789",
tokenPath: "testdata/token.jwt",
sessionName: "foo",
duration: 15 * time.Minute,
newClient: func(t *testing.T) stsiface.STSAPI {
return mockAssumeRoleWithWebIdentityClient{
t: t,
doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a {
t.Errorf("expect %v duration, got %v", e, a)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: stscreds.WebIdentityProviderName,
ProviderName: WebIdentityProviderName,
},
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
reqCount = 0

svc := sts.New(unit.Session, &aws.Config{
Logger: t,
})
svc.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{
Name: "custom send stub handler",
Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: 200, Header: http.Header{},
}
c.onSendReq(t, r)
reqCount++
},
})
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalError.Clear()

p := stscreds.NewWebIdentityRoleProvider(svc, c.roleARN, c.sessionName, c.tokenFilepath)
p := NewWebIdentityRoleProvider(c.newClient(t), c.roleARN, c.sessionName, c.tokenPath)
p.Duration = c.duration

credValue, err := p.Retrieve()
Expand All @@ -169,3 +124,96 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
})
}
}

type mockAssumeRoleWithWebIdentityClient struct {
stsiface.STSAPI

t *testing.T
doRequest func(*testing.T, *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error)
}

func (c mockAssumeRoleWithWebIdentityClient) AssumeRoleWithWebIdentityRequest(input *sts.AssumeRoleWithWebIdentityInput) (
*request.Request, *sts.AssumeRoleWithWebIdentityOutput,
) {
output, err := c.doRequest(c.t, input)

req := &request.Request{
HTTPRequest: &http.Request{},
Retryer: client.DefaultRetryer{},
}
req.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{}
r.Data = output
r.Error = err

var found bool
for _, retryCode := range req.RetryErrorCodes {
if retryCode == sts.ErrCodeInvalidIdentityTokenException {
found = true
break
}
}
if !found {
c.t.Errorf("expect ErrCodeInvalidIdentityTokenException error code to be retry-able")
}
})

return req, output
}

func TestNewWebIdentityRoleProviderWithOptions(t *testing.T) {
const roleARN = "a-role-arn"
const roleSessionName = "a-session-name"

cases := map[string]struct {
options []func(*WebIdentityRoleProvider)
expect WebIdentityRoleProvider
}{
"no options": {
expect: WebIdentityRoleProvider{
client: stubClient{},
tokenFetcher: stubTokenFetcher{},
roleARN: roleARN,
roleSessionName: roleSessionName,
},
},
"with options": {
options: []func(*WebIdentityRoleProvider){
func(o *WebIdentityRoleProvider) {
o.Duration = 10 * time.Minute
o.ExpiryWindow = time.Minute
},
},
expect: WebIdentityRoleProvider{
client: stubClient{},
tokenFetcher: stubTokenFetcher{},
roleARN: roleARN,
roleSessionName: roleSessionName,
Duration: 10 * time.Minute,
ExpiryWindow: time.Minute,
},
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {

p := NewWebIdentityRoleProviderWithOptions(
stubClient{}, roleARN, roleSessionName,
stubTokenFetcher{}, c.options...)

if !reflect.DeepEqual(c.expect, *p) {
t.Errorf("expect:\n%v\nactual:\n%v", c.expect, *p)
}
})
}
}

type stubClient struct {
stsiface.STSAPI
}
type stubTokenFetcher struct{}

func (stubTokenFetcher) FetchToken(credentials.Context) ([]byte, error) {
return nil, fmt.Errorf("stubTokenFetcher should not be called")
}