Skip to content

Commit

Permalink
Add rate limiter to client (#715)
Browse files Browse the repository at this point in the history
* add rate limiter to client

* make rate limiter work for retries

* make rate limiter return error instead of blocking

* fix test

* use RateLimiter interface instead of x/time/rate

---------

Co-authored-by: David Linus Briemann <dlb@mailbox.org>
  • Loading branch information
SVilgelm and dbriemann committed Sep 30, 2023
1 parent 41199c3 commit e52a7e0
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 1 deletion.
16 changes: 16 additions & 0 deletions client.go
Expand Up @@ -152,6 +152,7 @@ type Client struct {
errorHooks []ErrorHook
invalidHooks []ErrorHook
panicHooks []ErrorHook
rateLimiter RateLimiter
}

// User type is to hold an username and password information
Expand Down Expand Up @@ -920,6 +921,13 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
return c
}

// SetRateLimiter sets an optional `RateLimiter`. If set the rate limiter will control
// all requests made with this client.
func (c *Client) SetRateLimiter(rl RateLimiter) *Client {
c.rateLimiter = rl
return c
}

// SetTransport method sets custom `*http.Transport` or any `http.RoundTripper`
// compatible interface implementation in the resty client.
//
Expand Down Expand Up @@ -1141,6 +1149,14 @@ func (c *Client) execute(req *Request) (*Response, error) {
}
}

// If there is a rate limiter set for this client, the Execute call
// will return an error if the rate limit is exceeded.
if req.client.rateLimiter != nil {
if !req.client.rateLimiter.Allow() {
return nil, wrapNoRetryErr(ErrRateLimitExceeded)
}
}

// resty middlewares
for _, f := range c.beforeRequest {
if err = f(c, req); err != nil {
Expand Down
5 changes: 4 additions & 1 deletion go.mod
Expand Up @@ -2,4 +2,7 @@ module github.com/go-resty/resty/v2

go 1.16

require golang.org/x/net v0.15.0
require (
golang.org/x/net v0.15.0
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11
)
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -33,6 +33,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
Expand Down
40 changes: 40 additions & 0 deletions request_test.go
Expand Up @@ -19,6 +19,8 @@ import (
"strings"
"testing"
"time"

"golang.org/x/time/rate"
)

type AuthSuccess struct {
Expand Down Expand Up @@ -66,6 +68,44 @@ func TestGetGH524(t *testing.T) {
assertEqual(t, resp.Request.Header.Get("Content-Type"), "") // unable to reproduce reported issue
}

func TestRateLimiter(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()

// Test a burst with a valid capacity and then a consecutive request that must fail.

// Allow a rate of 1 every 100 ms but also allow bursts of 10 requests.
client := dc().SetRateLimiter(rate.NewLimiter(rate.Every(100*time.Millisecond), 10))

// Execute a burst of 10 requests.
for i := 0; i < 10; i++ {
resp, err := client.R().
SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/")
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
}
// Next request issued directly should fail because burst of 10 has been consumed.
{
_, err := client.R().
SetQueryParam("request_no", strconv.Itoa(11)).Get(ts.URL + "/")
assertErrorIs(t, ErrRateLimitExceeded, err)
}

// Test continues request at a valid rate

// Allow a rate of 1 every ms with no burst.
client = dc().SetRateLimiter(rate.NewLimiter(rate.Every(1*time.Millisecond), 1))

// Sending requests every ms+tiny delta must succeed.
for i := 0; i < 100; i++ {
resp, err := client.R().
SetQueryParam("request_no", strconv.Itoa(i)).Get(ts.URL + "/")
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
time.Sleep(1*time.Millisecond + 100*time.Microsecond)
}
}

func TestIllegalRetryCount(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()
Expand Down
11 changes: 11 additions & 0 deletions util.go
Expand Up @@ -6,6 +6,7 @@ package resty

import (
"bytes"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -64,6 +65,16 @@ func (l *logger) output(format string, v ...interface{}) {
l.l.Printf(format, v...)
}

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// Rate Limiter interface
//_______________________________________________________________________

type RateLimiter interface {
Allow() bool
}

var ErrRateLimitExceeded = errors.New("rate limit exceeded")

//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
// Package Helper methods
//_______________________________________________________________________
Expand Down

0 comments on commit e52a7e0

Please sign in to comment.