Skip to content

Commit

Permalink
added floating point limit support
Browse files Browse the repository at this point in the history
  • Loading branch information
David Kaufman committed Jan 30, 2018
1 parent b65680a commit 4f3653f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
6 changes: 3 additions & 3 deletions limiter/limiter.go
Expand Up @@ -54,7 +54,7 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter {
// Limiter is a config struct to limit a particular request handler.
type Limiter struct {
// Maximum number of requests to limit per second.
max int64
max float64

// Limiter burst size
burst int
Expand Down Expand Up @@ -151,7 +151,7 @@ func (l *Limiter) GetHeaderEntryExpirationTTL() time.Duration {
}

// SetMax is thread-safe way of setting maximum number of requests to limit per duration.
func (l *Limiter) SetMax(max int64) *Limiter {
func (l *Limiter) SetMax(max float64) *Limiter {
l.Lock()
l.max = max
l.Unlock()
Expand All @@ -160,7 +160,7 @@ func (l *Limiter) SetMax(max int64) *Limiter {
}

// GetMax is thread-safe way of getting maximum number of requests to limit per duration.
func (l *Limiter) GetMax() int64 {
func (l *Limiter) GetMax() float64 {
l.RLock()
defer l.RUnlock()
return l.max
Expand Down
27 changes: 25 additions & 2 deletions limiter/limiter_test.go
Expand Up @@ -50,6 +50,29 @@ func TestLimitReached(t *testing.T) {
}
}

func TestFloatingLimitReached(t *testing.T) {
lmt := New(nil).SetMax(0.1).SetBurst(1)
key := "127.0.0.1|/"

if lmt.LimitReached(key) == true {
t.Error("First time count should not reached the limit.")
}

if lmt.LimitReached(key) == false {
t.Error("Second time count should return true because it exceeds 1 request per 10 seconds.")
}

<-time.After(7 * time.Second)
if lmt.LimitReached(key) == false {
t.Error("Third time count should return true because it exceeds 1 request per 10 seconds.")
}

<-time.After(3 * time.Second)
if lmt.LimitReached(key) == true {
t.Error("Fourth time count should not reached the limit because the 10 second window has passed.")
}
}

func TestLimitReachedWithCustomTokenBucketTTL(t *testing.T) {
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Second, ExpireJobInterval: 0}).SetMax(1).SetBurst(1)
key := "127.0.0.1|/"
Expand All @@ -71,7 +94,7 @@ func TestLimitReachedWithCustomTokenBucketTTL(t *testing.T) {
func TestMuchHigherMaxRequests(t *testing.T) {
numRequests := 1000
delay := (1 * time.Second) / time.Duration(numRequests)
lmt := New(nil).SetMax(int64(numRequests)).SetBurst(1)
lmt := New(nil).SetMax(float64(numRequests)).SetBurst(1)
key := "127.0.0.1|/"

for i := 0; i < numRequests; i++ {
Expand All @@ -90,7 +113,7 @@ func TestMuchHigherMaxRequests(t *testing.T) {
func TestMuchHigherMaxRequestsWithCustomTokenBucketTTL(t *testing.T) {
numRequests := 1000
delay := (1 * time.Second) / time.Duration(numRequests)
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Minute, ExpireJobInterval: time.Minute}).SetMax(int64(numRequests)).SetBurst(1)
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Minute, ExpireJobInterval: time.Minute}).SetMax(float64(numRequests)).SetBurst(1)
key := "127.0.0.1|/"

for i := 0; i < numRequests; i++ {
Expand Down
8 changes: 4 additions & 4 deletions tollbooth.go
Expand Up @@ -3,25 +3,25 @@ package tollbooth

import (
"net/http"
"strconv"
"strings"

"fmt"
"github.com/didip/tollbooth/errors"
"github.com/didip/tollbooth/libstring"
"github.com/didip/tollbooth/limiter"
)

// setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration
func setResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-Rate-Limit-Limit", strconv.FormatInt(lmt.GetMax(), 10))
w.Header().Add("X-Rate-Limit-Limit", fmt.Sprintf("%.2f", lmt.GetMax()))
w.Header().Add("X-Rate-Limit-Duration", "1")
w.Header().Add("X-Rate-Limit-Request-Forwarded-For", r.Header.Get("X-Forwarded-For"))
w.Header().Add("X-Rate-Limit-Request-Remote-Addr", r.RemoteAddr)
}

// NewLimiter is a convenience function to limiter.New.
func NewLimiter(max int64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter {
return limiter.New(tbOptions).SetMax(max).SetBurst(int(max))
func NewLimiter(max float64, burst int, tbOptions *limiter.ExpirableOptions) *limiter.Limiter {
return limiter.New(tbOptions).SetMax(max).SetBurst(burst)
}

// LimitByKeys keeps track number of request made by keys separated by pipe.
Expand Down

0 comments on commit 4f3653f

Please sign in to comment.