Skip to content

Commit

Permalink
🔥 Update: add timeout context middleware (#2090)
Browse files Browse the repository at this point in the history
* 🔥 Feature: add timeoutcontext middleware

* move timeoutconext to timeout package

* remove timeoutcontext readme.md

* replace timeout mware with timeout context mware

* Update README.md

* Update README.md

* update timeout middleware readme

* test curl commands fixed

* rename sample code title on timeout middleware

Co-authored-by: RW <rene@gofiber.io>
  • Loading branch information
hakankutluay and ReneWerner87 committed Sep 16, 2022
1 parent e829caf commit 7c83e38
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 86 deletions.
81 changes: 73 additions & 8 deletions middleware/timeout/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Timeout
Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.Handler` with a timeout. If the handler takes longer than the given duration to return, the timeout error is set and forwarded to the centralized [ErrorHandler](https://docs.gofiber.io/error-handling).
Timeout middleware for Fiber. As a `fiber.Handler` wrapper, it creates a context with `context.WithTimeout` and pass it in `UserContext`.

If the context passed executions (eg. DB ops, Http calls) takes longer than the given duration to return, the timeout error is set and forwarded to the centralized `ErrorHandler`.

It has no race conditions, ready to use on production.

### Table of Contents
- [Signatures](#signatures)
Expand All @@ -8,7 +12,7 @@ Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.

### Signatures
```go
func New(h fiber.Handler, t time.Duration) fiber.Handler
func New(handler fiber.Handler, timeout time.Duration, timeoutErrors ...error) fiber.Handler
```

### Examples
Expand All @@ -20,15 +24,76 @@ import (
)
```

After you initiate your Fiber app, you can use the following possibilities:
Sample timeout middleware usage
```go
handler := func(ctx *fiber.Ctx) error {
err := ctx.SendString("Hello, World 👋!")
if err != nil {
return err
func main() {
app := fiber.New()
h := func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}

app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second))
_ = app.Listen(":3000")
}

func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)

select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return context.DeadlineExceeded
case <-timer.C:
}
return nil
}
```

Test http 200 with curl:
```bash
curl --location -I --request GET 'http://localhost:3000/foo/1000'
```

Test http 408 with curl:
```bash
curl --location -I --request GET 'http://localhost:3000/foo/3000'
```


When using with custom error:
```go
var ErrFooTimeOut = errors.New("foo context canceled")

func main() {
app := fiber.New()
h := func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContextWithCustomError(c.UserContext(), sleepTime); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}

app.Get("/foo/:sleepTime", timeout.New(h, 2*time.Second), ErrFooTimeOut)
_ = app.Listen(":3000")
}

app.Get("/foo", timeout.New(handler, 5 * time.Second))
func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ErrFooTimeOut
case <-timer.C:
}
return nil
}
```
47 changes: 17 additions & 30 deletions middleware/timeout/timeout.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,30 @@
package timeout

import (
"fmt"
"sync"
"context"
"errors"
"time"

"github.com/gofiber/fiber/v2"
)

var once sync.Once

// New wraps a handler and aborts the process of the handler if the timeout is reached
func New(handler fiber.Handler, timeout time.Duration) fiber.Handler {
once.Do(func() {
fmt.Println("[Warning] timeout contains data race issues, not ready for production!")
})

if timeout <= 0 {
return handler
}

// logic is from fasthttp.TimeoutWithCodeHandler https://github.com/valyala/fasthttp/blob/master/server.go#L418
// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
return func(ctx *fiber.Ctx) error {
ch := make(chan struct{}, 1)

go func() {
defer func() {
_ = recover()
}()
_ = handler(ctx)
ch <- struct{}{}
}()

select {
case <-ch:
case <-time.After(timeout):
return fiber.ErrRequestTimeout
timeoutContext, cancel := context.WithTimeout(ctx.UserContext(), t)
defer cancel()
ctx.SetUserContext(timeoutContext)
if err := h(ctx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}
for i := range tErrs {
if errors.Is(err, tErrs[i]) {
return fiber.ErrRequestTimeout
}
}
return err
}

return nil
}
}
125 changes: 77 additions & 48 deletions middleware/timeout/timeout_test.go
Original file line number Diff line number Diff line change
@@ -1,55 +1,84 @@
package timeout

// // go test -run Test_Middleware_Timeout
// func Test_Middleware_Timeout(t *testing.T) {
// app := fiber.New(fiber.Config{DisableStartupMessage: true})
import (
"context"
"errors"
"fmt"
"net/http/httptest"
"testing"
"time"

// h := New(func(c *fiber.Ctx) error {
// sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
// time.Sleep(sleepTime)
// return c.SendString("After " + c.Params("sleepTime") + "ms sleeping")
// }, 5*time.Millisecond)
// app.Get("/test/:sleepTime", h)
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)

// testTimeout := func(timeoutStr string) {
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
// go test -run Test_Timeout
func Test_Timeout(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
}
return nil
}, 100*time.Millisecond)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "Request Timeout", string(body))
// }
// testSucces := func(timeoutStr string) {
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
var ErrFooTimeOut = errors.New("foo context canceled")

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "After "+timeoutStr+"ms sleeping", string(body))
// }
// go test -run Test_TimeoutWithCustomError
func Test_TimeoutWithCustomError(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

// testTimeout("15")
// testSucces("2")
// testTimeout("30")
// testSucces("3")
// }

// // go test -run -v Test_Timeout_Panic
// func Test_Timeout_Panic(t *testing.T) {
// app := fiber.New(fiber.Config{DisableStartupMessage: true})

// app.Get("/panic", recover.New(), New(func(c *fiber.Ctx) error {
// c.Set("dummy", "this should not be here")
// panic("panic in timeout handler")
// }, 5*time.Millisecond))

// resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "Request Timeout", string(body))
// }
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return te
case <-timer.C:
}
return nil
}

0 comments on commit 7c83e38

Please sign in to comment.