Skip to content

Commit

Permalink
Add ability to specify response HTTP status code for Throttle middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
vasayxtx committed Dec 28, 2023
1 parent ff1d3c6 commit bccb994
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 9 deletions.
18 changes: 13 additions & 5 deletions middleware/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type ThrottleOpts struct {
Limit int
BacklogLimit int
BacklogTimeout time.Duration
StatusCode int
}

// Throttle is a middleware that limits number of currently processed requests
Expand Down Expand Up @@ -49,10 +50,16 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
panic("chi/middleware: Throttle expects backlogLimit to be positive")
}

statusCode := opts.StatusCode
if statusCode == 0 {
statusCode = http.StatusTooManyRequests
}

t := throttler{
tokens: make(chan token, opts.Limit),
backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
backlogTimeout: opts.BacklogTimeout,
statusCode: statusCode,
retryAfterFn: opts.RetryAfterFn,
}

Expand All @@ -72,7 +79,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

case <-ctx.Done():
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return

case btok := <-t.backlogTokens:
Expand All @@ -85,12 +92,12 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
select {
case <-timer.C:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errTimedOut, http.StatusTooManyRequests)
http.Error(w, errTimedOut, t.statusCode)
return
case <-ctx.Done():
timer.Stop()
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return
case tok := <-t.tokens:
defer func() {
Expand All @@ -103,7 +110,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

default:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errCapacityExceeded, http.StatusTooManyRequests)
http.Error(w, errCapacityExceeded, t.statusCode)
return
}
}
Expand All @@ -119,8 +126,9 @@ type token struct{}
type throttler struct {
tokens chan token
backlogTokens chan token
retryAfterFn func(ctxDone bool) time.Duration
backlogTimeout time.Duration
statusCode int
retryAfterFn func(ctxDone bool) time.Duration
}

// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.
Expand Down
56 changes: 52 additions & 4 deletions middleware/throttle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)

}(i)
}

Expand All @@ -136,7 +135,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -175,7 +173,6 @@ func TestThrottleMaximum(t *testing.T) {
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, testContent, buf)

}(i)
}

Expand All @@ -196,7 +193,6 @@ func TestThrottleMaximum(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -252,3 +248,55 @@ func TestThrottleMaximum(t *testing.T) {
wg.Wait()
}*/

func TestThrottleCustomStatusCode(t *testing.T) {
const timeout = time.Second * 3

wait := make(chan struct{})

r := chi.NewRouter()
r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable}))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
select {
case <-wait:
case <-time.After(timeout):
}
w.WriteHeader(http.StatusOK)
})
server := httptest.NewServer(r)
defer server.Close()

const totalRequestCount = 5

client := http.Client{Timeout: timeout}

codes := make(chan int, totalRequestCount)
errs := make(chan error, totalRequestCount)
for i := 0; i < totalRequestCount; i++ {
go func() {
resp, err := client.Get(server.URL)
if err != nil {
errs <- err
return
}
codes <- resp.StatusCode
}()
}

waitResponse := func(wantCode int) {
select {
case err := <-errs:
t.Fatal(err)
case code := <-codes:
assertEqual(t, wantCode, code)
case <-time.After(timeout):
t.Fatalf("waiting %d code, timeout exceeded", wantCode)
}
}

for i := 0; i < totalRequestCount-1; i++ {
waitResponse(http.StatusServiceUnavailable)
}
close(wait)
waitResponse(http.StatusOK)
}

0 comments on commit bccb994

Please sign in to comment.