Skip to content

Commit

Permalink
Some refactoring in azcore (#6982)
Browse files Browse the repository at this point in the history
Added DefaultRetryOptions() to create initialized default options.
Removed Response.CheckStatusCode() as it can't create custom errors.
  • Loading branch information
jhendrixMSFT committed Jan 16, 2020
1 parent 68963ee commit 7a7120f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 67 deletions.
44 changes: 16 additions & 28 deletions sdk/azcore/policy_retry.go
Expand Up @@ -59,33 +59,15 @@ var (
}
)

func (o RetryOptions) defaults() RetryOptions {
// We assume the following:
// 1. o.MaxTries >= 0
// 2. o.TryTimeout, o.RetryDelay, and o.MaxRetryDelay >=0
// 3. o.RetryDelay <= o.MaxRetryDelay
// 4. Both o.RetryDelay and o.MaxRetryDelay must be 0 or neither can be 0

if len(o.StatusCodes) == 0 {
o.StatusCodes = StatusCodesForRetry[:]
// DefaultRetryOptions returns an instance of RetryOptions initialized with default values.
func DefaultRetryOptions() RetryOptions {
return RetryOptions{
StatusCodes: StatusCodesForRetry[:],
MaxTries: defaultMaxTries,
TryTimeout: 1 * time.Minute,
RetryDelay: 4 * time.Second,
MaxRetryDelay: 120 * time.Second,
}

IfDefault := func(current *time.Duration, desired time.Duration) {
if *current == time.Duration(0) {
*current = desired
}
}

// Set defaults if unspecified
if o.MaxTries == 0 {
o.MaxTries = defaultMaxTries
}

IfDefault(&o.TryTimeout, 1*time.Minute)
IfDefault(&o.RetryDelay, 4*time.Second)
IfDefault(&o.MaxRetryDelay, 120*time.Second)

return o
}

func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never 0
Expand All @@ -108,8 +90,14 @@ func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never
}

// NewRetryPolicy creates a policy object configured using the specified options.
func NewRetryPolicy(o RetryOptions) Policy {
return &retryPolicy{options: o.defaults()} // Force defaults to be calculated
// Pass nil to accept the default values; this is the same as passing the result
// from a call to DefaultRetryOptions().
func NewRetryPolicy(o *RetryOptions) Policy {
if o == nil {
def := DefaultRetryOptions()
o = &def
}
return &retryPolicy{options: *o}
}

type retryPolicy struct {
Expand Down
22 changes: 12 additions & 10 deletions sdk/azcore/policy_retry_test.go
Expand Up @@ -17,13 +17,17 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

const retryDelay = 20 * time.Millisecond
func testRetryOptions() *RetryOptions {
def := DefaultRetryOptions()
def.RetryDelay = 20 * time.Millisecond
return &def
}

func TestRetryPolicySuccess(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{}))
pl := NewPipeline(srv, NewRetryPolicy(nil))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand All @@ -46,7 +50,7 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand Down Expand Up @@ -74,7 +78,7 @@ func TestRetryPolicySuccessWithRetry(t *testing.T) {
srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout))
srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError))
srv.AppendResponse()
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand All @@ -101,7 +105,7 @@ func TestRetryPolicyFailOnError(t *testing.T) {
defer close()
fakeErr := errors.New("bogus error")
srv.SetError(fakeErr)
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodPost, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand Down Expand Up @@ -130,7 +134,7 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) {
srv.AppendError(errors.New("bogus error"))
srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError))
srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{RetryDelay: retryDelay}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand All @@ -156,7 +160,7 @@ func TestRetryPolicyRequestTimedOut(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.SetError(errors.New("bogus error"))
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{}))
pl := NewPipeline(srv, NewRetryPolicy(nil))
req := NewRequest(http.MethodPost, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
Expand Down Expand Up @@ -195,9 +199,7 @@ func TestRetryPolicyIsNotRetriable(t *testing.T) {
defer close()
srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout))
srv.AppendError(theErr)
pl := NewPipeline(srv, NewRetryPolicy(RetryOptions{
RetryDelay: retryDelay,
}))
pl := NewPipeline(srv, NewRetryPolicy(testRetryOptions()))
_, err := pl.Do(context.Background(), NewRequest(http.MethodGet, srv.URL()))
if err == nil {
t.Fatal("unexpected nil error")
Expand Down
12 changes: 0 additions & 12 deletions sdk/azcore/response.go
Expand Up @@ -36,18 +36,6 @@ func (r *Response) payload() []byte {
return nil
}

// CheckStatusCode returns a RequestError if the Response's status code isn't one of the specified values.
func (r *Response) CheckStatusCode(statusCodes ...int) error {
if !r.HasStatusCode(statusCodes...) {
msg := r.Status
if len(r.payload()) > 0 {
msg = string(r.payload())
}
return newRequestError(msg, r)
}
return nil
}

// HasStatusCode returns true if the Response's status code is one of the specified values.
func (r *Response) HasStatusCode(statusCodes ...int) bool {
if r == nil {
Expand Down
27 changes: 10 additions & 17 deletions sdk/azcore/response_test.go
Expand Up @@ -23,8 +23,8 @@ func TestResponseUnmarshalXML(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
var tx testXML
if err := resp.UnmarshalAsXML(&tx); err != nil {
Expand All @@ -44,15 +44,8 @@ func TestResponseFailureStatusCode(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err = resp.CheckStatusCode(http.StatusOK); err == nil {
t.Fatal("unexpected nil status code error")
}
re, ok := err.(RequestError)
if !ok {
t.Fatal("expected RequestError type")
}
if re.Response().StatusCode != http.StatusForbidden {
t.Fatal("unexpected response")
if resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
}

Expand All @@ -65,8 +58,8 @@ func TestResponseUnmarshalJSON(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
var tx testJSON
if err := resp.UnmarshalAsJSON(&tx); err != nil {
Expand All @@ -86,8 +79,8 @@ func TestResponseUnmarshalJSONNoBody(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if err := resp.UnmarshalAsJSON(nil); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
Expand All @@ -103,8 +96,8 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if err := resp.CheckStatusCode(http.StatusOK); err != nil {
t.Fatalf("unexpected status code error: %v", err)
if !resp.HasStatusCode(http.StatusOK) {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if err := resp.UnmarshalAsXML(nil); err != nil {
t.Fatalf("unexpected error unmarshalling: %v", err)
Expand Down

0 comments on commit 7a7120f

Please sign in to comment.