From 0d3bd7a3ee7b0eda33531db18f4aa404225381e0 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte <961963+jasdel@users.noreply.github.com> Date: Wed, 6 Oct 2021 11:18:51 -0700 Subject: [PATCH] feature/ec2/imds: Fix Client's response handling and operation timeout race (#1448) Fixes #1253 race between reading a IMDS response, and the operationTimeout middleware cleaning up its timeout context. Changes the IMDS client to always buffer the response body received, before the result is deserialized. This ensures that the consumer of the operation's response body will not race with context cleanup within the middleware stack. Updates the IMDS Client operations to not override the passed in Context's Deadline or Timeout options. If an Client operation is called with a Context with a Deadline or Timeout, the client will no longer override it with the client's default timeout. Updates operationTimeout so that if DefaultTimeout is unset (aka zero) operationTimeout will not set a default timeout on the context. --- .../17ac89419cd94e598fcc93444709cc1a.json | 8 + .../53dad1d685864ddfb0304829c2a45e4c.json | 8 + feature/ec2/imds/doc.go | 5 + feature/ec2/imds/request_middleware.go | 40 +++-- feature/ec2/imds/request_middleware_test.go | 145 +++++++++++++++++- 5 files changed, 196 insertions(+), 10 deletions(-) create mode 100644 .changelog/17ac89419cd94e598fcc93444709cc1a.json create mode 100644 .changelog/53dad1d685864ddfb0304829c2a45e4c.json diff --git a/.changelog/17ac89419cd94e598fcc93444709cc1a.json b/.changelog/17ac89419cd94e598fcc93444709cc1a.json new file mode 100644 index 00000000000..0d0ab8e2888 --- /dev/null +++ b/.changelog/17ac89419cd94e598fcc93444709cc1a.json @@ -0,0 +1,8 @@ +{ + "id": "17ac8941-9cd9-4e59-8fcc-93444709cc1a", + "type": "feature", + "description": "Respect passed in Context Deadline/Timeout. Updates the IMDS Client operations to not override the passed in Context's Deadline or Timeout options. If an Client operation is called with a Context with a Deadline or Timeout, the client will no longer override it with the client's default timeout.", + "modules": [ + "feature/ec2/imds" + ] +} \ No newline at end of file diff --git a/.changelog/53dad1d685864ddfb0304829c2a45e4c.json b/.changelog/53dad1d685864ddfb0304829c2a45e4c.json new file mode 100644 index 00000000000..c25231e2a2f --- /dev/null +++ b/.changelog/53dad1d685864ddfb0304829c2a45e4c.json @@ -0,0 +1,8 @@ +{ + "id": "53dad1d6-8586-4ddf-b030-4829c2a45e4c", + "type": "bugfix", + "description": "Fix IMDS client's response handling and operation timeout race. Fixes #1253", + "modules": [ + "feature/ec2/imds" + ] +} \ No newline at end of file diff --git a/feature/ec2/imds/doc.go b/feature/ec2/imds/doc.go index 9ae608291c1..bacdb5d21f2 100644 --- a/feature/ec2/imds/doc.go +++ b/feature/ec2/imds/doc.go @@ -1,6 +1,11 @@ // Package imds provides the API client for interacting with the Amazon EC2 // Instance Metadata Service. // +// All Client operation calls have a default timeout. If the operation is not +// completed before this timeout expires, the operation will be canceled. This +// timeout can be overridden by providing Context with a timeout or deadline +// with calling the client's operations. +// // See the EC2 IMDS user guide for more information on using the API. // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html package imds diff --git a/feature/ec2/imds/request_middleware.go b/feature/ec2/imds/request_middleware.go index 93f02405f99..605cbd13140 100644 --- a/feature/ec2/imds/request_middleware.go +++ b/feature/ec2/imds/request_middleware.go @@ -1,8 +1,10 @@ package imds import ( + "bytes" "context" "fmt" + "io/ioutil" "net/url" "path" "time" @@ -52,7 +54,7 @@ func addRequestMiddleware(stack *middleware.Stack, // Operation timeout err = stack.Initialize.Add(&operationTimeout{ - Timeout: defaultOperationTimeout, + DefaultTimeout: defaultOperationTimeout, }, middleware.Before) if err != nil { return err @@ -142,12 +144,20 @@ func (m *deserializeResponse) HandleDeserialize( resp, ok := out.RawResponse.(*smithyhttp.Response) if !ok { return out, metadata, fmt.Errorf( - "unexpected transport response type, %T", out.RawResponse) + "unexpected transport response type, %T, want %T", out.RawResponse, resp) } + defer resp.Body.Close() - // Anything thats not 200 |< 300 is error + // read the full body so that any operation timeouts cleanup will not race + // the body being read. + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return out, metadata, fmt.Errorf("read response body failed, %w", err) + } + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + + // Anything that's not 200 |< 300 is error if resp.StatusCode < 200 || resp.StatusCode >= 300 { - resp.Body.Close() return out, metadata, &smithyhttp.ResponseError{ Response: resp, Err: fmt.Errorf("request to EC2 IMDS failed"), @@ -213,8 +223,19 @@ const ( defaultOperationTimeout = 5 * time.Second ) +// operationTimeout adds a timeout on the middleware stack if the Context the +// stack was called with does not have a deadline. The next middleware must +// complete before the timeout, or the context will be canceled. +// +// If DefaultTimeout is zero, no default timeout will be used if the Context +// does not have a timeout. +// +// The next middleware must also ensure that any resources that are also +// canceled by the stack's context are completely consumed before returning. +// Otherwise the timeout cleanup will race the resource being consumed +// upstream. type operationTimeout struct { - Timeout time.Duration + DefaultTimeout time.Duration } func (*operationTimeout) ID() string { return "OperationTimeout" } @@ -224,10 +245,11 @@ func (m *operationTimeout) HandleInitialize( ) ( output middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { - var cancelFn func() - - ctx, cancelFn = context.WithTimeout(ctx, m.Timeout) - defer cancelFn() + if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 { + var cancelFn func() + ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout) + defer cancelFn() + } return next.HandleInitialize(ctx, input) } diff --git a/feature/ec2/imds/request_middleware_test.go b/feature/ec2/imds/request_middleware_test.go index 629947e096a..8aee6fb93bc 100644 --- a/feature/ec2/imds/request_middleware_test.go +++ b/feature/ec2/imds/request_middleware_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/hex" + "fmt" "io" "io/ioutil" "net/http" @@ -126,7 +127,7 @@ func TestAddRequestMiddleware(t *testing.T) { func TestOperationTimeoutMiddleware(t *testing.T) { m := &operationTimeout{ - Timeout: time.Nanosecond, + DefaultTimeout: time.Nanosecond, } _, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{}, @@ -135,6 +136,10 @@ func TestOperationTimeoutMiddleware(t *testing.T) { ) ( out middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { + if _, ok := ctx.Deadline(); !ok { + return out, metadata, fmt.Errorf("expect context deadline to be set") + } + if err := sdk.SleepWithContext(ctx, time.Second); err != nil { return out, metadata, err } @@ -150,6 +155,144 @@ func TestOperationTimeoutMiddleware(t *testing.T) { } } +func TestOperationTimeoutMiddleware_noDefaultTimeout(t *testing.T) { + m := &operationTimeout{} + + _, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{}, + middleware.InitializeHandlerFunc(func( + ctx context.Context, input middleware.InitializeInput, + ) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + if t, ok := ctx.Deadline(); ok { + return out, metadata, fmt.Errorf("expect no context deadline, got %v", t) + } + + return out, metadata, nil + })) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +func TestOperationTimeoutMiddleware_withCustomDeadline(t *testing.T) { + m := &operationTimeout{ + DefaultTimeout: time.Nanosecond, + } + + expectDeadline := time.Now().Add(time.Hour) + ctx, cancelFn := context.WithDeadline(context.Background(), expectDeadline) + defer cancelFn() + + _, _, err := m.HandleInitialize(ctx, middleware.InitializeInput{}, + middleware.InitializeHandlerFunc(func( + ctx context.Context, input middleware.InitializeInput, + ) ( + out middleware.InitializeOutput, metadata middleware.Metadata, err error, + ) { + t, ok := ctx.Deadline() + if !ok { + return out, metadata, fmt.Errorf("expect context deadline to be set") + } + if e, a := expectDeadline, t; !e.Equal(a) { + return out, metadata, fmt.Errorf("expect %v deadline, got %v", e, a) + } + + return out, metadata, nil + })) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +// Ensure that the response body is read in the deserialize middleware, +// ensuring that the timeoutOperation middleware won't race canceling the +// context with the upstream reading the response body. +// * https://github.com/aws/aws-sdk-go-v2/issues/1253 +func TestDeserailizeResponse_cacheBody(t *testing.T) { + type Output struct { + Content io.ReadCloser + } + m := &deserializeResponse{ + GetOutput: func(resp *smithyhttp.Response) (interface{}, error) { + return &Output{ + Content: resp.Body, + }, nil + }, + } + + expectBody := "hello world!" + originalBody := &bytesReader{ + reader: strings.NewReader(expectBody), + } + if originalBody.closed { + t.Fatalf("expect original body not to be closed yet") + } + + out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{}, + middleware.DeserializeHandlerFunc(func( + ctx context.Context, input middleware.DeserializeInput, + ) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + out.RawResponse = &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Status: "200 OK", + Header: http.Header{}, + ContentLength: int64(originalBody.Len()), + Body: originalBody, + }, + } + return out, metadata, nil + })) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !originalBody.closed { + t.Errorf("expect original body to be closed, was not") + } + + result, ok := out.Result.(*Output) + if !ok { + t.Fatalf("expect result to be Output, got %T, %v", result, result) + } + + actualBody, err := ioutil.ReadAll(result.Content) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := expectBody, string(actualBody); e != a { + t.Errorf("expect %v body, got %v", e, a) + } + if err := result.Content.Close(); err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +type bytesReader struct { + reader interface { + io.Reader + Len() int + } + closed bool +} + +func (r *bytesReader) Len() int { + return r.reader.Len() +} +func (r *bytesReader) Close() error { + r.closed = true + return nil +} +func (r *bytesReader) Read(p []byte) (int, error) { + if r.closed { + return 0, io.EOF + } + return r.reader.Read(p) +} + type successAPIResponseHandler struct { t *testing.T path string