-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🔥 Update: add timeout context middleware (#2090)
* 🔥 Feature: add timeoutcontext middleware * move timeoutconext to timeout package * remove timeoutcontext readme.md * replace timeout mware with timeout context mware * Update README.md * Update README.md * update timeout middleware readme * test curl commands fixed * rename sample code title on timeout middleware Co-authored-by: RW <rene@gofiber.io>
- Loading branch information
1 parent
e829caf
commit 7c83e38
Showing
3 changed files
with
167 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,30 @@ | ||
package timeout | ||
|
||
import ( | ||
"fmt" | ||
"sync" | ||
"context" | ||
"errors" | ||
"time" | ||
|
||
"github.com/gofiber/fiber/v2" | ||
) | ||
|
||
var once sync.Once | ||
|
||
// New wraps a handler and aborts the process of the handler if the timeout is reached | ||
func New(handler fiber.Handler, timeout time.Duration) fiber.Handler { | ||
once.Do(func() { | ||
fmt.Println("[Warning] timeout contains data race issues, not ready for production!") | ||
}) | ||
|
||
if timeout <= 0 { | ||
return handler | ||
} | ||
|
||
// logic is from fasthttp.TimeoutWithCodeHandler https://github.com/valyala/fasthttp/blob/master/server.go#L418 | ||
// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response. | ||
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler { | ||
return func(ctx *fiber.Ctx) error { | ||
ch := make(chan struct{}, 1) | ||
|
||
go func() { | ||
defer func() { | ||
_ = recover() | ||
}() | ||
_ = handler(ctx) | ||
ch <- struct{}{} | ||
}() | ||
|
||
select { | ||
case <-ch: | ||
case <-time.After(timeout): | ||
return fiber.ErrRequestTimeout | ||
timeoutContext, cancel := context.WithTimeout(ctx.UserContext(), t) | ||
defer cancel() | ||
ctx.SetUserContext(timeoutContext) | ||
if err := h(ctx); err != nil { | ||
if errors.Is(err, context.DeadlineExceeded) { | ||
return fiber.ErrRequestTimeout | ||
} | ||
for i := range tErrs { | ||
if errors.Is(err, tErrs[i]) { | ||
return fiber.ErrRequestTimeout | ||
} | ||
} | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,84 @@ | ||
package timeout | ||
|
||
// // go test -run Test_Middleware_Timeout | ||
// func Test_Middleware_Timeout(t *testing.T) { | ||
// app := fiber.New(fiber.Config{DisableStartupMessage: true}) | ||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
// h := New(func(c *fiber.Ctx) error { | ||
// sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") | ||
// time.Sleep(sleepTime) | ||
// return c.SendString("After " + c.Params("sleepTime") + "ms sleeping") | ||
// }, 5*time.Millisecond) | ||
// app.Get("/test/:sleepTime", h) | ||
"github.com/gofiber/fiber/v2" | ||
"github.com/gofiber/fiber/v2/utils" | ||
) | ||
|
||
// testTimeout := func(timeoutStr string) { | ||
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) | ||
// utils.AssertEqual(t, nil, err, "app.Test(req)") | ||
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") | ||
// go test -run Test_Timeout | ||
func Test_Timeout(t *testing.T) { | ||
// fiber instance | ||
app := fiber.New() | ||
h := New(func(c *fiber.Ctx) error { | ||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") | ||
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil { | ||
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err)) | ||
} | ||
return nil | ||
}, 100*time.Millisecond) | ||
app.Get("/test/:sleepTime", h) | ||
testTimeout := func(timeoutStr string) { | ||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) | ||
utils.AssertEqual(t, nil, err, "app.Test(req)") | ||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") | ||
} | ||
testSucces := func(timeoutStr string) { | ||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) | ||
utils.AssertEqual(t, nil, err, "app.Test(req)") | ||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") | ||
} | ||
testTimeout("300") | ||
testTimeout("500") | ||
testSucces("50") | ||
testSucces("30") | ||
} | ||
|
||
// body, err := ioutil.ReadAll(resp.Body) | ||
// utils.AssertEqual(t, nil, err) | ||
// utils.AssertEqual(t, "Request Timeout", string(body)) | ||
// } | ||
// testSucces := func(timeoutStr string) { | ||
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) | ||
// utils.AssertEqual(t, nil, err, "app.Test(req)") | ||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") | ||
var ErrFooTimeOut = errors.New("foo context canceled") | ||
|
||
// body, err := ioutil.ReadAll(resp.Body) | ||
// utils.AssertEqual(t, nil, err) | ||
// utils.AssertEqual(t, "After "+timeoutStr+"ms sleeping", string(body)) | ||
// } | ||
// go test -run Test_TimeoutWithCustomError | ||
func Test_TimeoutWithCustomError(t *testing.T) { | ||
// fiber instance | ||
app := fiber.New() | ||
h := New(func(c *fiber.Ctx) error { | ||
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") | ||
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil { | ||
return fmt.Errorf("%w: execution error", err) | ||
} | ||
return nil | ||
}, 100*time.Millisecond, ErrFooTimeOut) | ||
app.Get("/test/:sleepTime", h) | ||
testTimeout := func(timeoutStr string) { | ||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) | ||
utils.AssertEqual(t, nil, err, "app.Test(req)") | ||
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") | ||
} | ||
testSucces := func(timeoutStr string) { | ||
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil)) | ||
utils.AssertEqual(t, nil, err, "app.Test(req)") | ||
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code") | ||
} | ||
testTimeout("300") | ||
testTimeout("500") | ||
testSucces("50") | ||
testSucces("30") | ||
} | ||
|
||
// testTimeout("15") | ||
// testSucces("2") | ||
// testTimeout("30") | ||
// testSucces("3") | ||
// } | ||
|
||
// // go test -run -v Test_Timeout_Panic | ||
// func Test_Timeout_Panic(t *testing.T) { | ||
// app := fiber.New(fiber.Config{DisableStartupMessage: true}) | ||
|
||
// app.Get("/panic", recover.New(), New(func(c *fiber.Ctx) error { | ||
// c.Set("dummy", "this should not be here") | ||
// panic("panic in timeout handler") | ||
// }, 5*time.Millisecond)) | ||
|
||
// resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil)) | ||
// utils.AssertEqual(t, nil, err, "app.Test(req)") | ||
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code") | ||
|
||
// body, err := ioutil.ReadAll(resp.Body) | ||
// utils.AssertEqual(t, nil, err) | ||
// utils.AssertEqual(t, "Request Timeout", string(body)) | ||
// } | ||
func sleepWithContext(ctx context.Context, d time.Duration, te error) error { | ||
timer := time.NewTimer(d) | ||
select { | ||
case <-ctx.Done(): | ||
if !timer.Stop() { | ||
<-timer.C | ||
} | ||
return te | ||
case <-timer.C: | ||
} | ||
return nil | ||
} |