Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WrapExecutor #477

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 39 additions & 8 deletions client.go
Expand Up @@ -85,6 +85,12 @@ type (

// ErrorHook type is for reacting to request errors, called after all retries were attempted
ErrorHook func(*Request, error)

// Executor executes a Request
Executor func(req *Request) (*Response, error)

// ExecutorMiddleware type wraps the execution of a request
ExecutorMiddleware func(req *Request, next Executor) (*Response, error)
)

// Client struct is used to create Resty client with client level settings,
Expand Down Expand Up @@ -140,6 +146,7 @@ type Client struct {
requestLog RequestLogCallback
responseLog ResponseLogCallback
errorHooks []ErrorHook
executor Executor
}

// User type is to hold an username and password information
Expand Down Expand Up @@ -444,6 +451,28 @@ func (c *Client) OnError(h ErrorHook) *Client {
return c
}

// WrapExecutor wraps the execution of a request, granting full access to the request, response, and error.
// Runs on every request attempt, before any request hook and after any response or error hook.
// Can be useful to introduce throttling or add hooks that always fire, regardless of success or error.
//
// c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
// // do something with the Request
// // e.g. Acquire a lock
//
// resp, err := next(req)
// // do something with the Response or error
// // e.g. Release a lock
//
// return resp, err
// })
func (c *Client) WrapExecutor(e ExecutorMiddleware) *Client {
next := c.executor
c.executor = func(req *Request) (*Response, error) {
return e(req, next)
}
return c
}

// SetPreRequestHook method sets the given pre-request function into resty client.
// It is called right before the request is fired.
//
Expand Down Expand Up @@ -900,14 +929,14 @@ func (c *Client) execute(req *Request) (*Response, error) {
// to modify the *resty.Request object
for _, f := range c.udBeforeRequest {
if err = f(c, req); err != nil {
return nil, wrapNoRetryErr(err)
return nil, err
}
}

// resty middlewares
for _, f := range c.beforeRequest {
if err = f(c, req); err != nil {
return nil, wrapNoRetryErr(err)
return nil, err
}
}

Expand All @@ -918,12 +947,12 @@ func (c *Client) execute(req *Request) (*Response, error) {
// call pre-request if defined
if c.preReqHook != nil {
if err = c.preReqHook(c, req.RawRequest); err != nil {
return nil, wrapNoRetryErr(err)
return nil, err
}
}

if err = requestLogger(c, req); err != nil {
return nil, wrapNoRetryErr(err)
return nil, err
}

req.RawRequest.Body = newRequestBodyReleaser(req.RawRequest.Body, req.bodyBuf)
Expand All @@ -938,7 +967,7 @@ func (c *Client) execute(req *Request) (*Response, error) {

if err != nil || req.notParseResponse || c.notParseResponse {
response.setReceivedAt()
return response, err
return response, wrapTemporaryError(err)
}

if !req.isSaveResponse {
Expand All @@ -951,15 +980,15 @@ func (c *Client) execute(req *Request) (*Response, error) {
body, err = gzip.NewReader(body)
if err != nil {
response.setReceivedAt()
return response, err
return response, wrapTemporaryError(err)
}
defer closeq(body)
}
}

if response.body, err = ioutil.ReadAll(body); err != nil {
response.setReceivedAt()
return response, err
return response, wrapTemporaryError(err)
}

response.size = int64(len(response.body))
Expand All @@ -974,7 +1003,7 @@ func (c *Client) execute(req *Request) (*Response, error) {
}
}

return response, wrapNoRetryErr(err)
return response, err
}

// getting TLS client config if not exists then create one
Expand Down Expand Up @@ -1092,6 +1121,8 @@ func createClient(hc *http.Client) *Client {
// Logger
c.SetLogger(createLogger())

c.executor = c.execute

// default before request middlewares
c.beforeRequest = []RequestMiddleware{
parseRequestURL,
Expand Down
39 changes: 39 additions & 0 deletions client_test.go
Expand Up @@ -735,6 +735,45 @@ func TestClientOnResponseError(t *testing.T) {
}
}

func TestWrapExecutor(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

t.Run("abort", func(t *testing.T) {
c := dc()
c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
return nil, fmt.Errorf("abort")
})

resp, err := c.R().Get(ts.URL)
assertNil(t, resp)
assertEqual(t, "abort", err.Error())
})

t.Run("noop", func(t *testing.T) {
c := dc()
c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
return next(req)
})

resp, err := c.R().Get(ts.URL)
assertNil(t, err)
assertEqual(t, 200, resp.StatusCode())
})

t.Run("add error", func(t *testing.T) {
c := dc()
c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
resp, _ := next(req)
return resp, fmt.Errorf("error")
})

resp, err := c.R().Get(ts.URL)
assertEqual(t, "error", err.Error())
assertEqual(t, 200, resp.StatusCode())
})
}

func TestResponseError(t *testing.T) {
err := errors.New("error message")
re := &ResponseError{
Expand Down
33 changes: 33 additions & 0 deletions example_test.go
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"os"
"strconv"
"sync"
"time"

"golang.org/x/net/proxy"
Expand Down Expand Up @@ -241,3 +242,35 @@ func Example_socks5Proxy() {
func printOutput(resp *resty.Response, err error) {
fmt.Println(resp, err)
}

//
// Throttling
//

func ExampleClient_throttling() {
// Consider the use of proper throttler, possibly waiting for resources to free up
// e.g. https://github.com/throttled/throttled or https://pkg.go.dev/golang.org/x/time/rate
var lock sync.Mutex
currentConcurrent := 0
maxConcurrent := 10

resty.New().WrapExecutor(func(req *resty.Request, next resty.Executor) (*resty.Response, error) {
lock.Lock()
current := currentConcurrent
if current == maxConcurrent {
lock.Unlock()
return nil, fmt.Errorf("max concurrency exceeded")
}

current++
lock.Unlock()

defer func() {
lock.Lock()
current--
lock.Unlock()
}()

return next(req)
})
}
14 changes: 7 additions & 7 deletions request.go
Expand Up @@ -745,7 +745,7 @@ func (r *Request) Execute(method, url string) (*Response, error) {
if r.SRV != nil {
_, addrs, err = net.LookupSRV(r.SRV.Service, "tcp", r.SRV.Domain)
if err != nil {
r.client.onErrorHooks(r, nil, err)
r.client.onErrorHooks(r, resp, err)
return nil, err
}
}
Expand All @@ -755,9 +755,9 @@ func (r *Request) Execute(method, url string) (*Response, error) {

if r.client.RetryCount == 0 {
r.Attempt = 1
resp, err = r.client.execute(r)
r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err))
return resp, unwrapNoRetryErr(err)
resp, err = r.client.executor(r)
r.client.onErrorHooks(r, resp, err)
return resp, err
}

err = Backoff(
Expand All @@ -766,7 +766,7 @@ func (r *Request) Execute(method, url string) (*Response, error) {

r.URL = r.selectAddr(addrs, url, r.Attempt)

resp, err = r.client.execute(r)
resp, err = r.client.executor(r)
if err != nil {
r.client.log.Errorf("%v, Attempt %v", err, r.Attempt)
}
Expand All @@ -780,9 +780,9 @@ func (r *Request) Execute(method, url string) (*Response, error) {
RetryHooks(r.client.RetryHooks),
)

r.client.onErrorHooks(r, resp, unwrapNoRetryErr(err))
r.client.onErrorHooks(r, resp, err)

return resp, unwrapNoRetryErr(err)
return resp, err
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
Expand Down
5 changes: 2 additions & 3 deletions retry.go
Expand Up @@ -111,11 +111,10 @@ func Backoff(operation func() (*Response, error), options ...Option) error {
return err
}

err1 := unwrapNoRetryErr(err) // raw error, it used for return users callback.
needsRetry := err != nil && err == err1 // retry on a few operation errors by default
needsRetry := isTemporaryError(err) // retry on temporary errors by default

for _, condition := range opts.retryConditions {
needsRetry = condition(resp, err1)
needsRetry = condition(resp, err)
if needsRetry {
break
}
Expand Down
4 changes: 2 additions & 2 deletions retry_test.go
Expand Up @@ -22,7 +22,7 @@ func TestBackoffSuccess(t *testing.T) {
retryErr := Backoff(func() (*Response, error) {
externalCounter++
if externalCounter < attempts {
return nil, errors.New("not yet got the number we're after")
return nil, wrapTemporaryError(errors.New("not yet got the number we're after"))
}

return nil, nil
Expand Down Expand Up @@ -71,7 +71,7 @@ func TestBackoffTenAttemptsSuccess(t *testing.T) {
retryErr := Backoff(func() (*Response, error) {
externalCounter++
if externalCounter < attempts {
return nil, errors.New("not yet got the number we're after")
return nil, wrapTemporaryError(errors.New("not yet got the number we're after"))
}
return nil, nil
}, Retries(attempts), WaitTime(5), MaxWaitTime(500))
Expand Down
35 changes: 25 additions & 10 deletions util.go
Expand Up @@ -6,6 +6,7 @@ package resty

import (
"bytes"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -368,24 +369,38 @@ func copyHeaders(hdrs http.Header) http.Header {
return nh
}

type noRetryErr struct {
type temporaryError struct {
err error
}

func (e *noRetryErr) Error() string {
func (e *temporaryError) Error() string {
return e.err.Error()
}

func wrapNoRetryErr(err error) error {
if err != nil {
err = &noRetryErr{err: err}
func (e *temporaryError) Unwrap() error {
return e.err
}

func (e *temporaryError) Temporary() bool {
return true
}

// wrapTemporaryError wraps an error to advertise it should be retryable, if it doesn't specify it already.
func wrapTemporaryError(err error) error {
if err == nil {
return nil
}
return err
var tempError interface{ Temporary() bool }
if errors.As(err, &tempError) {
return err // Already exposes the method, honour it, even if false
}
return &temporaryError{err}
}

func unwrapNoRetryErr(err error) error {
if e, ok := err.(*noRetryErr); ok {
err = e.err
func isTemporaryError(err error) bool {
var tempError interface{ Temporary() bool }
if errors.As(err, &tempError) {
return tempError.Temporary()
}
return err
return false
}
21 changes: 21 additions & 0 deletions util_test.go
Expand Up @@ -6,7 +6,9 @@ package resty

import (
"bytes"
"errors"
"mime/multipart"
"net"
"testing"
)

Expand Down Expand Up @@ -95,3 +97,22 @@ func TestWriteMultipartFormFileReaderError(t *testing.T) {
assertNotNil(t, err)
assertEqual(t, "read error", err.Error())
}

func Test_wrapTemporaryError(t *testing.T) {
tests := []struct {
name string
base error
temp bool
}{
{name: "nil", temp: false},
{name: "dns temp", base: &net.DNSError{Err: "err", IsTemporary: true}, temp: true},
{name: "dns not temp", base: &net.DNSError{Err: "err"}, temp: false},
{name: "other", base: errors.New("foo"), temp: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := wrapTemporaryError(tt.base)
assertEqual(t, tt.temp, isTemporaryError(err))
})
}
}