Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added floating point limit support #60

Merged
merged 1 commit into from Feb 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions limiter/limiter.go
Expand Up @@ -54,7 +54,7 @@ func New(generalExpirableOptions *ExpirableOptions) *Limiter {
// Limiter is a config struct to limit a particular request handler.
type Limiter struct {
// Maximum number of requests to limit per second.
max int64
max float64

// Limiter burst size
burst int
Expand Down Expand Up @@ -151,7 +151,7 @@ func (l *Limiter) GetHeaderEntryExpirationTTL() time.Duration {
}

// SetMax is thread-safe way of setting maximum number of requests to limit per duration.
func (l *Limiter) SetMax(max int64) *Limiter {
func (l *Limiter) SetMax(max float64) *Limiter {
l.Lock()
l.max = max
l.Unlock()
Expand All @@ -160,7 +160,7 @@ func (l *Limiter) SetMax(max int64) *Limiter {
}

// GetMax is thread-safe way of getting maximum number of requests to limit per duration.
func (l *Limiter) GetMax() int64 {
func (l *Limiter) GetMax() float64 {
l.RLock()
defer l.RUnlock()
return l.max
Expand Down
27 changes: 25 additions & 2 deletions limiter/limiter_test.go
Expand Up @@ -50,6 +50,29 @@ func TestLimitReached(t *testing.T) {
}
}

func TestFloatingLimitReached(t *testing.T) {
lmt := New(nil).SetMax(0.1).SetBurst(1)
key := "127.0.0.1|/"

if lmt.LimitReached(key) == true {
t.Error("First time count should not reached the limit.")
}

if lmt.LimitReached(key) == false {
t.Error("Second time count should return true because it exceeds 1 request per 10 seconds.")
}

<-time.After(7 * time.Second)
if lmt.LimitReached(key) == false {
t.Error("Third time count should return true because it exceeds 1 request per 10 seconds.")
}

<-time.After(3 * time.Second)
if lmt.LimitReached(key) == true {
t.Error("Fourth time count should not reached the limit because the 10 second window has passed.")
}
}

func TestLimitReachedWithCustomTokenBucketTTL(t *testing.T) {
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Second, ExpireJobInterval: 0}).SetMax(1).SetBurst(1)
key := "127.0.0.1|/"
Expand All @@ -71,7 +94,7 @@ func TestLimitReachedWithCustomTokenBucketTTL(t *testing.T) {
func TestMuchHigherMaxRequests(t *testing.T) {
numRequests := 1000
delay := (1 * time.Second) / time.Duration(numRequests)
lmt := New(nil).SetMax(int64(numRequests)).SetBurst(1)
lmt := New(nil).SetMax(float64(numRequests)).SetBurst(1)
key := "127.0.0.1|/"

for i := 0; i < numRequests; i++ {
Expand All @@ -90,7 +113,7 @@ func TestMuchHigherMaxRequests(t *testing.T) {
func TestMuchHigherMaxRequestsWithCustomTokenBucketTTL(t *testing.T) {
numRequests := 1000
delay := (1 * time.Second) / time.Duration(numRequests)
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Minute, ExpireJobInterval: time.Minute}).SetMax(int64(numRequests)).SetBurst(1)
lmt := New(&ExpirableOptions{DefaultExpirationTTL: time.Minute, ExpireJobInterval: time.Minute}).SetMax(float64(numRequests)).SetBurst(1)
key := "127.0.0.1|/"

for i := 0; i < numRequests; i++ {
Expand Down
9 changes: 5 additions & 4 deletions tollbooth.go
Expand Up @@ -3,25 +3,26 @@ package tollbooth

import (
"net/http"
"strconv"
"strings"

"fmt"
"github.com/didip/tollbooth/errors"
"github.com/didip/tollbooth/libstring"
"github.com/didip/tollbooth/limiter"
"math"
)

// setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration
func setResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-Rate-Limit-Limit", strconv.FormatInt(lmt.GetMax(), 10))
w.Header().Add("X-Rate-Limit-Limit", fmt.Sprintf("%.2f", lmt.GetMax()))
w.Header().Add("X-Rate-Limit-Duration", "1")
w.Header().Add("X-Rate-Limit-Request-Forwarded-For", r.Header.Get("X-Forwarded-For"))
w.Header().Add("X-Rate-Limit-Request-Remote-Addr", r.RemoteAddr)
}

// NewLimiter is a convenience function to limiter.New.
func NewLimiter(max int64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter {
return limiter.New(tbOptions).SetMax(max).SetBurst(int(max))
func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter {
return limiter.New(tbOptions).SetMax(max).SetBurst(int(math.Max(1, max)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this had to be done to allow rates below 1, otherwise SetBurst(0) would always result in exceeded rate limits.

}

// LimitByKeys keeps track number of request made by keys separated by pipe.
Expand Down