Skip to content

Commit

Permalink
Add ability to specify response HTTP status code for Throttle middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
vasayxtx committed Dec 27, 2023
1 parent ff1d3c6 commit a62d96e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
18 changes: 13 additions & 5 deletions middleware/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type ThrottleOpts struct {
Limit int
BacklogLimit int
BacklogTimeout time.Duration
StatusCode int
}

// Throttle is a middleware that limits number of currently processed requests
Expand Down Expand Up @@ -49,10 +50,16 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
panic("chi/middleware: Throttle expects backlogLimit to be positive")
}

statusCode := opts.StatusCode
if statusCode == 0 {
statusCode = http.StatusTooManyRequests
}

t := throttler{
tokens: make(chan token, opts.Limit),
backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
backlogTimeout: opts.BacklogTimeout,
statusCode: statusCode,
retryAfterFn: opts.RetryAfterFn,
}

Expand All @@ -72,7 +79,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

case <-ctx.Done():
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return

case btok := <-t.backlogTokens:
Expand All @@ -85,12 +92,12 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
select {
case <-timer.C:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errTimedOut, http.StatusTooManyRequests)
http.Error(w, errTimedOut, t.statusCode)
return
case <-ctx.Done():
timer.Stop()
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return
case tok := <-t.tokens:
defer func() {
Expand All @@ -103,7 +110,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

default:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errCapacityExceeded, http.StatusTooManyRequests)
http.Error(w, errCapacityExceeded, t.statusCode)
return
}
}
Expand All @@ -119,8 +126,9 @@ type token struct{}
type throttler struct {
tokens chan token
backlogTokens chan token
retryAfterFn func(ctxDone bool) time.Duration
backlogTimeout time.Duration
statusCode int
retryAfterFn func(ctxDone bool) time.Duration
}

// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.
Expand Down
40 changes: 36 additions & 4 deletions middleware/throttle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)

}(i)
}

Expand All @@ -136,7 +135,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -175,7 +173,6 @@ func TestThrottleMaximum(t *testing.T) {
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, testContent, buf)

}(i)
}

Expand All @@ -196,7 +193,6 @@ func TestThrottleMaximum(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -252,3 +248,39 @@ func TestThrottleMaximum(t *testing.T) {
wg.Wait()
}*/

func TestThrottleCustomStatusCode(t *testing.T) {
block := make(chan struct{})

r := chi.NewRouter()
r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable}))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
block <- struct{}{}
block <- struct{}{}
w.Write(testContent)
})
server := httptest.NewServer(r)
defer server.Close()

client := http.Client{
Timeout: time.Second * 60, // Maximum waiting time.
}

done := make(chan struct{})

go func() {
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)
done <- struct{}{}
}()

<-block
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusServiceUnavailable, res.StatusCode)
<-block

<-done
}

0 comments on commit a62d96e

Please sign in to comment.