diff --git a/.changelog/17ac89419cd94e598fcc93444709cc1a.json b/.changelog/17ac89419cd94e598fcc93444709cc1a.json new file mode 100644 index 00000000000..0d0ab8e2888 --- /dev/null +++ b/.changelog/17ac89419cd94e598fcc93444709cc1a.json @@ -0,0 +1,8 @@ +{ + "id": "17ac8941-9cd9-4e59-8fcc-93444709cc1a", + "type": "feature", + "description": "Respect passed in Context Deadline/Timeout. Updates the IMDS Client operations to not override the passed in Context's Deadline or Timeout options. If an Client operation is called with a Context with a Deadline or Timeout, the client will no longer override it with the client's default timeout.", + "modules": [ + "feature/ec2/imds" + ] +} \ No newline at end of file diff --git a/.changelog/53dad1d685864ddfb0304829c2a45e4c.json b/.changelog/53dad1d685864ddfb0304829c2a45e4c.json new file mode 100644 index 00000000000..c25231e2a2f --- /dev/null +++ b/.changelog/53dad1d685864ddfb0304829c2a45e4c.json @@ -0,0 +1,8 @@ +{ + "id": "53dad1d6-8586-4ddf-b030-4829c2a45e4c", + "type": "bugfix", + "description": "Fix IMDS client's response handling and operation timeout race. Fixes #1253", + "modules": [ + "feature/ec2/imds" + ] +} \ No newline at end of file diff --git a/feature/ec2/imds/doc.go b/feature/ec2/imds/doc.go index 9ae608291c1..bacdb5d21f2 100644 --- a/feature/ec2/imds/doc.go +++ b/feature/ec2/imds/doc.go @@ -1,6 +1,11 @@ // Package imds provides the API client for interacting with the Amazon EC2 // Instance Metadata Service. // +// All Client operation calls have a default timeout. If the operation is not +// completed before this timeout expires, the operation will be canceled. This +// timeout can be overridden by providing Context with a timeout or deadline +// with calling the client's operations. +// // See the EC2 IMDS user guide for more information on using the API. // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html package imds diff --git a/feature/ec2/imds/request_middleware.go b/feature/ec2/imds/request_middleware.go index 93f02405f99..605cbd13140 100644 --- a/feature/ec2/imds/request_middleware.go +++ b/feature/ec2/imds/request_middleware.go @@ -1,8 +1,10 @@ package imds import ( + "bytes" "context" "fmt" + "io/ioutil" "net/url" "path" "time" @@ -52,7 +54,7 @@ func addRequestMiddleware(stack *middleware.Stack, // Operation timeout err = stack.Initialize.Add(&operationTimeout{ - Timeout: defaultOperationTimeout, + DefaultTimeout: defaultOperationTimeout, }, middleware.Before) if err != nil { return err @@ -142,12 +144,20 @@ func (m *deserializeResponse) HandleDeserialize( resp, ok := out.RawResponse.(*smithyhttp.Response) if !ok { return out, metadata, fmt.Errorf( - "unexpected transport response type, %T", out.RawResponse) + "unexpected transport response type, %T, want %T", out.RawResponse, resp) } + defer resp.Body.Close() - // Anything thats not 200 |< 300 is error + // read the full body so that any operation timeouts cleanup will not race + // the body being read. + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return out, metadata, fmt.Errorf("read response body failed, %w", err) + } + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + + // Anything that's not 200 |< 300 is error if resp.StatusCode < 200 || resp.StatusCode >= 300 { - resp.Body.Close() return out, metadata, &smithyhttp.ResponseError{ Response: resp, Err: fmt.Errorf("request to EC2 IMDS failed"), @@ -213,8 +223,19 @@ const ( defaultOperationTimeout = 5 * time.Second ) +// operationTimeout adds a timeout on the middleware stack if the Context the +// stack was called with does not have a deadline. The next middleware must +// complete before the timeout, or the context will be canceled. +// +// If DefaultTimeout is zero, no default timeout will be used if the Context +// does not have a timeout. +// +// The next middleware must also ensure that any resources that are also +// canceled by the stack's context are completely consumed before returning. +// Otherwise the timeout cleanup will race the resource being consumed +// upstream. type operationTimeout struct { - Timeout time.Duration + DefaultTimeout time.Duration } func (*operationTimeout) ID() string { return "OperationTimeout" } @@ -224,10 +245,11 @@ func (m *operationTimeout) HandleInitialize( ) ( output middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { - var cancelFn func() - - ctx, cancelFn = context.WithTimeout(ctx, m.Timeout) - defer cancelFn() + if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 { + var cancelFn func() + ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout) + defer cancelFn() + } return next.HandleInitialize(ctx, input) } diff --git a/feature/ec2/imds/request_middleware_test.go b/feature/ec2/imds/request_middleware_test.go index 629947e096a..8aee6fb93bc 100644 --- a/feature/ec2/imds/request_middleware_test.go +++ b/feature/ec2/imds/request_middleware_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/hex" + "fmt" "io" "io/ioutil" "net/http" @@ -126,7 +127,7 @@ func TestAddRequestMiddleware(t *testing.T) { func TestOperationTimeoutMiddleware(t *testing.T) { m := &operationTimeout{ - Timeout: time.Nanosecond, + DefaultTimeout: time.Nanosecond, } _, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{}, @@ -135,6 +136,10 @@ func TestOperationTimeoutMiddleware(t *testing.T) { ) ( out middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { + if _, ok := ctx.Deadline(); !ok { + return out, metadata, fmt.Errorf("expect context deadline to be set") + } + if err := sdk.SleepWithContext(ctx, time.Second); err != nil { return out, metadata, err } @@ -150,6 +155,144 @@ func TestOperationTimeoutMiddleware(t *testing.T) { } } +func TestOperationTimeoutMiddleware_noDefaultTimeout(t *testing.T) { + m := &operationTimeout{} + + _, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{}, + middleware.InitializeHandlerFunc(func( + ctx context.Context, input middleware.InitializeInput, + ) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + if t, ok := ctx.Deadline(); ok { + return out, metadata, fmt.Errorf("expect no context deadline, got %v", t) + } + + return out, metadata, nil + })) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +func TestOperationTimeoutMiddleware_withCustomDeadline(t *testing.T) { + m := &operationTimeout{ + DefaultTimeout: time.Nanosecond, + } + + expectDeadline := time.Now().Add(time.Hour) + ctx, cancelFn := context.WithDeadline(context.Background(), expectDeadline) + defer cancelFn() + + _, _, err := m.HandleInitialize(ctx, middleware.InitializeInput{}, + middleware.InitializeHandlerFunc(func( + ctx context.Context, input middleware.InitializeInput, + ) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + t, ok := ctx.Deadline() + if !ok { + return out, metadata, fmt.Errorf("expect context deadline to be set") + } + if e, a := expectDeadline, t; !e.Equal(a) { + return out, metadata, fmt.Errorf("expect %v deadline, got %v", e, a) + } + + return out, metadata, nil + })) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +// Ensure that the response body is read in the deserialize middleware, +// ensuring that the timeoutOperation middleware won't race canceling the +// context with the upstream reading the response body. +// * https://github.com/aws/aws-sdk-go-v2/issues/1253 +func TestDeserailizeResponse_cacheBody(t *testing.T) { + type Output struct { + Content io.ReadCloser + } + m := &deserializeResponse{ + GetOutput: func(resp *smithyhttp.Response) (interface{}, error) { + return &Output{ + Content: resp.Body, + }, nil + }, + } + + expectBody := "hello world!" + originalBody := &bytesReader{ + reader: strings.NewReader(expectBody), + } + if originalBody.closed { + t.Fatalf("expect original body not to be closed yet") + } + + out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{}, + middleware.DeserializeHandlerFunc(func( + ctx context.Context, input middleware.DeserializeInput, + ) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + out.RawResponse = &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Status: "200 OK", + Header: http.Header{}, + ContentLength: int64(originalBody.Len()), + Body: originalBody, + }, + } + return out, metadata, nil + })) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !originalBody.closed { + t.Errorf("expect original body to be closed, was not") + } + + result, ok := out.Result.(*Output) + if !ok { + t.Fatalf("expect result to be Output, got %T, %v", result, result) + } + + actualBody, err := ioutil.ReadAll(result.Content) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := expectBody, string(actualBody); e != a { + t.Errorf("expect %v body, got %v", e, a) + } + if err := result.Content.Close(); err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +type bytesReader struct { + reader interface { + io.Reader + Len() int + } + closed bool +} + +func (r *bytesReader) Len() int { + return r.reader.Len() +} +func (r *bytesReader) Close() error { + r.closed = true + return nil +} +func (r *bytesReader) Read(p []byte) (int, error) { + if r.closed { + return 0, io.EOF + } + return r.reader.Read(p) +} + type successAPIResponseHandler struct { t *testing.T path string