diff --git a/sdk/azcore/policy_body_download.go b/sdk/azcore/policy_body_download.go index fbed7ee33c60..7087242dc4b4 100644 --- a/sdk/azcore/policy_body_download.go +++ b/sdk/azcore/policy_body_download.go @@ -6,6 +6,7 @@ package azcore import ( + "errors" "fmt" "io" "io/ioutil" @@ -71,7 +72,7 @@ type bodyDownloadPolicyOpValues struct { skip bool } -// nopClosingBytesReader is an io.ReadCloser around a byte slice. +// nopClosingBytesReader is an io.ReadSeekCloser around a byte slice. // It also provides direct access to the byte slice. type nopClosingBytesReader struct { s []byte @@ -103,3 +104,23 @@ func (r *nopClosingBytesReader) Set(b []byte) { r.s = b r.i = 0 } + +// Seek implements the io.Seeker interface. +func (r *nopClosingBytesReader) Seek(offset int64, whence int) (int64, error) { + var i int64 + switch whence { + case io.SeekStart: + i = offset + case io.SeekCurrent: + i = r.i + offset + case io.SeekEnd: + i = int64(len(r.s)) + offset + default: + return 0, errors.New("nopClosingBytesReader: invalid whence") + } + if i < 0 { + return 0, errors.New("nopClosingBytesReader: negative position") + } + r.i = i + return i, nil +} diff --git a/sdk/azcore/policy_body_download_test.go b/sdk/azcore/policy_body_download_test.go index 22ec6bcc79f1..a3eda8433c85 100644 --- a/sdk/azcore/policy_body_download_test.go +++ b/sdk/azcore/policy_body_download_test.go @@ -7,6 +7,7 @@ package azcore import ( "context" + "io" "net/http" "testing" @@ -28,11 +29,15 @@ func TestDownloadBody(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(resp.payload()) == 0 { + payload, err := resp.payload() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(payload) == 0 { t.Fatal("missing payload") } - if string(resp.payload()) != message { - t.Fatalf("unexpected response: %s", string(resp.payload())) + if string(payload) != message { + t.Fatalf("unexpected response: %s", string(payload)) } } @@ -52,8 +57,12 @@ func TestSkipBodyDownload(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(resp.payload()) > 0 { - t.Fatalf("unexpected download: %s", string(resp.payload())) + payload, err := resp.payload() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(payload) != message { + t.Fatalf("unexpected body: %s", string(payload)) } } @@ -71,7 +80,11 @@ func TestDownloadBodyFail(t *testing.T) { if err == nil { t.Fatal("unexpected nil error") } - if resp.payload() != nil { + payload, err := resp.payload() + if err == nil { + t.Fatalf("expected an error") + } + if payload != nil { t.Fatal("expected nil payload") } } @@ -93,11 +106,15 @@ func TestDownloadBodyWithRetryGet(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(resp.payload()) == 0 { + payload, err := resp.payload() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(payload) == 0 { t.Fatal("missing payload") } - if string(resp.payload()) != message { - t.Fatalf("unexpected response: %s", string(resp.payload())) + if string(payload) != message { + t.Fatalf("unexpected response: %s", string(payload)) } if r := srv.Requests(); r != 3 { t.Fatalf("expected %d requests, got %d", 3, r) @@ -121,11 +138,15 @@ func TestDownloadBodyWithRetryDelete(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(resp.payload()) == 0 { + payload, err := resp.payload() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(payload) == 0 { t.Fatal("missing payload") } - if string(resp.payload()) != message { - t.Fatalf("unexpected response: %s", string(resp.payload())) + if string(payload) != message { + t.Fatalf("unexpected response: %s", string(payload)) } if r := srv.Requests(); r != 3 { t.Fatalf("expected %d requests, got %d", 3, r) @@ -149,11 +170,15 @@ func TestDownloadBodyWithRetryPut(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(resp.payload()) == 0 { + payload, err := resp.payload() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(payload) == 0 { t.Fatal("missing payload") } - if string(resp.payload()) != message { - t.Fatalf("unexpected response: %s", string(resp.payload())) + if string(payload) != message { + t.Fatalf("unexpected response: %s", string(payload)) } if r := srv.Requests(); r != 3 { t.Fatalf("expected %d requests, got %d", 3, r) @@ -180,7 +205,11 @@ func TestDownloadBodyWithRetryPatch(t *testing.T) { if _, ok := err.(*bodyDownloadError); !ok { t.Fatal("expected *bodyDownloadError type") } - if len(resp.payload()) != 0 { + payload, err := resp.payload() + if err == nil { + t.Fatalf("expected an error") + } + if len(payload) != 0 { t.Fatal("unexpected payload") } // should be only one request, no retires @@ -206,10 +235,11 @@ func TestDownloadBodyWithRetryPost(t *testing.T) { if err == nil { t.Fatal("unexpected nil error") } - if _, ok := err.(*bodyDownloadError); !ok { - t.Fatal("expected *bodyDownloadError type") + payload, err := resp.payload() + if err == nil { + t.Fatalf("expected an error") } - if len(resp.payload()) != 0 { + if len(payload) != 0 { t.Fatal("unexpected payload") } // should be only one request, no retires @@ -234,10 +264,64 @@ func TestSkipBodyDownloadWith400(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(resp.payload()) == 0 { + payload, err := resp.payload() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(payload) == 0 { t.Fatal("missing payload") } - if string(resp.payload()) != message { - t.Fatalf("unexpected response: %s", string(resp.payload())) + if string(payload) != message { + t.Fatalf("unexpected response: %s", string(payload)) + } +} + +func TestReadBodyAfterSeek(t *testing.T) { + const message = "downloaded" + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(message))) + srv.AppendResponse(mock.WithBody([]byte(message))) + // download policy is automatically added during pipeline construction + pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions())) + req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resp, err := pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + payload, err := resp.payload() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(payload) != message { + t.Fatal("incorrect payload") + } + nb, ok := resp.Body.(*nopClosingBytesReader) + if !ok { + t.Fatalf("unexpected body type: %t", resp.Body) + } + i, err := nb.Seek(0, io.SeekStart) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if i != 0 { + t.Fatalf("did not seek correctly") + } + i, err = nb.Seek(5, io.SeekCurrent) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if i != 5 { + t.Fatalf("did not seek correctly") + } + i, err = nb.Seek(5, io.SeekCurrent) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if i != 10 { + t.Fatalf("did not seek correctly") } } diff --git a/sdk/azcore/response.go b/sdk/azcore/response.go index ef6f6dfd734b..dcff907e5eb2 100644 --- a/sdk/azcore/response.go +++ b/sdk/azcore/response.go @@ -26,15 +26,18 @@ type Response struct { *http.Response } -func (r *Response) payload() []byte { - if r.Body == nil { - return nil - } +func (r *Response) payload() ([]byte, error) { // r.Body won't be a nopClosingBytesReader if downloading was skipped if buf, ok := r.Body.(*nopClosingBytesReader); ok { - return buf.Bytes() + return buf.Bytes(), nil } - return nil + bytesBody, err := ioutil.ReadAll(r.Body) + r.Body.Close() + if err != nil { + return nil, err + } + r.Body = &nopClosingBytesReader{s: bytesBody, i: 0} + return bytesBody, nil } // HasStatusCode returns true if the Response's status code is one of the specified values. @@ -52,10 +55,14 @@ func (r *Response) HasStatusCode(statusCodes ...int) bool { // UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v. func (r *Response) UnmarshalAsByteArray(v **[]byte, format Base64Encoding) error { - if len(r.payload()) == 0 { + p, err := r.payload() + if err != nil { + return err + } + if len(p) == 0 { return nil } - payload := string(r.payload()) + payload := string(p) if payload[0] == '"' { // remove surrounding quotes payload = payload[1 : len(payload)-1] @@ -84,12 +91,19 @@ func (r *Response) UnmarshalAsByteArray(v **[]byte, format Base64Encoding) error // UnmarshalAsJSON calls json.Unmarshal() to unmarshal the received payload into the value pointed to by v. // If no payload was received a RequestError is returned. If json.Unmarshal fails a UnmarshalError is returned. func (r *Response) UnmarshalAsJSON(v interface{}) error { + payload, err := r.payload() + if err != nil { + return err + } // TODO: verify early exit is correct - if len(r.payload()) == 0 { + if len(payload) == 0 { return nil } - r.removeBOM() - err := json.Unmarshal(r.payload(), v) + err = r.removeBOM() + if err != nil { + return err + } + err = json.Unmarshal(payload, v) if err != nil { err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err) } @@ -99,12 +113,19 @@ func (r *Response) UnmarshalAsJSON(v interface{}) error { // UnmarshalAsXML calls xml.Unmarshal() to unmarshal the received payload into the value pointed to by v. // If no payload was received a RequestError is returned. If xml.Unmarshal fails a UnmarshalError is returned. func (r *Response) UnmarshalAsXML(v interface{}) error { + payload, err := r.payload() + if err != nil { + return err + } // TODO: verify early exit is correct - if len(r.payload()) == 0 { + if len(payload) == 0 { return nil } - r.removeBOM() - err := xml.Unmarshal(r.payload(), v) + err = r.removeBOM() + if err != nil { + return err + } + err = xml.Unmarshal(payload, v) if err != nil { err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err) } @@ -120,12 +141,17 @@ func (r *Response) Drain() { } // removeBOM removes any byte-order mark prefix from the payload if present. -func (r *Response) removeBOM() { +func (r *Response) removeBOM() error { + payload, err := r.payload() + if err != nil { + return err + } // UTF8 - trimmed := bytes.TrimPrefix(r.payload(), []byte("\xef\xbb\xbf")) - if len(trimmed) < len(r.payload()) { + trimmed := bytes.TrimPrefix(payload, []byte("\xef\xbb\xbf")) + if len(trimmed) < len(payload) { r.Body.(*nopClosingBytesReader).Set(trimmed) } + return nil } // helper to reduce nil Response checks diff --git a/sdk/azcore/response_test.go b/sdk/azcore/response_test.go index 254811cf98f4..ff5713e54c24 100644 --- a/sdk/azcore/response_test.go +++ b/sdk/azcore/response_test.go @@ -83,6 +83,32 @@ func TestResponseUnmarshalJSON(t *testing.T) { } } +func TestResponseUnmarshalJSONskipDownload(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetResponse(mock.WithBody([]byte(`{ "someInt": 1, "someString": "s" }`))) + pl := NewPipeline(srv) + req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + req.SkipBodyDownload() + resp, err := pl.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !resp.HasStatusCode(http.StatusOK) { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + var tx testJSON + if err := resp.UnmarshalAsJSON(&tx); err != nil { + t.Fatalf("unexpected error unmarshalling: %v", err) + } + if tx.SomeInt != 1 || tx.SomeString != "s" { + t.Fatal("unexpected value") + } +} + func TestResponseUnmarshalJSONNoBody(t *testing.T) { srv, close := mock.NewServer() defer close()