Skip to content

Commit

Permalink
Provide a custom retryablehttp.ErrorHandler for more consistent retur…
Browse files Browse the repository at this point in the history
…ns using retries. (#629)
  • Loading branch information
andrewsomething committed Aug 17, 2023
1 parent f97a523 commit 21863bf
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 26 deletions.
49 changes: 38 additions & 11 deletions godo.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"reflect"
Expand All @@ -27,9 +26,11 @@ const (
userAgent = "godo/" + libraryVersion
mediaType = "application/json"

headerRateLimit = "RateLimit-Limit"
headerRateRemaining = "RateLimit-Remaining"
headerRateReset = "RateLimit-Reset"
headerRateLimit = "RateLimit-Limit"
headerRateRemaining = "RateLimit-Remaining"
headerRateReset = "RateLimit-Reset"
headerRequestID = "x-request-id"
internalHeaderRetryAttempts = "X-Godo-Retry-Attempts"
)

// Client manages communication with DigitalOcean V2 API.
Expand Down Expand Up @@ -178,6 +179,9 @@ type ErrorResponse struct {

// RequestID returned from the API, useful to contact support.
RequestID string `json:"request_id"`

// Attempts is the number of times the request was attempted when retries are enabled.
Attempts int
}

// Rate contains the rate limit for the current client.
Expand Down Expand Up @@ -314,6 +318,19 @@ func New(httpClient *http.Client, opts ...ClientOpt) (*Client, error) {
// if timeout is set, it is maintained before overwriting client with StandardClient()
retryableClient.HTTPClient.Timeout = c.HTTPClient.Timeout

// This custom ErrorHandler is required to provide errors that are consistent
// with a *godo.ErrorResponse and a non-nil *godo.Response while providing
// insight into retries using an internal header.
retryableClient.ErrorHandler = func(resp *http.Response, err error, numTries int) (*http.Response, error) {
if resp != nil {
resp.Header.Add(internalHeaderRetryAttempts, strconv.Itoa(numTries))

return resp, err
}

return resp, err
}

var source *oauth2.Transport
if _, ok := c.HTTPClient.Transport.(*oauth2.Transport); ok {
source = c.HTTPClient.Transport.(*oauth2.Transport)
Expand Down Expand Up @@ -489,7 +506,7 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
// won't reuse it anyway.
const maxBodySlurpSize = 2 << 10
if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize)
io.CopyN(io.Discard, resp.Body, maxBodySlurpSize)
}

if rerr := resp.Body.Close(); err == nil {
Expand Down Expand Up @@ -539,12 +556,17 @@ func DoRequestWithClient(
}

func (r *ErrorResponse) Error() string {
var attempted string
if r.Attempts > 0 {
attempted = fmt.Sprintf("; giving up after %d attempt(s)", r.Attempts)
}

if r.RequestID != "" {
return fmt.Sprintf("%v %v: %d (request %q) %v",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.RequestID, r.Message)
return fmt.Sprintf("%v %v: %d (request %q) %v%s",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.RequestID, r.Message, attempted)
}
return fmt.Sprintf("%v %v: %d %v",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.Message)
return fmt.Sprintf("%v %v: %d %v%s",
r.Response.Request.Method, r.Response.Request.URL, r.Response.StatusCode, r.Message, attempted)
}

// CheckResponse checks the API response for errors, and returns them if present. A response is considered an
Expand All @@ -557,7 +579,7 @@ func CheckResponse(r *http.Response) error {
}

errorResponse := &ErrorResponse{Response: r}
data, err := ioutil.ReadAll(r.Body)
data, err := io.ReadAll(r.Body)
if err == nil && len(data) > 0 {
err := json.Unmarshal(data, errorResponse)
if err != nil {
Expand All @@ -566,7 +588,12 @@ func CheckResponse(r *http.Response) error {
}

if errorResponse.RequestID == "" {
errorResponse.RequestID = r.Header.Get("x-request-id")
errorResponse.RequestID = r.Header.Get(headerRequestID)
}

attempts, strconvErr := strconv.Atoi(r.Header.Get(internalHeaderRetryAttempts))
if strconvErr == nil {
errorResponse.Attempts = attempts
}

return errorResponse
Expand Down
94 changes: 79 additions & 15 deletions godo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -191,7 +191,7 @@ func TestNewRequest(t *testing.T) {
}

// test body was JSON encoded
body, _ := ioutil.ReadAll(req.Body)
body, _ := io.ReadAll(req.Body)
if string(body) != outBody {
t.Errorf("NewRequest(%v)Body = %v, expected %v", inBody, string(body), outBody)
}
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestNewRequest_withUserData(t *testing.T) {
}

// test body was JSON encoded
body, _ := ioutil.ReadAll(req.Body)
body, _ := io.ReadAll(req.Body)
if string(body) != outBody {
t.Errorf("NewRequest(%v)Body = %v, expected %v", inBody, string(body), outBody)
}
Expand Down Expand Up @@ -271,7 +271,7 @@ func TestNewRequest_withDropletAgent(t *testing.T) {
}

// test body was JSON encoded
body, _ := ioutil.ReadAll(req.Body)
body, _ := io.ReadAll(req.Body)
if string(body) != outBody {
t.Errorf("NewRequest(%v)Body = %v, expected %v", inBody, string(body), outBody)
}
Expand Down Expand Up @@ -406,7 +406,7 @@ func TestCheckResponse(t *testing.T) {
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m",
Body: io.NopCloser(strings.NewReader(`{"message":"m",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -418,7 +418,7 @@ func TestCheckResponse(t *testing.T) {
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef",
Body: io.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -432,7 +432,7 @@ func TestCheckResponse(t *testing.T) {
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Header: testHeaders,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m",
Body: io.NopCloser(strings.NewReader(`{"message":"m",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -448,7 +448,7 @@ func TestCheckResponse(t *testing.T) {
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Header: testHeaders,
Body: ioutil.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef-body",
Body: io.NopCloser(strings.NewReader(`{"message":"m", "request_id": "dead-beef-body",
"errors": [{"resource": "r", "field": "f", "code": "c"}]}`)),
},
expected: &ErrorResponse{
Expand All @@ -463,7 +463,7 @@ func TestCheckResponse(t *testing.T) {
input: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader("")),
Body: io.NopCloser(strings.NewReader("")),
},
expected: &ErrorResponse{},
},
Expand Down Expand Up @@ -614,14 +614,15 @@ func TestWithRetryAndBackoffs(t *testing.T) {

url, _ := url.Parse(server.URL)
mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"id": "bad_request", "message": "broken"}`))
})

tokenSrc := oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "new_token",
})

oauth_client := oauth2.NewClient(oauth2.NoContext, tokenSrc)
oauthClient := oauth2.NewClient(oauth2.NoContext, tokenSrc)

waitMax := PtrTo(6.0)
waitMin := PtrTo(3.0)
Expand All @@ -633,7 +634,7 @@ func TestWithRetryAndBackoffs(t *testing.T) {
}

// Create the client. Use short retry windows so we fail faster.
client, err := New(oauth_client, WithRetryAndBackoffs(retryConfig))
client, err := New(oauthClient, WithRetryAndBackoffs(retryConfig))
client.BaseURL = url
if err != nil {
t.Fatalf("err: %v", err)
Expand All @@ -645,13 +646,12 @@ func TestWithRetryAndBackoffs(t *testing.T) {
t.Fatalf("err: %v", err)
}

expectingErr := "giving up after 4 attempt(s)"
expectingErr := fmt.Sprintf("GET %s/foo: 500 broken; giving up after 4 attempt(s)", url)
// Send the request.
_, err = client.Do(context.Background(), req, nil)
if err == nil || !strings.HasSuffix(err.Error(), expectingErr) {
if err == nil || (err.Error() != expectingErr) {
t.Fatalf("expected giving up error, got: %#v", err)
}

}

func TestWithRetryAndBackoffsLogger(t *testing.T) {
Expand Down Expand Up @@ -701,6 +701,70 @@ func TestWithRetryAndBackoffsLogger(t *testing.T) {
}
}

func TestWithRetryAndBackoffsForResourceMethods(t *testing.T) {
// Mock server which always responds 500.
setup()
defer teardown()

url, _ := url.Parse(server.URL)
mux.HandleFunc("/v2/account", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(headerRateLimit, "500")
w.Header().Add(headerRateRemaining, "42")
w.Header().Add(headerRateReset, "1372700873")
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"id": "bad_request", "message": "broken"}`))
})

tokenSrc := oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "new_token",
})

oauthClient := oauth2.NewClient(context.TODO(), tokenSrc)

waitMax := PtrTo(6.0)
waitMin := PtrTo(3.0)

retryConfig := RetryConfig{
RetryMax: 3,
RetryWaitMin: waitMin,
RetryWaitMax: waitMax,
}

// Create the client. Use short retry windows so we fail faster.
client, err := New(oauthClient, WithRetryAndBackoffs(retryConfig))
client.BaseURL = url
if err != nil {
t.Fatalf("err: %v", err)
}

expectingErr := fmt.Sprintf("GET %s/v2/account: 500 broken; giving up after 4 attempt(s)", url)
_, resp, err := client.Account.Get(context.Background())
if err == nil || (err.Error() != expectingErr) {
t.Fatalf("expected giving up error, got: %s", err.Error())
}
if _, ok := err.(*ErrorResponse); !ok {
t.Fatalf("expected error to be *godo.ErrorResponse, got: %#v", err)
}

// Ensure that the *Response is properly populated
if resp == nil {
t.Fatal("expected non-nil *godo.Response")
}
if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected %d, got: %d", http.StatusInternalServerError, resp.StatusCode)
}
if expected := 500; resp.Rate.Limit != expected {
t.Errorf("expected rate limit to be populate: got: %v, expected: %v", resp.Rate.Limit, expected)
}
if expected := 42; resp.Rate.Remaining != expected {
t.Errorf("expected rate limit remaining to be populate: got: %v, expected: %v", resp.Rate.Remaining, expected)
}
reset := time.Date(2013, 7, 1, 17, 47, 53, 0, time.UTC)
if client.Rate.Reset.UTC() != reset {
t.Errorf("expected rate limit reset to be populate: got: %v, expected: %v", resp.Rate.Reset, reset)
}
}

func checkCurrentPage(t *testing.T, resp *Response, expectedPage int) {
links := resp.Links
p, err := links.CurrentPage()
Expand Down

0 comments on commit 21863bf

Please sign in to comment.