Skip to content

Commit

Permalink
Adding Seek() method and updating payload() (#13703)
Browse files Browse the repository at this point in the history
* Adding Seek() method and updating payload()

* updates based on comments

* add tests
  • Loading branch information
catalinaperalta committed Nov 20, 2020
1 parent 8d7d6ee commit c1bff18
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 40 deletions.
23 changes: 22 additions & 1 deletion sdk/azcore/policy_body_download.go
Expand Up @@ -6,6 +6,7 @@
package azcore

import (
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -71,7 +72,7 @@ type bodyDownloadPolicyOpValues struct {
skip bool
}

// nopClosingBytesReader is an io.ReadCloser around a byte slice.
// nopClosingBytesReader is an io.ReadSeekCloser around a byte slice.
// It also provides direct access to the byte slice.
type nopClosingBytesReader struct {
s []byte
Expand Down Expand Up @@ -103,3 +104,23 @@ func (r *nopClosingBytesReader) Set(b []byte) {
r.s = b
r.i = 0
}

// Seek implements the io.Seeker interface.
func (r *nopClosingBytesReader) Seek(offset int64, whence int) (int64, error) {
var i int64
switch whence {
case io.SeekStart:
i = offset
case io.SeekCurrent:
i = r.i + offset
case io.SeekEnd:
i = int64(len(r.s)) + offset
default:
return 0, errors.New("nopClosingBytesReader: invalid whence")
}
if i < 0 {
return 0, errors.New("nopClosingBytesReader: negative position")
}
r.i = i
return i, nil
}
128 changes: 106 additions & 22 deletions sdk/azcore/policy_body_download_test.go
Expand Up @@ -7,6 +7,7 @@ package azcore

import (
"context"
"io"
"net/http"
"testing"

Expand All @@ -28,11 +29,15 @@ func TestDownloadBody(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
}

Expand All @@ -52,8 +57,12 @@ func TestSkipBodyDownload(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) > 0 {
t.Fatalf("unexpected download: %s", string(resp.payload()))
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(payload) != message {
t.Fatalf("unexpected body: %s", string(payload))
}
}

Expand All @@ -71,7 +80,11 @@ func TestDownloadBodyFail(t *testing.T) {
if err == nil {
t.Fatal("unexpected nil error")
}
if resp.payload() != nil {
payload, err := resp.payload()
if err == nil {
t.Fatalf("expected an error")
}
if payload != nil {
t.Fatal("expected nil payload")
}
}
Expand All @@ -93,11 +106,15 @@ func TestDownloadBodyWithRetryGet(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
if r := srv.Requests(); r != 3 {
t.Fatalf("expected %d requests, got %d", 3, r)
Expand All @@ -121,11 +138,15 @@ func TestDownloadBodyWithRetryDelete(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
if r := srv.Requests(); r != 3 {
t.Fatalf("expected %d requests, got %d", 3, r)
Expand All @@ -149,11 +170,15 @@ func TestDownloadBodyWithRetryPut(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
if r := srv.Requests(); r != 3 {
t.Fatalf("expected %d requests, got %d", 3, r)
Expand All @@ -180,7 +205,11 @@ func TestDownloadBodyWithRetryPatch(t *testing.T) {
if _, ok := err.(*bodyDownloadError); !ok {
t.Fatal("expected *bodyDownloadError type")
}
if len(resp.payload()) != 0 {
payload, err := resp.payload()
if err == nil {
t.Fatalf("expected an error")
}
if len(payload) != 0 {
t.Fatal("unexpected payload")
}
// should be only one request, no retires
Expand All @@ -206,10 +235,11 @@ func TestDownloadBodyWithRetryPost(t *testing.T) {
if err == nil {
t.Fatal("unexpected nil error")
}
if _, ok := err.(*bodyDownloadError); !ok {
t.Fatal("expected *bodyDownloadError type")
payload, err := resp.payload()
if err == nil {
t.Fatalf("expected an error")
}
if len(resp.payload()) != 0 {
if len(payload) != 0 {
t.Fatal("unexpected payload")
}
// should be only one request, no retires
Expand All @@ -234,10 +264,64 @@ func TestSkipBodyDownloadWith400(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.payload()) == 0 {
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(payload) == 0 {
t.Fatal("missing payload")
}
if string(resp.payload()) != message {
t.Fatalf("unexpected response: %s", string(resp.payload()))
if string(payload) != message {
t.Fatalf("unexpected response: %s", string(payload))
}
}

func TestReadBodyAfterSeek(t *testing.T) {
const message = "downloaded"
srv, close := mock.NewServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(message)))
srv.AppendResponse(mock.WithBody([]byte(message)))
// download policy is automatically added during pipeline construction
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resp, err := pl.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
payload, err := resp.payload()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(payload) != message {
t.Fatal("incorrect payload")
}
nb, ok := resp.Body.(*nopClosingBytesReader)
if !ok {
t.Fatalf("unexpected body type: %t", resp.Body)
}
i, err := nb.Seek(0, io.SeekStart)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if i != 0 {
t.Fatalf("did not seek correctly")
}
i, err = nb.Seek(5, io.SeekCurrent)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if i != 5 {
t.Fatalf("did not seek correctly")
}
i, err = nb.Seek(5, io.SeekCurrent)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if i != 10 {
t.Fatalf("did not seek correctly")
}
}
60 changes: 43 additions & 17 deletions sdk/azcore/response.go
Expand Up @@ -26,15 +26,18 @@ type Response struct {
*http.Response
}

func (r *Response) payload() []byte {
if r.Body == nil {
return nil
}
func (r *Response) payload() ([]byte, error) {
// r.Body won't be a nopClosingBytesReader if downloading was skipped
if buf, ok := r.Body.(*nopClosingBytesReader); ok {
return buf.Bytes()
return buf.Bytes(), nil
}
return nil
bytesBody, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return nil, err
}
r.Body = &nopClosingBytesReader{s: bytesBody, i: 0}
return bytesBody, nil
}

// HasStatusCode returns true if the Response's status code is one of the specified values.
Expand All @@ -52,10 +55,14 @@ func (r *Response) HasStatusCode(statusCodes ...int) bool {

// UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v.
func (r *Response) UnmarshalAsByteArray(v **[]byte, format Base64Encoding) error {
if len(r.payload()) == 0 {
p, err := r.payload()
if err != nil {
return err
}
if len(p) == 0 {
return nil
}
payload := string(r.payload())
payload := string(p)
if payload[0] == '"' {
// remove surrounding quotes
payload = payload[1 : len(payload)-1]
Expand Down Expand Up @@ -84,12 +91,19 @@ func (r *Response) UnmarshalAsByteArray(v **[]byte, format Base64Encoding) error
// UnmarshalAsJSON calls json.Unmarshal() to unmarshal the received payload into the value pointed to by v.
// If no payload was received a RequestError is returned. If json.Unmarshal fails a UnmarshalError is returned.
func (r *Response) UnmarshalAsJSON(v interface{}) error {
payload, err := r.payload()
if err != nil {
return err
}
// TODO: verify early exit is correct
if len(r.payload()) == 0 {
if len(payload) == 0 {
return nil
}
r.removeBOM()
err := json.Unmarshal(r.payload(), v)
err = r.removeBOM()
if err != nil {
return err
}
err = json.Unmarshal(payload, v)
if err != nil {
err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err)
}
Expand All @@ -99,12 +113,19 @@ func (r *Response) UnmarshalAsJSON(v interface{}) error {
// UnmarshalAsXML calls xml.Unmarshal() to unmarshal the received payload into the value pointed to by v.
// If no payload was received a RequestError is returned. If xml.Unmarshal fails a UnmarshalError is returned.
func (r *Response) UnmarshalAsXML(v interface{}) error {
payload, err := r.payload()
if err != nil {
return err
}
// TODO: verify early exit is correct
if len(r.payload()) == 0 {
if len(payload) == 0 {
return nil
}
r.removeBOM()
err := xml.Unmarshal(r.payload(), v)
err = r.removeBOM()
if err != nil {
return err
}
err = xml.Unmarshal(payload, v)
if err != nil {
err = fmt.Errorf("unmarshalling type %s: %w", reflect.TypeOf(v).Elem().Name(), err)
}
Expand All @@ -120,12 +141,17 @@ func (r *Response) Drain() {
}

// removeBOM removes any byte-order mark prefix from the payload if present.
func (r *Response) removeBOM() {
func (r *Response) removeBOM() error {
payload, err := r.payload()
if err != nil {
return err
}
// UTF8
trimmed := bytes.TrimPrefix(r.payload(), []byte("\xef\xbb\xbf"))
if len(trimmed) < len(r.payload()) {
trimmed := bytes.TrimPrefix(payload, []byte("\xef\xbb\xbf"))
if len(trimmed) < len(payload) {
r.Body.(*nopClosingBytesReader).Set(trimmed)
}
return nil
}

// helper to reduce nil Response checks
Expand Down
26 changes: 26 additions & 0 deletions sdk/azcore/response_test.go
Expand Up @@ -83,6 +83,32 @@ func TestResponseUnmarshalJSON(t *testing.T) {
}
}

func TestResponseUnmarshalJSONskipDownload(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithBody([]byte(`{ "someInt": 1, "someString": "s" }`)))
pl := NewPipeline(srv)
req, err := NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
req.SkipBodyDownload()
resp, err := pl.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
var tx testJSON
if err := resp.UnmarshalAsJSON(&tx); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
}
if tx.SomeInt != 1 || tx.SomeString != "s" {
t.Fatal("unexpected value")
}
}

func TestResponseUnmarshalJSONNoBody(t *testing.T) {
srv, close := mock.NewServer()
defer close()
Expand Down

0 comments on commit c1bff18

Please sign in to comment.