Skip to content

Commit

Permalink
Fix #66. When limiter is configured with header values,
Browse files Browse the repository at this point in the history
we should only limit request when its header is defined in limiter’s header values.
  • Loading branch information
didip committed Jul 15, 2019
1 parent cd91c82 commit be0cf69
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 52 deletions.
24 changes: 17 additions & 7 deletions tollbooth.go
Expand Up @@ -6,10 +6,11 @@ import (
"strings"

"fmt"
"math"

"github.com/didip/tollbooth/errors"
"github.com/didip/tollbooth/libstring"
"github.com/didip/tollbooth/limiter"
"math"
)

// setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration
Expand Down Expand Up @@ -67,9 +68,12 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
// If header values are not empty, rate-limit all request with headerKey and headerValues.
for _, headerValue := range headerValues {
username, _, ok := r.BasicAuth()
if ok && libstring.StringInSlice(lmtBasicAuthUsers, username) {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue, username})
if r.Header.Get(headerKey) == headerValue {
username, _, ok := r.BasicAuth()
if ok && libstring.StringInSlice(lmtBasicAuthUsers, username) {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue})
}
break
}
}
}
Expand All @@ -85,9 +89,12 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey})

} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
// If header values are not empty, rate-limit all request with headerKey and headerValues.
// We are only limiting if request's header value is defined inside `headerValues`.
for _, headerValue := range headerValues {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue})
if r.Header.Get(headerKey) == headerValue {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue})
break
}
}
}
}
Expand Down Expand Up @@ -118,7 +125,10 @@ func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
// If header values are not empty, rate-limit all request with headerKey and headerValues.
for _, headerValue := range headerValues {
sliceKeys = append(sliceKeys, []string{remoteIP, path, headerKey, headerValue})
if r.Header.Get(headerKey) == headerValue {
sliceKeys = append(sliceKeys, []string{remoteIP, path, headerKey, headerValue})
break
}
}
}
}
Expand Down
103 changes: 58 additions & 45 deletions tollbooth_bug_report_test.go
@@ -1,6 +1,3 @@
// +build slow
// How to run: go test -tags=slow

package tollbooth

import (
Expand Down Expand Up @@ -57,62 +54,78 @@ Top:
}
}

func issue66HeaderKey() string {
return "X-Customer-ID"
}
var issue66HeaderKey = "X-Customer-ID"

func issue66RateLimiter(h http.HandlerFunc) http.HandlerFunc {
allocationLimiter := NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Minute}).
SetMethods([]string{"POST"})
func issue66RateLimiter(h http.HandlerFunc, customerIDs []string) (http.HandlerFunc, *limiter.Limiter) {
allocationLimiter := NewLimiter(1, nil).SetMethods([]string{"POST"})

handler := func(w http.ResponseWriter, r *http.Request) {
customerID := r.Header.Get(issue66HeaderKey())
allocationLimiter.SetHeader(issue66HeaderKey(), []string{customerID})

allocationLimiter.SetHeader(issue66HeaderKey, customerIDs)
LimitFuncHandler(allocationLimiter, h).ServeHTTP(w, r)
}

return handler
return handler, allocationLimiter
}

// See: https://github.com/didip/tollbooth/issues/66
func Test_Issue66_CustomRateLimitByHeaderValues(t *testing.T) {
customerID1 := "1234"
customerID2 := "5678"

tests := []struct {
name string
secondRequestStatus int
customerIDs []string
}{
{"different customer id", http.StatusOK, []string{customerID1, customerID2}},
{"same customer id", http.StatusTooManyRequests, []string{customerID1, customerID1}},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

testServer := httptest.NewServer(issue66RateLimiter(h))
defer testServer.Close()

request1, _ := http.NewRequest("POST", testServer.URL, nil)
request1.Header.Add(issue66HeaderKey(), tt.customerIDs[0])

client := &http.Client{}

response, _ := client.Do(request1)
if response.StatusCode != http.StatusOK {
t.Errorf("Unexpected status code. Got: %v, expected: %v", response.StatusCode, http.StatusOK)
}
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

h, allocationLimiter := issue66RateLimiter(h, []string{customerID1, customerID2})
testServer := httptest.NewServer(h)
defer testServer.Close()

client := &http.Client{}

// subtest 1:
// There are 2 different customer ids,
// both should pass rate limiter the 1st time and failed the second time.
request1, _ := http.NewRequest("POST", testServer.URL, nil)
request1.Header.Add(issue66HeaderKey, customerID1)

request2, _ := http.NewRequest("POST", testServer.URL, nil)
request2.Header.Add(issue66HeaderKey, customerID2)

for _, request := range []*http.Request{request1} {
// 1st, 200
response, _ := client.Do(request)
if response.StatusCode != http.StatusOK {
t.Fatalf(`
Customer %v must pass rate limiter the first time.
Expected to receive: %v status code. Got: %v.
limiter.headers: %v`,
request.Header.Get(issue66HeaderKey),
http.StatusOK, response.StatusCode,
allocationLimiter.GetHeaders())
}

request2, _ := http.NewRequest("POST", testServer.URL, nil)
request2.Header.Add(issue66HeaderKey(), tt.customerIDs[1])
// 2nd, 429
response, _ = client.Do(request)
if response.StatusCode != http.StatusTooManyRequests {
t.Fatalf(`Both customer must pass rate limiter.
Expected to receive: %v status code. Got: %v`,
http.StatusTooManyRequests, response.StatusCode)
}
}

response, _ = client.Do(request2)
if response.StatusCode != tt.secondRequestStatus {
t.Errorf("Unexpected status code. Got: %v, expected: %v. Customers: %v", response.StatusCode, tt.secondRequestStatus, tt.customerIDs)
}
})
// subtest 2:
// There is 1 customer not defined in rate limiter.
// S/he should not be rate limited.
request3, _ := http.NewRequest("POST", testServer.URL, nil)
request3.Header.Add(issue66HeaderKey, "777")

for i := 0; i < 2; i++ {
response, _ := client.Do(request3)

if response.StatusCode != http.StatusOK {
t.Fatalf(`
Customer %v must always pass rate limiter.
Expected to receive: %v status code. Got: %v`,
request3.Header.Get(issue66HeaderKey),
http.StatusOK, response.StatusCode)
}
}
}

0 comments on commit be0cf69

Please sign in to comment.