Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Seek() method and updating payload() #13703

Merged
merged 3 commits into from Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 22 additions & 1 deletion sdk/azcore/policy_body_download.go
Expand Up @@ -6,6 +6,7 @@
package azcore

import (
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -71,7 +72,7 @@ type bodyDownloadPolicyOpValues struct {
skip bool
}

// nopClosingBytesReader is an io.ReadCloser around a byte slice.
// nopClosingBytesReader is an io.ReadSeekCloser around a byte slice.
// It also provides direct access to the byte slice.
type nopClosingBytesReader struct {
s []byte
Expand Down Expand Up @@ -103,3 +104,23 @@ func (r *nopClosingBytesReader) Set(b []byte) {
r.s = b
r.i = 0
}

// Seek implements the io.Seeker interface.
func (r *nopClosingBytesReader) Seek(offset int64, whence int) (int64, error) {
var i int64
switch whence {
case io.SeekStart:
i = offset
case io.SeekCurrent:
i = r.i + offset
case io.SeekEnd:
i = int64(len(r.s)) + offset
default:
return 0, errors.New("nopClosingBytesReader: invalid whence")
}
if i < 0 {
return 0, errors.New("nopClosingBytesReader: negative position")
}
r.i = i
return i, nil
}
128 changes: 106 additions & 22 deletions sdk/azcore/policy_body_download_test.go
Expand Up @@ -7,6 +7,7 @@ package azcore

import (
"context"
"io"
"net/http"
"testing"

Expand All @@ -28,11 +29,15 @@ func TestDownloadBody(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
}

Expand All @@ -52,8 +57,12 @@ func TestSkipBodyDownload(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) > 0 {
t.Fatalf("unexpected download: %s", string(resp.payload()))
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(payload) != message {
t.Fatalf("unexpected body: %s", string(payload))
}
}

Expand All @@ -71,7 +80,11 @@ func TestDownloadBodyFail(t *testing.T) {
if err == nil {
t.Fatal("unexpected nil error")
}
if resp.payload() != nil {
payload, err := resp.payload()
if err == nil {
t.Fatalf("expected an error")
}
if payload != nil {
t.Fatal("expected nil payload")
}
}
Expand All @@ -93,11 +106,15 @@ func TestDownloadBodyWithRetryGet(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
if r := srv.Requests(); r != 3 {
t.Fatalf("expected %d requests, got %d", 3, r)
Expand All @@ -121,11 +138,15 @@ func TestDownloadBodyWithRetryDelete(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
if r := srv.Requests(); r != 3 {
t.Fatalf("expected %d requests, got %d", 3, r)
Expand All @@ -149,11 +170,15 @@ func TestDownloadBodyWithRetryPut(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
if r := srv.Requests(); r != 3 {
t.Fatalf("expected %d requests, got %d", 3, r)
Expand All @@ -180,7 +205,11 @@ func TestDownloadBodyWithRetryPatch(t *testing.T) {
if _, ok := err.(*bodyDownloadError); !ok {
t.Fatal("expected *bodyDownloadError type")
}
if len(resp.payload()) != 0 {
payload, err := resp.payload()
if err == nil {
t.Fatalf("expected an error")
}
if len(payload) != 0 {
t.Fatal("unexpected payload")
}
// should be only one request, no retires
Expand All @@ -206,10 +235,11 @@ func TestDownloadBodyWithRetryPost(t *testing.T) {
if err == nil {
t.Fatal("unexpected nil error")
}
if _, ok := err.(*bodyDownloadError); !ok {
t.Fatal("expected *bodyDownloadError type")
payload, err := resp.payload()
if err == nil {
t.Fatalf("expected an error")
}
if len(resp.payload()) != 0 {
if len(payload) != 0 {
t.Fatal("unexpected payload")
}
// should be only one request, no retires
Expand All @@ -234,10 +264,64 @@ func TestSkipBodyDownloadWith400(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
}

func TestReadBodyAfterSeek(t *testing.T) {
const message = "downloaded"
srv, close := mock.NewServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(message)))
srv.AppendResponse(mock.WithBody([]byte(message)))
// download policy is automatically added during pipeline construction
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resp, err := pl.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(payload) != message {
t.Fatal("incorrect payload")
}
nb, ok := resp.Body.(*nopClosingBytesReader)
if !ok {
t.Fatalf("unexpected body type: %t", resp.Body)
}
i, err := nb.Seek(0, io.SeekStart)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if i != 0 {
t.Fatalf("did not seek correctly")
}
i, err = nb.Seek(5, io.SeekCurrent)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if i != 5 {
t.Fatalf("did not seek correctly")
}
i, err = nb.Seek(5, io.SeekCurrent)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if i != 10 {
t.Fatalf("did not seek correctly")
}
}
60 changes: 43 additions & 17 deletions sdk/azcore/response.go
Expand Up @@ -26,15 +26,18 @@ type Response struct {
*http.Response
}

func (r *Response) payload() []byte {
if r.Body == nil {
return nil
}
func (r *Response) payload() ([]byte, error) {
// r.Body won't be a nopClosingBytesReader if downloading was skipped
if buf, ok := r.Body.(*nopClosingBytesReader); ok {
return buf.Bytes()
return buf.Bytes(), nil
}
return nil
bytesBody, err := ioutil.ReadAll(r.Body)
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
r.Body.Close()
if err != nil {
return nil, err
}
r.Body = &nopClosingBytesReader{s: bytesBody, i: 0}
return bytesBody, nil
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
}

// HasStatusCode returns true if the Response's status code is one of the specified values.
Expand All @@ -52,10 +55,14 @@ func (r *Response) HasStatusCode(statusCodes ...int) bool {

// UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v.
func (r *Response) UnmarshalAsByteArray(v **[]byte, format Base64Encoding) error {
if len(r.payload()) == 0 {
p, err := r.payload()
if err != nil {
return err
}
if len(p) == 0 {
return nil
}
payload := string(r.payload())
payload := string(p)
if payload[0] == '"' {
// remove surrounding quotes
payload = payload[1 : len(payload)-1]
Expand Down Expand Up @@ -84,12 +91,19 @@ func (r *Response) UnmarshalAsByteArray(v **[]byte, format Base64Encoding) error
// UnmarshalAsJSON calls json.Unmarshal() to unmarshal the received payload into the value pointed to by v.
// If no payload was received a RequestError is returned. If json.Unmarshal fails a UnmarshalError is returned.
func (r *Response) UnmarshalAsJSON(v interface{}) error {
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
payload, err := r.payload()
if err != nil {
return err
}
// TODO: verify early exit is correct
if len(r.payload()) == 0 {
if len(payload) == 0 {
return nil
}
r.removeBOM()
err := json.Unmarshal(r.payload(), v)
err = r.removeBOM()
if err != nil {
return err
}
err = json.Unmarshal(payload, v)
if err != nil {
err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err)
}
Expand All @@ -99,12 +113,19 @@ func (r *Response) UnmarshalAsJSON(v interface{}) error {
// UnmarshalAsXML calls xml.Unmarshal() to unmarshal the received payload into the value pointed to by v.
// If no payload was received a RequestError is returned. If xml.Unmarshal fails a UnmarshalError is returned.
func (r *Response) UnmarshalAsXML(v interface{}) error {
payload, err := r.payload()
if err != nil {
return err
}
// TODO: verify early exit is correct
if len(r.payload()) == 0 {
if len(payload) == 0 {
return nil
}
r.removeBOM()
err := xml.Unmarshal(r.payload(), v)
err = r.removeBOM()
if err != nil {
return err
}
err = xml.Unmarshal(payload, v)
if err != nil {
err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err)
}
Expand All @@ -120,12 +141,17 @@ func (r *Response) Drain() {
}

// removeBOM removes any byte-order mark prefix from the payload if present.
func (r *Response) removeBOM() {
func (r *Response) removeBOM() error {
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
payload, err := r.payload()
if err != nil {
return err
}
// UTF8
trimmed := bytes.TrimPrefix(r.payload(), []byte("\xef\xbb\xbf"))
if len(trimmed) < len(r.payload()) {
trimmed := bytes.TrimPrefix(payload, []byte("\xef\xbb\xbf"))
if len(trimmed) < len(payload) {
r.Body.(*nopClosingBytesReader).Set(trimmed)
}
return nil
}

// helper to reduce nil Response checks
Expand Down
26 changes: 26 additions & 0 deletions sdk/azcore/response_test.go
Expand Up @@ -83,6 +83,32 @@ func TestResponseUnmarshalJSON(t *testing.T) {
}
}

func TestResponseUnmarshalJSONskipDownload(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithBody([]byte(`{ "someInt": 1, "someString": "s" }`)))
pl := NewPipeline(srv)
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req.SkipBodyDownload()
resp, err := pl.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
var tx testJSON
if err := resp.UnmarshalAsJSON(&tx); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
}
if tx.SomeInt != 1 || tx.SomeString != "s" {
t.Fatal("unexpected value")
}
}

func TestResponseUnmarshalJSONNoBody(t *testing.T) {
srv, close := mock.NewServer()
defer close()
Expand Down