diff --git a/.changelog/08d93344ca9c4036aefcf73146edaf55.json b/.changelog/08d93344ca9c4036aefcf73146edaf55.json new file mode 100644 index 00000000000..ee2582939d3 --- /dev/null +++ b/.changelog/08d93344ca9c4036aefcf73146edaf55.json @@ -0,0 +1,8 @@ +{ + "id": "08d93344-ca9c-4036-aefc-f73146edaf55", + "type": "feature", + "description": "Update CredentialsCache to make use of two new optional CredentialsProvider interfaces to give the cache, per provider, behavior how the cache handles credentials that fail to refresh, and adjusting expires time. See [aws.CredentialsCache](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#CredentialsCache) for more details.", + "modules": [ + "." + ] +} \ No newline at end of file diff --git a/.changelog/9913162361ed41fe867f56fb4ee75e8e.json b/.changelog/9913162361ed41fe867f56fb4ee75e8e.json new file mode 100644 index 00000000000..2a1adcaafb7 --- /dev/null +++ b/.changelog/9913162361ed41fe867f56fb4ee75e8e.json @@ -0,0 +1,9 @@ +{ + "id": "99131623-61ed-41fe-867f-56fb4ee75e8e", + "type": "feature", + "description": "Update `ec2rolecreds` package's `Provider` to implememnt support for CredentialsCache new optional caching strategy interfaces, HandleFailRefreshCredentialsCacheStrategy and AdjustExpiresByCredentialsCacheStrategy.", + "modules": [ + ".", + "credentials" + ] +} \ No newline at end of file diff --git a/aws/credential_cache.go b/aws/credential_cache.go index 1411a5c32d6..dfd2b1ddbff 100644 --- a/aws/credential_cache.go +++ b/aws/credential_cache.go @@ -2,6 +2,7 @@ package aws import ( "context" + "fmt" "sync/atomic" "time" @@ -24,11 +25,13 @@ type CredentialsCacheOptions struct { // If ExpiryWindow is 0 or less it will be ignored. ExpiryWindow time.Duration - // ExpiryWindowJitterFrac provides a mechanism for randomizing the expiration of credentials - // within the configured ExpiryWindow by a random percentage. Valid values are between 0.0 and 1.0. + // ExpiryWindowJitterFrac provides a mechanism for randomizing the + // expiration of credentials within the configured ExpiryWindow by a random + // percentage. Valid values are between 0.0 and 1.0. // - // As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac is 0.5 then credentials will be set to - // expire between 30 to 60 seconds prior to their actual expiration time. + // As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac + // is 0.5 then credentials will be set to expire between 30 to 60 seconds + // prior to their actual expiration time. // // If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored. // If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window. @@ -39,8 +42,19 @@ type CredentialsCacheOptions struct { // CredentialsCache provides caching and concurrency safe credentials retrieval // via the provider's retrieve method. +// +// CredentialsCache will look for optional interfaces on the Provider to adjust +// how the credential cache handles credentials caching. +// +// * HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle +// credential refresh failures. This could return an updated Credentials +// value, or attempt another means of retrieving credentials. +// +// * AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how +// credentials Expires is modified. This could modify how the Credentials +// Expires is adjusted based on the CredentialsCache ExpiryWindow option. +// Such as providing a floor not to reduce the Expires below. type CredentialsCache struct { - // provider is the CredentialProvider implementation to be wrapped by the CredentialCache. provider CredentialsProvider options CredentialsCacheOptions @@ -48,8 +62,9 @@ type CredentialsCache struct { sf singleflight.Group } -// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider is expected to not be nil. A variadic -// list of one or more functions can be provided to modify the CredentialsCache configuration. This allows for +// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider +// is expected to not be nil. A variadic list of one or more functions can be +// provided to modify the CredentialsCache configuration. This allows for // configuration of credential expiry window and jitter. func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache { options := CredentialsCacheOptions{} @@ -81,8 +96,8 @@ func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *C // // Returns and error if the provider's retrieve method returns an error. func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) { - if creds := p.getCreds(); creds != nil { - return *creds, nil + if creds, ok := p.getCreds(); ok && !creds.Expired() { + return creds, nil } resCh := p.sf.DoChan("", func() (interface{}, error) { @@ -97,39 +112,64 @@ func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) { } func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) { - if creds := p.getCreds(); creds != nil { - return *creds, nil + currCreds, ok := p.getCreds() + if ok && !currCreds.Expired() { + return currCreds, nil + } + + newCreds, err := p.provider.Retrieve(ctx) + if err != nil { + handleFailToRefresh := defaultHandleFailToRefresh + if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok { + handleFailToRefresh = cs.HandleFailToRefresh + } + newCreds, err = handleFailToRefresh(ctx, currCreds, err) + if err != nil { + return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err) + } } - creds, err := p.provider.Retrieve(ctx) - if err == nil { - if creds.CanExpire { - randFloat64, err := sdkrand.CryptoRandFloat64() - if err != nil { - return Credentials{}, err - } - jitter := time.Duration(randFloat64 * p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow)) - creds.Expires = creds.Expires.Add(-(p.options.ExpiryWindow - jitter)) + if newCreds.CanExpire && p.options.ExpiryWindow > 0 { + adjustExpiresBy := defaultAdjustExpiresBy + if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok { + adjustExpiresBy = cs.AdjustExpiresBy + } + + randFloat64, err := sdkrand.CryptoRandFloat64() + if err != nil { + return Credentials{}, fmt.Errorf("failed to get random provider, %w", err) } - p.creds.Store(&creds) + var jitter time.Duration + if p.options.ExpiryWindowJitterFrac > 0 { + jitter = time.Duration(randFloat64 * + p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow)) + } + + newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter)) + if err != nil { + return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err) + } } - return creds, err + p.creds.Store(&newCreds) + return newCreds, nil } -func (p *CredentialsCache) getCreds() *Credentials { +// getCreds returns the currently stored credentials and true. Returning false +// if no credentials were stored. +func (p *CredentialsCache) getCreds() (Credentials, bool) { v := p.creds.Load() if v == nil { - return nil + return Credentials{}, false } c := v.(*Credentials) - if c != nil && c.HasKeys() && !c.Expired() { - return c + if c == nil || !c.HasKeys() { + return Credentials{}, false } - return nil + return *c, true } // Invalidate will invalidate the cached credentials. The next call to Retrieve @@ -137,3 +177,42 @@ func (p *CredentialsCache) getCreds() *Credentials { func (p *CredentialsCache) Invalidate() { p.creds.Store((*Credentials)(nil)) } + +// HandleFailRefreshCredentialsCacheStrategy is an interface for +// CredentialsCache to allow CredentialsProvider how failed to refresh +// credentials is handled. +type HandleFailRefreshCredentialsCacheStrategy interface { + // Given the previously cached Credentials, if any, and refresh error, may + // returns new or modified set of Credentials, or error. + // + // Credential caches may use default implementation if nil. + HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error) +} + +// defaultHandleFailToRefresh returns the passed in error. +func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) { + return Credentials{}, err +} + +// AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache +// to allow CredentialsProvider to intercept adjustments to Credentials expiry +// based on expectations and use cases of CredentialsProvider. +// +// Credential caches may use default implementation if nil. +type AdjustExpiresByCredentialsCacheStrategy interface { + // Given a Credentials as input, applying any mutations and + // returning the potentially updated Credentials, or error. + AdjustExpiresBy(Credentials, time.Duration) (Credentials, error) +} + +// defaultAdjustExpiresBy adds the duration to the passed in credentials Expires, +// and returns the updated credentials value. If Credentials value's CanExpire +// is false, the passed in credentials are returned unchanged. +func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) { + if !creds.CanExpire { + return creds, nil + } + + creds.Expires = creds.Expires.Add(dur) + return creds, nil +} diff --git a/aws/credential_cache_test.go b/aws/credential_cache_test.go index 314c7c1da81..28f43d3871d 100644 --- a/aws/credential_cache_test.go +++ b/aws/credential_cache_test.go @@ -4,12 +4,15 @@ import ( "context" "fmt" "math/rand" + "strings" "sync" "sync/atomic" "testing" "time" + sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand" "github.com/aws/aws-sdk-go-v2/internal/sdk" + "github.com/google/go-cmp/cmp" ) type stubCredentialsProvider struct { @@ -217,7 +220,7 @@ func TestCredentialsCache_Error(t *testing.T) { if err == nil { t.Fatalf("expect error, not none") } - if e, a := "failed", err.Error(); e != a { + if e, a := "failed", err.Error(); !strings.Contains(a, e) { t.Errorf("expect %q, got %q", e, a) } if e, a := (Credentials{}), creds; e != a { @@ -299,3 +302,318 @@ func TestCredentialsCache_RetrieveConcurrent(t *testing.T) { t.Errorf("expected %v, got %v", e, a) } } + +func TestCredentialsCache_cacheStrategies(t *testing.T) { + origSdkTime := sdk.NowTime + defer func() { sdk.NowTime = origSdkTime }() + sdk.NowTime = func() time.Time { + return time.Date(2015, 4, 8, 0, 0, 0, 0, time.UTC) + } + + origSdkRandReader := sdkrand.Reader + defer func() { sdkrand.Reader = origSdkRandReader }() + sdkrand.Reader = byteReader(0xFF) + + cases := map[string]struct { + options func(*CredentialsCacheOptions) + provider CredentialsProvider + initialCreds Credentials + expectErr string + expectCreds Credentials + }{ + "default": { + provider: struct { + mockProvider + }{ + mockProvider: mockProvider{ + creds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + }, + expectCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + "default with window": { + options: func(o *CredentialsCacheOptions) { + o.ExpiryWindow = 5 * time.Minute + }, + provider: struct { + mockProvider + }{ + mockProvider: mockProvider{ + creds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + }, + expectCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(55 * time.Minute), + }, + }, + "default with window jitterFrac": { + options: func(o *CredentialsCacheOptions) { + o.ExpiryWindow = 5 * time.Minute + o.ExpiryWindowJitterFrac = 0.5 + }, + provider: struct { + mockProvider + }{ + mockProvider: mockProvider{ + creds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + }, + expectCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(57*time.Minute + 29*time.Second), + }, + }, + "handle refresh": { + initialCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(-time.Hour), + }, + provider: struct { + mockProvider + mockHandleFailToRefresh + }{ + mockProvider: mockProvider{ + err: fmt.Errorf("some error"), + }, + mockHandleFailToRefresh: mockHandleFailToRefresh{ + expectInputCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(-time.Hour), + }, + creds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + }, + expectCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + "handle refresh error": { + initialCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(-time.Hour), + }, + provider: struct { + mockProvider + mockHandleFailToRefresh + }{ + mockProvider: mockProvider{ + err: fmt.Errorf("some error"), + }, + mockHandleFailToRefresh: mockHandleFailToRefresh{ + expectInputCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(-time.Hour), + }, + expectErr: "some error", + err: fmt.Errorf("some other error"), + }, + }, + expectErr: "some other error", + }, + "adjust expires": { + options: func(o *CredentialsCacheOptions) { + o.ExpiryWindow = 5 * time.Minute + }, + provider: struct { + mockProvider + mockAdjustExpiryBy + }{ + mockProvider: mockProvider{ + creds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + mockAdjustExpiryBy: mockAdjustExpiryBy{ + expectInputCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + expectDur: -5 * time.Minute, + creds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(25 * time.Minute), + }, + }, + }, + expectCreds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(25 * time.Minute), + }, + }, + "adjust expires error": { + options: func(o *CredentialsCacheOptions) { + o.ExpiryWindow = 5 * time.Minute + }, + provider: struct { + mockProvider + mockAdjustExpiryBy + }{ + mockProvider: mockProvider{ + creds: Credentials{ + AccessKeyID: "AKIAIOSFODNN7EXAMPLE", + SecretAccessKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + CanExpire: true, + Expires: sdk.NowTime().Add(time.Hour), + }, + }, + mockAdjustExpiryBy: mockAdjustExpiryBy{ + err: fmt.Errorf("some error"), + }, + }, + expectErr: "some error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var optFns []func(*CredentialsCacheOptions) + if c.options != nil { + optFns = append(optFns, c.options) + } + provider := NewCredentialsCache(c.provider, optFns...) + + if c.initialCreds.HasKeys() { + creds := c.initialCreds + provider.creds.Store(&creds) + } + + creds, err := provider.Retrieve(context.Background()) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + // Truncate expires time so its easy to compare + creds.Expires = creds.Expires.Truncate(time.Second) + + if diff := cmp.Diff(c.expectCreds, creds); diff != "" { + t.Errorf("expect creds match\n%s", diff) + } + }) + } +} + +type byteReader byte + +func (b byteReader) Read(p []byte) (int, error) { + for i := 0; i < len(p); i++ { + p[i] = byte(b) + } + return len(p), nil +} + +type mockProvider struct { + creds Credentials + err error +} + +var _ CredentialsProvider = mockProvider{} + +func (m mockProvider) Retrieve(context.Context) (Credentials, error) { + return m.creds, m.err +} + +type mockHandleFailToRefresh struct { + expectInputCreds Credentials + expectErr string + creds Credentials + err error +} + +var _ HandleFailRefreshCredentialsCacheStrategy = mockHandleFailToRefresh{} + +func (m mockHandleFailToRefresh) HandleFailToRefresh(ctx context.Context, prevCreds Credentials, err error) ( + Credentials, error, +) { + if m.expectInputCreds.HasKeys() { + if e, a := m.expectInputCreds, prevCreds; e != a { + return Credentials{}, fmt.Errorf("expect %v creds, got %v", e, a) + } + } + if m.expectErr != "" { + if err == nil { + return Credentials{}, fmt.Errorf("expect input error, got none") + } + if e, a := m.expectErr, err.Error(); !strings.Contains(a, e) { + return Credentials{}, fmt.Errorf("expect %v in error, got %v", e, a) + } + } + return m.creds, m.err +} + +type mockAdjustExpiryBy struct { + expectInputCreds Credentials + expectDur time.Duration + creds Credentials + err error +} + +var _ AdjustExpiresByCredentialsCacheStrategy = mockAdjustExpiryBy{} + +func (m mockAdjustExpiryBy) AdjustExpiresBy(creds Credentials, dur time.Duration) ( + Credentials, error, +) { + if m.expectInputCreds.HasKeys() { + if diff := cmp.Diff(m.expectInputCreds, creds); diff != "" { + return Credentials{}, fmt.Errorf("expect creds match\n%s", diff) + } + } + return m.creds, m.err +} diff --git a/aws/credentials.go b/aws/credentials.go index ce3868a9f01..0fffc53e671 100644 --- a/aws/credentials.go +++ b/aws/credentials.go @@ -83,16 +83,20 @@ type Credentials struct { // Source of the credentials Source string - // Time the credentials will expire. + // States if the credentials can expire or not. CanExpire bool - Expires time.Time + + // The time the credentials will expire at. Should be ignored if CanExpire + // is false. + Expires time.Time } // Expired returns if the credentials have expired. func (v Credentials) Expired() bool { if v.CanExpire { - // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry - // time is always based on reported wall-clock time. + // Calling Round(0) on the current time will truncate the monotonic + // reading only. Ensures credential expiry time is always based on + // reported wall-clock time. return !v.Expires.After(sdk.NowTime().Round(0)) } diff --git a/config/resolve_credentials.go b/config/resolve_credentials.go index 6bac0bb4dd8..42904ed740d 100644 --- a/config/resolve_credentials.go +++ b/config/resolve_credentials.go @@ -305,9 +305,7 @@ func resolveEC2RoleCredentials(ctx context.Context, cfg *aws.Config, configs con provider := ec2rolecreds.New(optFns...) - cfg.Credentials, err = wrapWithCredentialsCache(ctx, configs, provider, func(options *aws.CredentialsCacheOptions) { - options.ExpiryWindow = 5 * time.Minute - }) + cfg.Credentials, err = wrapWithCredentialsCache(ctx, configs, provider) if err != nil { return err } diff --git a/credentials/ec2rolecreds/provider.go b/credentials/ec2rolecreds/provider.go index 901132a3253..aeb79ac3c97 100644 --- a/credentials/ec2rolecreds/provider.go +++ b/credentials/ec2rolecreds/provider.go @@ -5,13 +5,18 @@ import ( "context" "encoding/json" "fmt" + "math" "path" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand" + "github.com/aws/aws-sdk-go-v2/internal/sdk" "github.com/aws/smithy-go" + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" ) // ProviderName provides a name of EC2Role provider @@ -26,14 +31,10 @@ type GetMetadataAPIClient interface { // A Provider retrieves credentials from the EC2 service, and keeps track if // those credentials are expired. // -// The New function must be used to create the Provider. +// The New function must be used to create the with a custom EC2 IMDS client. // -// p := &ec2rolecreds.New(ec2rolecreds.Options{ -// Client: imds.New(imds.Options{}), -// -// // Expire the credentials 10 minutes before IAM states they should. -// // Proactively refreshing the credentials. -// ExpiryWindow: 10 * time.Minute +// p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{ +// o.Client = imds.New(imds.Options{/* custom options */}) // }) type Provider struct { options Options @@ -66,9 +67,8 @@ func New(optFns ...func(*Options)) *Provider { } } -// Retrieve retrieves credentials from the EC2 service. -// Error will be returned if the request fails, or unable to extract -// the desired credentials. +// Retrieve retrieves credentials from the EC2 service. Error will be returned +// if the request fails, or unable to extract the desired credentials. func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { credsList, err := requestCredList(ctx, p.options.Client) if err != nil { @@ -96,10 +96,65 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { Expires: roleCreds.Expiration, } + // Cap role credentials Expires to 1 hour so they can be refreshed more + // often. Jitter will be applied credentials cache if being used. + if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) { + creds.Expires = anHour + } + + return creds, nil +} + +// HandleFailToRefresh will extend the credentials Expires time if it it is +// expired. If the credentials will not expire within the minimum time, they +// will be returned. +// +// If the credentials cannot expire, the original error will be returned. +func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) ( + aws.Credentials, error, +) { + if !prevCreds.CanExpire { + return aws.Credentials{}, err + } + + if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) { + return prevCreds, nil + } + + newCreds := prevCreds + randFloat64, err := sdkrand.CryptoRandFloat64() + if err != nil { + return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err) + } + + // Random distribution of [5,15) minutes. + expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute + newCreds.Expires = sdk.NowTime().Add(expireOffset) + + logger := middleware.GetLogger(ctx) + logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes())) + + return newCreds, nil +} + +// AdjustExpiresBy will adds the passed in duration to the passed in +// credential's Expires time, unless the time until Expires is less than 15 +// minutes. Returns the credentials, even if not updated. +func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) ( + aws.Credentials, error, +) { + if !creds.CanExpire { + return creds, nil + } + if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) { + return creds, nil + } + + creds.Expires = creds.Expires.Add(dur) return creds, nil } -// A ec2RoleCredRespBody provides the shape for unmarshaling credential +// ec2RoleCredRespBody provides the shape for unmarshaling credential // request responses. type ec2RoleCredRespBody struct { // Success State diff --git a/credentials/ec2rolecreds/provider_test.go b/credentials/ec2rolecreds/provider_test.go index 4073506230b..31be7d6e365 100644 --- a/credentials/ec2rolecreds/provider_test.go +++ b/credentials/ec2rolecreds/provider_test.go @@ -1,17 +1,24 @@ package ec2rolecreds import ( + "bytes" "context" "errors" "fmt" + "io" "io/ioutil" "strings" "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand" "github.com/aws/aws-sdk-go-v2/internal/sdk" "github.com/aws/smithy-go" + "github.com/aws/smithy-go/logging" + "github.com/aws/smithy-go/middleware" + "github.com/google/go-cmp/cmp" ) const credsRespTmpl = `{ @@ -63,6 +70,11 @@ func (c mockClient) GetMetadata( } } +var ( + _ aws.AdjustExpiresByCredentialsCacheStrategy = (*Provider)(nil) + _ aws.HandleFailRefreshCredentialsCacheStrategy = (*Provider)(nil) +) + func TestProvider(t *testing.T) { orig := sdk.NowTime defer func() { sdk.NowTime = orig }() @@ -171,3 +183,185 @@ func TestProvider_IsExpired(t *testing.T) { t.Errorf("expect to be expired") } } + +type byteReader byte + +func (b byteReader) Read(p []byte) (int, error) { + for i := 0; i < len(p); i++ { + p[i] = byte(b) + } + return len(p), nil +} + +func TestProvider_HandleFailToRetrieve(t *testing.T) { + origTime := sdk.NowTime + defer func() { sdk.NowTime = origTime }() + sdk.NowTime = func() time.Time { + return time.Date(2014, 04, 04, 0, 1, 0, 0, time.UTC) + } + + origRand := sdkrand.Reader + defer func() { sdkrand.Reader = origRand }() + sdkrand.Reader = byteReader(0) + + cases := map[string]struct { + creds aws.Credentials + err error + randReader io.Reader + expectCreds aws.Credentials + expectErr string + expectLogged string + }{ + "expired low": { + randReader: byteReader(0), + creds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(-5 * time.Minute), + }, + err: fmt.Errorf("some error"), + expectCreds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(5 * time.Minute), + }, + expectLogged: fmt.Sprintf("again in 5 minutes"), + }, + "expired high": { + randReader: byteReader(0xFF), + creds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(-5 * time.Minute), + }, + err: fmt.Errorf("some error"), + expectCreds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(14*time.Minute + 59*time.Second), + }, + expectLogged: fmt.Sprintf("again in 14 minutes"), + }, + "not expired": { + randReader: byteReader(0xFF), + creds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(10 * time.Minute), + }, + err: fmt.Errorf("some error"), + expectCreds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(10 * time.Minute), + }, + }, + "cannot expire": { + randReader: byteReader(0xFF), + creds: aws.Credentials{ + CanExpire: false, + }, + err: fmt.Errorf("some error"), + expectErr: "some error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + sdkrand.Reader = c.randReader + if sdkrand.Reader == nil { + sdkrand.Reader = byteReader(0) + } + + var logBuf bytes.Buffer + logger := logging.LoggerFunc(func(class logging.Classification, format string, args ...interface{}) { + fmt.Fprintf(&logBuf, string(class)+" "+format, args...) + }) + ctx := middleware.SetLogger(context.Background(), logger) + + p := New() + creds, err := p.HandleFailToRefresh(ctx, c.creds, c.err) + if err == nil && len(c.expectErr) != 0 { + t.Fatalf("expect error %v, got none", c.expectErr) + } + if err != nil && len(c.expectErr) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), c.expectErr) { + t.Fatalf("expect error to contain %v, got %v", c.expectErr, err) + } + if c.expectErr != "" { + return + } + + if len(c.expectLogged) != 0 && logBuf.Len() == 0 { + t.Errorf("expect %v logged, got none", c.expectLogged) + } + if e, a := c.expectLogged, logBuf.String(); !strings.Contains(a, e) { + t.Errorf("expect %v to be logged in %v", e, a) + } + + // Truncate time so it can be easily compared. + creds.Expires = creds.Expires.Truncate(time.Second) + + if diff := cmp.Diff(c.expectCreds, creds); diff != "" { + t.Errorf("expect creds match\n%s", diff) + } + }) + } +} + +func TestProvider_AdjustExpiresBy(t *testing.T) { + origTime := sdk.NowTime + defer func() { sdk.NowTime = origTime }() + sdk.NowTime = func() time.Time { + return time.Date(2014, 04, 04, 0, 1, 0, 0, time.UTC) + } + + cases := map[string]struct { + creds aws.Credentials + dur time.Duration + expectCreds aws.Credentials + }{ + "modify expires": { + creds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(1 * time.Hour), + }, + dur: -5 * time.Minute, + expectCreds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(55 * time.Minute), + }, + }, + "expiry too soon": { + creds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(14*time.Minute + 59*time.Second), + }, + dur: -5 * time.Minute, + expectCreds: aws.Credentials{ + CanExpire: true, + Expires: sdk.NowTime().Add(14*time.Minute + 59*time.Second), + }, + }, + "cannot expire": { + creds: aws.Credentials{ + CanExpire: false, + }, + dur: 10 * time.Minute, + expectCreds: aws.Credentials{ + CanExpire: false, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + p := New() + creds, err := p.AdjustExpiresBy(c.creds, c.dur) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if diff := cmp.Diff(c.expectCreds, creds); diff != "" { + t.Errorf("expect creds match\n%s", diff) + } + }) + } +} diff --git a/internal/rand/rand.go b/internal/rand/rand.go index 9791ea590b5..c8484dcd759 100644 --- a/internal/rand/rand.go +++ b/internal/rand/rand.go @@ -29,5 +29,5 @@ func Float64(reader io.Reader) (float64, error) { // CryptoRandFloat64 returns a random float64 obtained from the crypto rand // source. func CryptoRandFloat64() (float64, error) { - return Float64(rand.Reader) + return Float64(Reader) }