diff --git a/middleware/limiter/README.md b/middleware/limiter/README.md index ead65538aa..8682dd11ad 100644 --- a/middleware/limiter/README.md +++ b/middleware/limiter/README.md @@ -109,6 +109,16 @@ type Config struct { // } LimitReached fiber.Handler + // When set to true, requests with StatusCode >= 400 won't be counted. + // + // Default: false + SkipFailedRequests bool + + // When set to true, requests with StatusCode < 400 won't be counted. + // + // Default: false + SkipSuccessfulRequests bool + // Store is used to store the state of the middleware // // Default: an in memory store for this process only @@ -130,5 +140,7 @@ var ConfigDefault = Config{ LimitReached: func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTooManyRequests) }, + SkipFailedRequests: false, + SkipSuccessfulRequests: false, } ``` diff --git a/middleware/limiter/config.go b/middleware/limiter/config.go index a08759a400..1634b33a3c 100644 --- a/middleware/limiter/config.go +++ b/middleware/limiter/config.go @@ -38,6 +38,16 @@ type Config struct { // } LimitReached fiber.Handler + // When set to true, requests with StatusCode >= 400 won't be counted. + // + // Default: false + SkipFailedRequests bool + + // When set to true, requests with StatusCode < 400 won't be counted. + // + // Default: false + SkipSuccessfulRequests bool + // Store is used to store the state of the middleware // // Default: an in memory store for this process only @@ -63,6 +73,8 @@ var ConfigDefault = Config{ LimitReached: func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTooManyRequests) }, + SkipFailedRequests: false, + SkipSuccessfulRequests: false, } // Helper function to set default values diff --git a/middleware/limiter/limiter.go b/middleware/limiter/limiter.go index de75504be0..59545675ab 100644 --- a/middleware/limiter/limiter.go +++ b/middleware/limiter/limiter.go @@ -50,6 +50,10 @@ func New(config ...Config) fiber.Handler { return c.Next() } + // Continue stack for reaching c.Response().StatusCode() + // Store err for returning + err := c.Next() + // Get key from request key := cfg.KeyGenerator(c) @@ -72,8 +76,12 @@ func New(config ...Config) fiber.Handler { e.exp = ts + expiration } - // Increment hits - e.hits++ + // Check for SkipFailedRequests and SkipSuccessfulRequests + if (!cfg.SkipSuccessfulRequests || c.Response().StatusCode() >= 400) && + (!cfg.SkipFailedRequests || c.Response().StatusCode() < 400) { + // Increment hits + e.hits++ + } // Calculate when it resets in seconds expire := e.exp - ts @@ -102,7 +110,6 @@ func New(config ...Config) fiber.Handler { c.Set(xRateLimitRemaining, strconv.Itoa(remaining)) c.Set(xRateLimitReset, strconv.FormatUint(expire, 10)) - // Continue stack - return c.Next() + return err } } diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index f8e12c479f..af61901806 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -14,7 +14,7 @@ import ( "github.com/valyala/fasthttp" ) -// go test -run Test_Limiter_Concurrency -race -v +// go test -run Test_Limiter_Concurrency_Store -race -v func Test_Limiter_Concurrency_Store(t *testing.T) { // Test concurrency using a custom store @@ -107,6 +107,84 @@ func Test_Limiter_Concurrency(t *testing.T) { } +// go test -run Test_Limiter_Skip_Failed_Requests -v +func Test_Limiter_Skip_Failed_Requests(t *testing.T) { + + app := fiber.New() + + app.Use(New(Config{ + Max: 1, + Expiration: 2 * time.Second, + SkipFailedRequests: true, + })) + + app.Get("/:status", func(c *fiber.Ctx) error { + if c.Params("status") == "fail" { + return c.SendStatus(400) + } + return c.SendStatus(200) + }) + + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 400, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 429, resp.StatusCode) + + time.Sleep(3 * time.Second) + + resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + +} + +// go test -run Test_Limiter_Skip_Successful_Requests -v +func Test_Limiter_Skip_Successful_Requests(t *testing.T) { + + // Test concurrency using a default store + + app := fiber.New() + + app.Use(New(Config{ + Max: 1, + Expiration: 2 * time.Second, + SkipSuccessfulRequests: true, + })) + + app.Get("/:status", func(c *fiber.Ctx) error { + if c.Params("status") == "fail" { + return c.SendStatus(400) + } + return c.SendStatus(200) + }) + + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/success", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 400, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 429, resp.StatusCode) + + time.Sleep(3 * time.Second) + + resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 400, resp.StatusCode) + +} + // go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4 func Benchmark_Limiter_Custom_Store(b *testing.B) { app := fiber.New()