diff --git a/limiter/limiter.go b/limiter/limiter.go index 5fdb6fd..1d6cfdd 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -54,7 +54,7 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter { // Limiter is a config struct to limit a particular request handler. type Limiter struct { // Maximum number of requests to limit per second. - max int64 + max float64 // Limiter burst size burst int @@ -151,7 +151,7 @@ func (l *Limiter) GetHeaderEntryExpirationTTL() time.Duration { } // SetMax is thread-safe way of setting maximum number of requests to limit per duration. -func (l *Limiter) SetMax(max int64) *Limiter { +func (l *Limiter) SetMax(max float64) *Limiter { l.Lock() l.max = max l.Unlock() @@ -160,7 +160,7 @@ func (l *Limiter) SetMax(max int64) *Limiter { } // GetMax is thread-safe way of getting maximum number of requests to limit per duration. -func (l *Limiter) GetMax() int64 { +func (l *Limiter) GetMax() float64 { l.RLock() defer l.RUnlock() return l.max diff --git a/limiter/limiter_test.go b/limiter/limiter_test.go index 664db8a..5db973f 100644 --- a/limiter/limiter_test.go +++ b/limiter/limiter_test.go @@ -50,6 +50,29 @@ func TestLimitReached(t *testing.T) { } } +func TestFloatingLimitReached(t *testing.T) { + lmt := New(nil).SetMax(0.1).SetBurst(1) + key := "127.0.0.1|/" + + if lmt.LimitReached(key) == true { + t.Error("First time count should not reached the limit.") + } + + if lmt.LimitReached(key) == false { + t.Error("Second time count should return true because it exceeds 1 request per 10 seconds.") + } + + <-time.After(7 * time.Second) + if lmt.LimitReached(key) == false { + t.Error("Third time count should return true because it exceeds 1 request per 10 seconds.") + } + + <-time.After(3 * time.Second) + if lmt.LimitReached(key) == true { + t.Error("Fourth time count should not reached the limit because the 10 second window has passed.") + } +} + func TestLimitReachedWithCustomTokenBucketTTL(t *testing.T) { lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Second, ExpireJobInterval: 0}).SetMax(1).SetBurst(1) key := "127.0.0.1|/" @@ -71,7 +94,7 @@ func TestLimitReachedWithCustomTokenBucketTTL(t *testing.T) { func TestMuchHigherMaxRequests(t *testing.T) { numRequests := 1000 delay := (1 * time.Second) / time.Duration(numRequests) - lmt := New(nil).SetMax(int64(numRequests)).SetBurst(1) + lmt := New(nil).SetMax(float64(numRequests)).SetBurst(1) key := "127.0.0.1|/" for i := 0; i < numRequests; i++ { @@ -90,7 +113,7 @@ func TestMuchHigherMaxRequests(t *testing.T) { func TestMuchHigherMaxRequestsWithCustomTokenBucketTTL(t *testing.T) { numRequests := 1000 delay := (1 * time.Second) / time.Duration(numRequests) - lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Minute, ExpireJobInterval: time.Minute}).SetMax(int64(numRequests)).SetBurst(1) + lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Minute, ExpireJobInterval: time.Minute}).SetMax(float64(numRequests)).SetBurst(1) key := "127.0.0.1|/" for i := 0; i < numRequests; i++ { diff --git a/tollbooth.go b/tollbooth.go index a557b2b..c023db2 100644 --- a/tollbooth.go +++ b/tollbooth.go @@ -3,25 +3,26 @@ package tollbooth import ( "net/http" - "strconv" "strings" + "fmt" "github.com/didip/tollbooth/errors" "github.com/didip/tollbooth/libstring" "github.com/didip/tollbooth/limiter" + "math" ) // setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration func setResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request) { - w.Header().Add("X-Rate-Limit-Limit", strconv.FormatInt(lmt.GetMax(), 10)) + w.Header().Add("X-Rate-Limit-Limit", fmt.Sprintf("%.2f", lmt.GetMax())) w.Header().Add("X-Rate-Limit-Duration", "1") w.Header().Add("X-Rate-Limit-Request-Forwarded-For", r.Header.Get("X-Forwarded-For")) w.Header().Add("X-Rate-Limit-Request-Remote-Addr", r.RemoteAddr) } // NewLimiter is a convenience function to limiter.New. -func NewLimiter(max int64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter { - return limiter.New(tbOptions).SetMax(max).SetBurst(int(max)) +func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter { + return limiter.New(tbOptions).SetMax(max).SetBurst(int(math.Max(1, max))) } // LimitByKeys keeps track number of request made by keys separated by pipe.