Skip to content

Commit

Permalink
Add WrapExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
lavoiesl committed Oct 18, 2021
1 parent 74e505c commit 2372bf0
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 2 deletions.
31 changes: 31 additions & 0 deletions client.go
Expand Up @@ -84,6 +84,12 @@ type (

// ErrorHook type is for reacting to request errors, called after all retries were attempted
ErrorHook func(*Request, error)

// Executor executes a Request
Executor func(req *Request) (*Response, error)

// ExecutorMiddleware type wraps the execution of a request
ExecutorMiddleware func(req *Request, next Executor) (*Response, error)
)

// Client struct is used to create Resty client with client level settings,
Expand Down Expand Up @@ -136,6 +142,7 @@ type Client struct {
requestLog RequestLogCallback
responseLog ResponseLogCallback
errorHooks []ErrorHook
executor Executor
}

// User type is to hold an username and password information
Expand Down Expand Up @@ -423,6 +430,28 @@ func (c *Client) OnError(h ErrorHook) *Client {
return c
}

// WrapExecutor wraps the execution of a request, granting full access to the request, response, and error.
// Runs on every request attempt, before any request hook and after any response or error hook.
// Can be useful to introduce throttling or add hooks that always fire, regardless of success or error.
//
// c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
// // do something with the Request
// // e.g. Acquire a lock
//
// resp, err := next(req)
// // do something with the Response or error
// // e.g. Release a lock
//
// return resp, err
// })
func (c *Client) WrapExecutor(e ExecutorMiddleware) *Client {
next := c.executor
c.executor = func(req *Request) (*Response, error) {
return e(req, next)
}
return c
}

// SetPreRequestHook method sets the given pre-request function into resty client.
// It is called right before the request is fired.
//
Expand Down Expand Up @@ -1068,6 +1097,8 @@ func createClient(hc *http.Client) *Client {
// Logger
c.SetLogger(createLogger())

c.executor = c.execute

// default before request middlewares
c.beforeRequest = []RequestMiddleware{
parseRequestURL,
Expand Down
39 changes: 39 additions & 0 deletions client_test.go
Expand Up @@ -735,6 +735,45 @@ func TestClientOnResponseError(t *testing.T) {
}
}

func TestWrapExecutor(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

t.Run("abort", func(t *testing.T) {
c := dc()
c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
return nil, fmt.Errorf("abort")
})

resp, err := c.R().Get(ts.URL)
assertNil(t, resp)
assertEqual(t, "abort", err.Error())
})

t.Run("noop", func(t *testing.T) {
c := dc()
c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
return next(req)
})

resp, err := c.R().Get(ts.URL)
assertNil(t, err)
assertEqual(t, 200, resp.StatusCode())
})

t.Run("add error", func(t *testing.T) {
c := dc()
c.WrapExecutor(func(req *Request, next Executor) (*Response, error) {
resp, _ := next(req)
return resp, fmt.Errorf("error")
})

resp, err := c.R().Get(ts.URL)
assertEqual(t, "error", err.Error())
assertEqual(t, 200, resp.StatusCode())
})
}

func TestResponseError(t *testing.T) {
err := errors.New("error message")
re := &ResponseError{
Expand Down
33 changes: 33 additions & 0 deletions example_test.go
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"os"
"strconv"
"sync"
"time"

"golang.org/x/net/proxy"
Expand Down Expand Up @@ -241,3 +242,35 @@ func Example_socks5Proxy() {
func printOutput(resp *resty.Response, err error) {
fmt.Println(resp, err)
}

//
// Throttling
//

func ExampleClient_throttling() {
// Consider the use of proper throttler, possibly waiting for resources to free up
// e.g. https://github.com/throttled/throttled or https://pkg.go.dev/golang.org/x/time/rate
var lock sync.Mutex
currentConcurrent := 0
maxConcurrent := 10

resty.New().WrapExecutor(func(req *resty.Request, next resty.Executor) (*resty.Response, error) {
lock.Lock()
current := currentConcurrent
if current == maxConcurrent {
lock.Unlock()
return nil, fmt.Errorf("max concurrency exceeded")
}

current++
lock.Unlock()

defer func() {
lock.Lock()
current--
lock.Unlock()
}()

return next(req)
})
}
4 changes: 2 additions & 2 deletions request.go
Expand Up @@ -754,7 +754,7 @@ func (r *Request) Execute(method, url string) (*Response, error) {
if r.client.RetryCount == 0 {
r.Attempt = 1
r.client.onErrorHooks(r, resp, err)
return r.client.execute(r)
return r.client.executor(r)
}

err = Backoff(
Expand All @@ -763,7 +763,7 @@ func (r *Request) Execute(method, url string) (*Response, error) {

r.URL = r.selectAddr(addrs, url, r.Attempt)

resp, err = r.client.execute(r)
resp, err = r.client.executor(r)
if err != nil {
r.client.log.Errorf("%v, Attempt %v", err, r.Attempt)
}
Expand Down

0 comments on commit 2372bf0

Please sign in to comment.