diff --git a/.changelog/781710a7ecb24b9290b2642bd90b42c9.json b/.changelog/781710a7ecb24b9290b2642bd90b42c9.json new file mode 100644 index 00000000000..6b48783d7dd --- /dev/null +++ b/.changelog/781710a7ecb24b9290b2642bd90b42c9.json @@ -0,0 +1,8 @@ +{ + "id": "781710a7-ecb2-4b92-90b2-642bd90b42c9", + "type": "bugfix", + "description": "Updates the Retry middleware to release the retry token, on subsequent attempts. This fixes #1413, and is based on PR #1424", + "modules": [ + "." + ] +} \ No newline at end of file diff --git a/aws/retry/middleware.go b/aws/retry/middleware.go index cd7ef0baac3..19ce353cb5f 100644 --- a/aws/retry/middleware.go +++ b/aws/retry/middleware.go @@ -16,8 +16,8 @@ import ( "github.com/aws/smithy-go/transport/http" ) -// RequestCloner is a function that can take an input request type and clone the request -// for use in a subsequent retry attempt +// RequestCloner is a function that can take an input request type and clone +// the request for use in a subsequent retry attempt. type RequestCloner func(interface{}) interface{} type retryMetadata struct { @@ -27,11 +27,12 @@ type retryMetadata struct { AttemptClockSkew time.Duration } -// Attempt is a Smithy FinalizeMiddleware that handles retry attempts using the provided -// Retryer implementation +// Attempt is a Smithy Finalize middleware that handles retry attempts using +// the provided Retryer implementation. type Attempt struct { - // Enable the logging of retry attempts performed by the SDK. - // This will include logging retry attempts, unretryable errors, and when max attempts are reached. + // Enable the logging of retry attempts performed by the SDK. This will + // include logging retry attempts, unretryable errors, and when max + // attempts are reached. LogAttempts bool retryer aws.Retryer @@ -59,8 +60,9 @@ func (r Attempt) logf(logger logging.Logger, classification logging.Classificati logger.Logf(classification, format, v...) } -// HandleFinalize utilizes the provider Retryer implementation to attempt retries over the next handler -func (r Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) ( +// HandleFinalize utilizes the provider Retryer implementation to attempt +// retries over the next handler +func (r *Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) ( out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error, ) { var attemptNum int @@ -68,12 +70,14 @@ func (r Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInp var attemptResults AttemptResults maxAttempts := r.retryer.MaxAttempts() + releaseRetryToken := nopRelease for { attemptNum++ attemptInput := in attemptInput.Request = r.requestCloner(attemptInput.Request) + // Record the metadata for the for attempt being started. attemptCtx := setRetryMetadata(ctx, retryMetadata{ AttemptNum: attemptNum, AttemptTime: sdk.NowTime().UTC(), @@ -82,23 +86,20 @@ func (r Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInp }) var attemptResult AttemptResult + out, attemptResult, releaseRetryToken, err = r.handleAttempt(attemptCtx, attemptInput, releaseRetryToken, next) + attemptClockSkew, _ = awsmiddle.GetAttemptSkew(attemptResult.ResponseMetadata) - out, attemptResult, err = r.handleAttempt(attemptCtx, attemptInput, next) - - var ok bool - attemptClockSkew, ok = awsmiddle.GetAttemptSkew(attemptResult.ResponseMetadata) - if !ok { - attemptClockSkew = 0 - } - + // AttempResult Retried states that the attempt was not successful, and + // should be retried. shouldRetry := attemptResult.Retried - // add attempt metadata to list of all attempt metadata + // Add attempt metadata to list of all attempt metadata attemptResults.Results = append(attemptResults.Results, attemptResult) if !shouldRetry { // Ensure the last response's metadata is used as the bases for result - // metadata returned by the stack. + // metadata returned by the stack. The Slice of attempt results + // will be added to this cloned metadata. metadata = attemptResult.ResponseMetadata.Clone() break @@ -110,26 +111,36 @@ func (r Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInp } // handleAttempt handles an individual request attempt. -func (r Attempt) handleAttempt(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) ( - out smithymiddle.FinalizeOutput, attemptResult AttemptResult, err error, +func (r *Attempt) handleAttempt( + ctx context.Context, in smithymiddle.FinalizeInput, releaseRetryToken func(error) error, next smithymiddle.FinalizeHandler, +) ( + out smithymiddle.FinalizeOutput, attemptResult AttemptResult, _ func(error) error, err error, ) { defer func() { attemptResult.Err = err }() - relRetryToken := r.retryer.GetInitialToken() + //------------------------------ + // Get Initial (aka Send) Token + //------------------------------ + releaseInitialToken := r.retryer.GetInitialToken() + + //------------------------------ + // Send Attempt + //------------------------------ logger := smithymiddle.GetLogger(ctx) service, operation := awsmiddle.GetServiceID(ctx), awsmiddle.GetOperationName(ctx) - retryMetadata, _ := getRetryMetadata(ctx) attemptNum := retryMetadata.AttemptNum maxAttempts := retryMetadata.MaxAttempts + // Following attempts must ensure the request payload stream starts in a + // rewound state. if attemptNum > 1 { if rewindable, ok := in.Request.(interface{ RewindStream() error }); ok { if rewindErr := rewindable.RewindStream(); rewindErr != nil { err = fmt.Errorf("failed to rewind transport stream for retry, %w", rewindErr) - return out, attemptResult, err + return out, attemptResult, nopRelease, err } } @@ -140,51 +151,78 @@ func (r Attempt) handleAttempt(ctx context.Context, in smithymiddle.FinalizeInpu out, metadata, err = next.HandleFinalize(ctx, in) attemptResult.ResponseMetadata = metadata - if releaseError := relRetryToken(err); releaseError != nil && err != nil { - err = fmt.Errorf("failed to release token after request error, %w", err) - return out, attemptResult, err + //------------------------------ + // Bookkeeping + //------------------------------ + // Release the initial send token based on the state of the attempt's error (if any). + if releaseError := releaseInitialToken(err); releaseError != nil && err != nil { + err = fmt.Errorf("failed to release initial token after request error, %w", err) + return out, attemptResult, nopRelease, err } - + // Release the retry token based on the state of the attempt's error (if any). + if releaseError := releaseRetryToken(err); releaseError != nil && err != nil { + err = fmt.Errorf("failed to release retry token after request error, %w", err) + return out, attemptResult, nopRelease, err + } + // If there was no error making the attempt, nothing further to do. There + // will be nothing to retry. if err == nil { - return out, attemptResult, err + return out, attemptResult, nopRelease, err } + //------------------------------ + // Is Retryable and Should Retry + //------------------------------ + // If the attempt failed with an unretryable error, nothing further to do + // but return, and inform the caller about the terminal failure. retryable := r.retryer.IsErrorRetryable(err) if !retryable { r.logf(logger, logging.Debug, "request failed with unretryable error %v", err) - return out, attemptResult, err + return out, attemptResult, nopRelease, err } // set retryable to true attemptResult.Retryable = true + // Once the maximum number of attempts have been exhausted there is nothing + // further to do other than inform the caller about the terminal failure. if maxAttempts > 0 && attemptNum >= maxAttempts { r.logf(logger, logging.Debug, "max retry attempts exhausted, max %d", maxAttempts) err = &MaxAttemptsError{ Attempt: attemptNum, Err: err, } - return out, attemptResult, err + return out, attemptResult, nopRelease, err } - relRetryToken, reqErr := r.retryer.GetRetryToken(ctx, err) - if reqErr != nil { - return out, attemptResult, reqErr + //------------------------------ + // Get Retry (aka Retry Quota) Token + //------------------------------ + // Get a retry token that will be released after the + releaseRetryToken, retryTokenErr := r.retryer.GetRetryToken(ctx, err) + if retryTokenErr != nil { + return out, attemptResult, nopRelease, retryTokenErr } + //------------------------------ + // Retry Delay and Sleep + //------------------------------ + // Get the retry delay before another attempt can be made, and sleep for + // that time. Potentially early exist if the sleep is canceled via the + // context. retryDelay, reqErr := r.retryer.RetryDelay(attemptNum, err) if reqErr != nil { - return out, attemptResult, reqErr + return out, attemptResult, releaseRetryToken, reqErr } - if reqErr = sdk.SleepWithContext(ctx, retryDelay); reqErr != nil { err = &aws.RequestCanceledError{Err: reqErr} - return out, attemptResult, err + return out, attemptResult, releaseRetryToken, err } + // The request should be re-attempted. attemptResult.Retried = true - return out, attemptResult, err + return out, attemptResult, releaseRetryToken, err } // MetricsHeader attaches SDK request metric header for retries to the transport @@ -195,7 +233,7 @@ func (r *MetricsHeader) ID() string { return "RetryMetricsHeader" } -// HandleFinalize attaches the sdk request metric header to the transport layer +// HandleFinalize attaches the SDK request metric header to the transport layer func (r MetricsHeader) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) ( out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error, ) { @@ -251,13 +289,14 @@ func setRetryMetadata(ctx context.Context, metadata retryMetadata) context.Conte return middleware.WithStackValue(ctx, retryMetadataKey{}, metadata) } -// AddRetryMiddlewaresOptions is the set of options that can be passed to AddRetryMiddlewares for configuring retry -// associated middleware. +// AddRetryMiddlewaresOptions is the set of options that can be passed to +// AddRetryMiddlewares for configuring retry associated middleware. type AddRetryMiddlewaresOptions struct { Retryer aws.Retryer - // Enable the logging of retry attempts performed by the SDK. - // This will include logging retry attempts, unretryable errors, and when max attempts are reached. + // Enable the logging of retry attempts performed by the SDK. This will + // include logging retry attempts, unretryable errors, and when max + // attempts are reached. LogRetryAttempts bool } diff --git a/aws/retry/middleware_test.go b/aws/retry/middleware_test.go index 90c7708eb70..737fbcdd249 100644 --- a/aws/retry/middleware_test.go +++ b/aws/retry/middleware_test.go @@ -2,6 +2,7 @@ package retry import ( "context" + "errors" "fmt" "net/http" "reflect" @@ -11,6 +12,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/ratelimit" "github.com/aws/aws-sdk-go-v2/internal/sdk" "github.com/aws/smithy-go/middleware" smithyhttp "github.com/aws/smithy-go/transport/http" @@ -63,17 +65,18 @@ func TestMetricsHeaderMiddleware(t *testing.T) { for i, tt := range cases { t.Run(strconv.Itoa(i), func(t *testing.T) { ctx := tt.ctx - _, _, err := retryMiddleware.HandleFinalize(ctx, tt.input, middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) ( - out middleware.FinalizeOutput, metadata middleware.Metadata, err error, - ) { - req := in.Request.(*smithyhttp.Request) + _, _, err := retryMiddleware.HandleFinalize(ctx, tt.input, middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + req := in.Request.(*smithyhttp.Request) - if e, a := tt.expectedHeader, req.Header.Get("amz-sdk-request"); e != a { - t.Errorf("expected %v, got %v", e, a) - } + if e, a := tt.expectedHeader, req.Header.Get("amz-sdk-request"); e != a { + t.Errorf("expected %v, got %v", e, a) + } - return out, metadata, err - })) + return out, metadata, err + })) if err != nil && len(tt.expectedErr) == 0 { t.Fatalf("expected no error, got %q", err) } else if err != nil && len(tt.expectedErr) != 0 { @@ -97,7 +100,9 @@ func (t retryProvider) GetRetryer() aws.Retryer { type mockHandler func(context.Context, interface{}) (interface{}, middleware.Metadata, error) -func (m mockHandler) Handle(ctx context.Context, input interface{}) (output interface{}, metadata middleware.Metadata, err error) { +func (m mockHandler) Handle(ctx context.Context, input interface{}) ( + output interface{}, metadata middleware.Metadata, err error, +) { return m(ctx, input) } @@ -141,13 +146,16 @@ func TestAttemptMiddleware(t *testing.T) { }{ "no error, no response in a single attempt": { Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { - return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { - m, ok := getRetryMetadata(ctx) - if ok { - *retries = append(*retries, m) - } - return out, metadata, err - }) + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + m, ok := getRetryMetadata(ctx) + if ok { + *retries = append(*retries, m) + } + return out, metadata, err + }) }, Expect: []retryMetadata{ { @@ -162,14 +170,17 @@ func TestAttemptMiddleware(t *testing.T) { }, "no error in a single attempt": { Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { - return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { - m, ok := getRetryMetadata(ctx) - if ok { - *retries = append(*retries, m) - } - setMockRawResponse(&metadata, "mockResponse") - return out, metadata, err - }) + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + m, ok := getRetryMetadata(ctx) + if ok { + *retries = append(*retries, m) + } + setMockRawResponse(&metadata, "mockResponse") + return out, metadata, err + }) }, Expect: []retryMetadata{ { @@ -196,19 +207,22 @@ func TestAttemptMiddleware(t *testing.T) { mockRetryableError{b: true}, nil, } - return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { - m, ok := getRetryMetadata(ctx) - if ok { - *retries = append(*retries, m) - } - if num >= len(reqsErrs) { - err = fmt.Errorf("more requests then expected") - } else { - err = reqsErrs[num] - num++ - } - return out, metadata, err - }) + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + m, ok := getRetryMetadata(ctx) + if ok { + *retries = append(*retries, m) + } + if num >= len(reqsErrs) { + err = fmt.Errorf("more requests then expected") + } else { + err = reqsErrs[num] + num++ + } + return out, metadata, err + }) }, Expect: []retryMetadata{ { @@ -249,15 +263,18 @@ func TestAttemptMiddleware(t *testing.T) { mockRetryableError{b: true}, mockRetryableError{b: true}, } - return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { - if num >= len(reqsErrs) { - err = fmt.Errorf("more requests then expected") - } else { - err = reqsErrs[num] - num++ - } - return out, metadata, err - }) + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + if num >= len(reqsErrs) { + err = fmt.Errorf("more requests then expected") + } else { + err = reqsErrs[num] + num++ + } + return out, metadata, err + }) }, Err: fmt.Errorf("exceeded maximum number of attempts"), ExpectResults: AttemptResults{Results: []AttemptResult{ @@ -280,13 +297,16 @@ func TestAttemptMiddleware(t *testing.T) { "stops on rewind error": { Request: testRequest{DisableRewind: true}, Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { - return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { - m, ok := getRetryMetadata(ctx) - if ok { - *retries = append(*retries, m) - } - return out, metadata, mockRetryableError{b: true} - }) + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + m, ok := getRetryMetadata(ctx) + if ok { + *retries = append(*retries, m) + } + return out, metadata, mockRetryableError{b: true} + }) }, Expect: []retryMetadata{ { @@ -312,13 +332,16 @@ func TestAttemptMiddleware(t *testing.T) { }, "stops on non-retryable errors": { Next: func(retries *[]retryMetadata) middleware.FinalizeHandler { - return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { - m, ok := getRetryMetadata(ctx) - if ok { - *retries = append(*retries, m) - } - return out, metadata, fmt.Errorf("some error") - }) + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + m, ok := getRetryMetadata(ctx) + if ok { + *retries = append(*retries, m) + } + return out, metadata, fmt.Errorf("some error") + }) }, Expect: []retryMetadata{ { @@ -341,25 +364,28 @@ func TestAttemptMiddleware(t *testing.T) { mockRetryableError{b: true}, nil, } - return middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) { - m, ok := getRetryMetadata(ctx) - if ok { - *retries = append(*retries, m) - } - if num >= len(reqsErrs) { - err = fmt.Errorf("more requests then expected") - } else { - err = reqsErrs[num] - num++ - } + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + m, ok := getRetryMetadata(ctx) + if ok { + *retries = append(*retries, m) + } + if num >= len(reqsErrs) { + err = fmt.Errorf("more requests then expected") + } else { + err = reqsErrs[num] + num++ + } - if err != nil { - metadata.Set("testKey", "testValue") - } else { - setMockRawResponse(&metadata, "mockResponse") - } - return out, metadata, err - }) + if err != nil { + metadata.Set("testKey", "testValue") + } else { + setMockRawResponse(&metadata, "mockResponse") + } + return out, metadata, err + }) }, Expect: []retryMetadata{ { @@ -414,7 +440,12 @@ func TestAttemptMiddleware(t *testing.T) { }) var recorded []retryMetadata - _, metadata, err := am.HandleFinalize(context.Background(), middleware.FinalizeInput{Request: tt.Request}, tt.Next(&recorded)) + _, metadata, err := am.HandleFinalize(context.Background(), + middleware.FinalizeInput{ + Request: tt.Request, + }, + tt.Next(&recorded), + ) if err != nil && tt.Err == nil { t.Errorf("expect no error, got %v", err) } else if err == nil && tt.Err != nil { @@ -446,6 +477,43 @@ func TestAttemptMiddleware(t *testing.T) { } } +func TestAttemptReleaseRetryLock(t *testing.T) { + standard := NewStandard(func(s *StandardOptions) { + s.MaxAttempts = 3 + s.RateLimiter = ratelimit.NewTokenRateLimit(10) + s.RetryCost = 10 + }) + am := NewAttemptMiddleware(standard, func(i interface{}) interface{} { + return i + }) + f := func(retries *[]retryMetadata) middleware.FinalizeHandler { + num := 0 + return middleware.FinalizeHandlerFunc( + func(ctx context.Context, in middleware.FinalizeInput) ( + out middleware.FinalizeOutput, metadata middleware.Metadata, err error, + ) { + m, ok := getRetryMetadata(ctx) + if ok { + *retries = append(*retries, m) + } + if num > 0 { + return out, metadata, err + } + num++ + return out, metadata, mockRetryableError{b: true} + }) + } + var recorded []retryMetadata + _, _, err := am.HandleFinalize(context.Background(), middleware.FinalizeInput{}, f(&recorded)) + if err != nil { + t.Fatal(err) + } + _, err = standard.GetRetryToken(context.Background(), errors.New("retryme")) + if err != nil { + t.Fatal(err) + } +} + // mockRawResponseKey is used to test the behavior when response metadata is // nested within the attempt request. type mockRawResponseKey struct{} diff --git a/aws/retry/standard.go b/aws/retry/standard.go index 29afa3bdf05..aa9b6b98205 100644 --- a/aws/retry/standard.go +++ b/aws/retry/standard.go @@ -171,8 +171,11 @@ func (s *Standard) RetryDelay(attempt int, err error) (time.Duration, error) { return s.backoff.BackoffDelay(attempt, err) } -// GetInitialToken returns the initial request token that can increment the -// retry token pool if the request is successful. +// GetInitialToken returns a token for adding the NoRetryIncrement to the +// RateLimiter token if the attempt completed successfully without error. +// +// InitialToken applies to result of the each attempt, including the first. +// Whereas the RetryToken applies to the result of subsequent attempts. func (s *Standard) GetInitialToken() func(error) error { return releaseToken(s.incrementTokens).release } @@ -197,6 +200,8 @@ func (s *Standard) GetRetryToken(ctx context.Context, err error) (func(error) er return releaseToken(fn).release, nil } +func nopRelease(error) error { return nil } + type releaseToken func() error func (f releaseToken) release(err error) error { diff --git a/aws/retryer.go b/aws/retryer.go index 0489508ef44..4c64af60907 100644 --- a/aws/retryer.go +++ b/aws/retryer.go @@ -29,7 +29,7 @@ type Retryer interface { // Returning the token release function, or error. GetRetryToken(ctx context.Context, opErr error) (releaseToken func(error) error, err error) - // GetInitalToken returns the initial request token that can increment the + // GetInitialToken returns the initial request token that can increment the // retry token pool if the request is successful. GetInitialToken() (releaseToken func(error) error) }