From 587f3ae9df4bb410ef7985469fdfdb771afcbfd8 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 5 Oct 2021 08:03:20 -0400 Subject: [PATCH] Support for sub fiber's error handlers (#1560) - Mounted fiber and its sub apps error handlers are now saved a new errorHandlers map in App - New public App.ErrorHandler method that wraps the logic for which error handler to user on any given context - Error handler match logic based on request path <=> prefix accuracy - Typo fixes - Tests --- app.go | 102 +++++++++++++++++++++++++++--------- app_test.go | 77 +++++++++++++++++++++++++++ middleware/logger/logger.go | 2 +- router.go | 2 +- 4 files changed, 155 insertions(+), 28 deletions(-) diff --git a/app.go b/app.go index 4ca486caec..dacead2c35 100644 --- a/app.go +++ b/app.go @@ -109,6 +109,8 @@ type App struct { getBytes func(s string) (b []byte) // Converts byte slice to a string getString func(b []byte) string + // mount prefix -> error handler + errorHandlers map[string]ErrorHandler } // Config is a struct holding the server settings. @@ -426,9 +428,10 @@ func New(config ...Config) *App { }, }, // Create config - config: Config{}, - getBytes: utils.UnsafeBytes, - getString: utils.UnsafeString, + config: Config{}, + getBytes: utils.UnsafeBytes, + getString: utils.UnsafeString, + errorHandlers: make(map[string]ErrorHandler), } // Override config if provided if len(config) > 0 { @@ -460,9 +463,11 @@ func New(config ...Config) *App { if app.config.Immutable { app.getBytes, app.getString = getBytesImmutable, getStringImmutable } + if app.config.ErrorHandler == nil { app.config.ErrorHandler = DefaultErrorHandler } + if app.config.JSONEncoder == nil { app.config.JSONEncoder = json.Marshal } @@ -487,7 +492,9 @@ func New(config ...Config) *App { // Mount attaches another app instance as a sub-router along a routing path. // It's very useful to split up a large API as many independent routers and -// compose them as a single service using Mount. +// compose them as a single service using Mount. The fiber's error handler and +// any of the fiber's sub apps are added to the application's error handlers +// to be invoked on errors that happen within the prefix route. func (app *App) Mount(prefix string, fiber *App) Router { stack := fiber.Stack() for m := range stack { @@ -497,6 +504,15 @@ func (app *App) Mount(prefix string, fiber *App) Router { } } + // Save the fiber's error handler and its sub apps + prefix = strings.TrimRight(prefix, "/") + if fiber.config.ErrorHandler != nil { + app.errorHandlers[prefix] = fiber.config.ErrorHandler + } + for mountedPrefixes, errHandler := range fiber.errorHandlers { + app.errorHandlers[prefix+mountedPrefixes] = errHandler + } + atomic.AddUint32(&app.handlerCount, fiber.handlerCount) return app @@ -822,7 +838,7 @@ func (app *App) init() *App { // lock application app.mutex.Lock() - // Only load templates if an view engine is specified + // Only load templates if a view engine is specified if app.config.Views != nil { if err := app.config.Views.Load(); err != nil { fmt.Printf("views: %v\n", err) @@ -833,26 +849,7 @@ func (app *App) init() *App { app.server = &fasthttp.Server{ Logger: &disableLogger{}, LogAllErrors: false, - ErrorHandler: func(fctx *fasthttp.RequestCtx, err error) { - c := app.AcquireCtx(fctx) - if _, ok := err.(*fasthttp.ErrSmallBuffer); ok { - err = ErrRequestHeaderFieldsTooLarge - } else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { - err = ErrRequestTimeout - } else if err == fasthttp.ErrBodyTooLarge { - err = ErrRequestEntityTooLarge - } else if err == fasthttp.ErrGetOnly { - err = ErrMethodNotAllowed - } else if strings.Contains(err.Error(), "timeout") { - err = ErrRequestTimeout - } else { - err = ErrBadRequest - } - if catch := app.config.ErrorHandler(c, err); catch != nil { - _ = c.SendStatus(StatusInternalServerError) - } - app.ReleaseCtx(c) - }, + ErrorHandler: app.serverErrorHandler, } // fasthttp server settings @@ -880,6 +877,60 @@ func (app *App) init() *App { return app } +// ErrorHandler is the application's method in charge of finding the +// appropiate handler for the given request. It searches any mounted +// sub fibers by their prefixes and if it finds a match, it uses that +// error handler. Otherwise it uses the configured error handler for +// the app, which if not set is the DefaultErrorHandler. +func (app *App) ErrorHandler(ctx *Ctx, err error) error { + var ( + mountedErrHandler ErrorHandler + mountedPrefixParts int + ) + + for prefix, errHandler := range app.errorHandlers { + if strings.HasPrefix(ctx.path, prefix) { + parts := len(strings.Split(prefix, "/")) + if mountedPrefixParts <= parts { + mountedErrHandler = errHandler + mountedPrefixParts = parts + } + } + } + + if mountedErrHandler != nil { + return mountedErrHandler(ctx, err) + } + + return app.config.ErrorHandler(ctx, err) +} + +// serverErrorHandler is a wrapper around the application's error handler method +// user for the fasthttp server configuration. It maps a set of fasthttp errors to fiber +// errors before calling the application's error handler method. +func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) { + c := app.AcquireCtx(fctx) + if _, ok := err.(*fasthttp.ErrSmallBuffer); ok { + err = ErrRequestHeaderFieldsTooLarge + } else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { + err = ErrRequestTimeout + } else if err == fasthttp.ErrBodyTooLarge { + err = ErrRequestEntityTooLarge + } else if err == fasthttp.ErrGetOnly { + err = ErrMethodNotAllowed + } else if strings.Contains(err.Error(), "timeout") { + err = ErrRequestTimeout + } else { + err = ErrBadRequest + } + + if catch := app.ErrorHandler(c, err); catch != nil { + _ = c.SendStatus(StatusInternalServerError) + } + + app.ReleaseCtx(c) +} + // startupProcess Is the method which executes all the necessary processes just before the start of the server. func (app *App) startupProcess() *App { app.mutex.Lock() @@ -961,7 +1012,6 @@ func (app *App) startupMessage(addr string, tls bool, pids string) { } } - scheme := "http" if tls { scheme = "https" diff --git a/app_test.go b/app_test.go index c3ec6df361..46304ef265 100644 --- a/app_test.go +++ b/app_test.go @@ -1439,3 +1439,80 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) { utils.AssertEqual(t, testString, string(body)) } + +func Test_App_UseMountedErrorHandler(t *testing.T) { + app := New() + + fiber := New(Config{ + ErrorHandler: func(ctx *Ctx, err error) error { + return ctx.Status(200).SendString("hi, i'm a custom error") + }, + }) + fiber.Get("/", func(c *Ctx) error { + return errors.New("something happened") + }) + + app.Mount("/api", fiber) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + b, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "iotuil.ReadAll()") + utils.AssertEqual(t, "hi, i'm a custom error", string(b), "Response body") +} + +func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) { + app := New() + + tsf := func(ctx *Ctx, err error) error { + return ctx.Status(200).SendString("hi, i'm a custom sub sub fiber error") + } + tripleSubFiber := New(Config{ + ErrorHandler: tsf, + }) + tripleSubFiber.Get("/", func(c *Ctx) error { + return errors.New("something happened") + }) + + sf := func(ctx *Ctx, err error) error { + return ctx.Status(200).SendString("hi, i'm a custom sub fiber error") + } + subfiber := New(Config{ + ErrorHandler: sf, + }) + subfiber.Get("/", func(c *Ctx) error { + return errors.New("something happened") + }) + subfiber.Mount("/third", tripleSubFiber) + + f := func(ctx *Ctx, err error) error { + return ctx.Status(200).SendString("hi, i'm a custom error") + } + fiber := New(Config{ + ErrorHandler: f, + }) + fiber.Get("/", func(c *Ctx) error { + return errors.New("something happened") + }) + fiber.Mount("/sub", subfiber) + + app.Mount("/api", fiber) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", nil)) + utils.AssertEqual(t, nil, err, "/api/sub req") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + b, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "iotuil.ReadAll()") + utils.AssertEqual(t, "hi, i'm a custom sub fiber error", string(b), "Response body") + + resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", nil)) + utils.AssertEqual(t, nil, err, "/api/sub/third req") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + b, err = ioutil.ReadAll(resp2.Body) + utils.AssertEqual(t, nil, err, "iotuil.ReadAll()") + utils.AssertEqual(t, "hi, i'm a custom sub sub fiber error", string(b), "Third fiber Response body") +} diff --git a/middleware/logger/logger.go b/middleware/logger/logger.go index e9cf7daa92..21fe3eb38f 100644 --- a/middleware/logger/logger.go +++ b/middleware/logger/logger.go @@ -145,7 +145,7 @@ func New(config ...Config) fiber.Handler { } } // override error handler - errHandler = c.App().Config().ErrorHandler + errHandler = c.App().ErrorHandler }) var start, stop time.Time diff --git a/router.go b/router.go index 730b70390d..98f34cb81a 100644 --- a/router.go +++ b/router.go @@ -154,7 +154,7 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) { // Find match in stack match, err := app.next(c) if err != nil { - if catch := c.app.config.ErrorHandler(c, err); catch != nil { + if catch := c.app.ErrorHandler(c, err); catch != nil { _ = c.SendStatus(StatusInternalServerError) } }