Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to specify response HTTP status code for Throttle middleware #571

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 14 additions & 6 deletions middleware/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ type ThrottleOpts struct {
Limit int
BacklogLimit int
BacklogTimeout time.Duration
StatusCode int
}

// Throttle is a middleware that limits number of currently processed requests
// at a time across all users. Note: Throttle is not a rate-limiter per user,
// instead it just puts a ceiling on the number of currently in-flight requests
// instead it just puts a ceiling on the number of current in-flight requests
// being processed from the point from where the Throttle middleware is mounted.
func Throttle(limit int) func(http.Handler) http.Handler {
return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout})
Expand All @@ -49,10 +50,16 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
panic("chi/middleware: Throttle expects backlogLimit to be positive")
}

statusCode := opts.StatusCode
if statusCode == 0 {
statusCode = http.StatusTooManyRequests
}

t := throttler{
tokens: make(chan token, opts.Limit),
backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
backlogTimeout: opts.BacklogTimeout,
statusCode: statusCode,
retryAfterFn: opts.RetryAfterFn,
}

Expand All @@ -72,7 +79,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

case <-ctx.Done():
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return

case btok := <-t.backlogTokens:
Expand All @@ -85,12 +92,12 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
select {
case <-timer.C:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errTimedOut, http.StatusTooManyRequests)
http.Error(w, errTimedOut, t.statusCode)
return
case <-ctx.Done():
timer.Stop()
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return
case tok := <-t.tokens:
defer func() {
Expand All @@ -103,7 +110,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

default:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errCapacityExceeded, http.StatusTooManyRequests)
http.Error(w, errCapacityExceeded, t.statusCode)
return
}
}
Expand All @@ -119,8 +126,9 @@ type token struct{}
type throttler struct {
tokens chan token
backlogTokens chan token
retryAfterFn func(ctxDone bool) time.Duration
backlogTimeout time.Duration
statusCode int
retryAfterFn func(ctxDone bool) time.Duration
}

// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.
Expand Down
55 changes: 51 additions & 4 deletions middleware/throttle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)

}(i)
}

Expand All @@ -136,7 +135,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -175,7 +173,6 @@ func TestThrottleMaximum(t *testing.T) {
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, testContent, buf)

}(i)
}

Expand All @@ -196,7 +193,6 @@ func TestThrottleMaximum(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -252,3 +248,54 @@ func TestThrottleMaximum(t *testing.T) {

wg.Wait()
}*/

func TestThrottleCustomStatusCode(t *testing.T) {
const timeout = time.Second * 3

wait := make(chan struct{})

r := chi.NewRouter()
r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable}))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
select {
case <-wait:
case <-time.After(timeout):
}
w.WriteHeader(http.StatusOK)
})
server := httptest.NewServer(r)
defer server.Close()

const totalRequestCount = 5

codes := make(chan int, totalRequestCount)
errs := make(chan error, totalRequestCount)
client := &http.Client{Timeout: timeout}
for i := 0; i < totalRequestCount; i++ {
go func() {
resp, err := client.Get(server.URL)
if err != nil {
errs <- err
return
}
codes <- resp.StatusCode
}()
}

waitResponse := func(wantCode int) {
select {
case err := <-errs:
t.Fatal(err)
case code := <-codes:
assertEqual(t, wantCode, code)
case <-time.After(timeout):
t.Fatalf("waiting %d code, timeout exceeded", wantCode)
}
}

for i := 0; i < totalRequestCount-1; i++ {
waitResponse(http.StatusServiceUnavailable)
}
close(wait) // Allow the last request to proceed.
waitResponse(http.StatusOK)
}