Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for sub fiber's error handlers #1560

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 76 additions & 26 deletions app.go
Expand Up @@ -109,6 +109,8 @@ type App struct {
getBytes func(s string) (b []byte)
// Converts byte slice to a string
getString func(b []byte) string
// mount prefix -> error handler
errorHandlers map[string]ErrorHandler
}

// Config is a struct holding the server settings.
Expand Down Expand Up @@ -426,9 +428,10 @@ func New(config ...Config) *App {
},
},
// Create config
config: Config{},
getBytes: utils.UnsafeBytes,
getString: utils.UnsafeString,
config: Config{},
getBytes: utils.UnsafeBytes,
getString: utils.UnsafeString,
errorHandlers: make(map[string]ErrorHandler),
}
// Override config if provided
if len(config) > 0 {
Expand Down Expand Up @@ -460,9 +463,11 @@ func New(config ...Config) *App {
if app.config.Immutable {
app.getBytes, app.getString = getBytesImmutable, getStringImmutable
}

if app.config.ErrorHandler == nil {
app.config.ErrorHandler = DefaultErrorHandler
}

if app.config.JSONEncoder == nil {
app.config.JSONEncoder = json.Marshal
}
Expand All @@ -487,7 +492,9 @@ func New(config ...Config) *App {

// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount.
// compose them as a single service using Mount. The fiber's error handler and
// any of the fiber's sub apps are added to the application's error handlers
// to be invoked on errors that happen within the prefix route.
func (app *App) Mount(prefix string, fiber *App) Router {
stack := fiber.Stack()
for m := range stack {
Expand All @@ -497,6 +504,15 @@ func (app *App) Mount(prefix string, fiber *App) Router {
}
}

// Save the fiber's error handler and its sub apps
prefix = strings.TrimRight(prefix, "/")
if fiber.config.ErrorHandler != nil {
app.errorHandlers[prefix] = fiber.config.ErrorHandler
}
for mountedPrefixes, errHandler := range fiber.errorHandlers {
app.errorHandlers[prefix+mountedPrefixes] = errHandler
}

atomic.AddUint32(&app.handlerCount, fiber.handlerCount)

return app
Expand Down Expand Up @@ -822,7 +838,7 @@ func (app *App) init() *App {
// lock application
app.mutex.Lock()

// Only load templates if an view engine is specified
// Only load templates if a view engine is specified
if app.config.Views != nil {
if err := app.config.Views.Load(); err != nil {
fmt.Printf("views: %v\n", err)
Expand All @@ -833,26 +849,7 @@ func (app *App) init() *App {
app.server = &fasthttp.Server{
Logger: &disableLogger{},
LogAllErrors: false,
ErrorHandler: func(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge {
err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly {
err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") {
err = ErrRequestTimeout
} else {
err = ErrBadRequest
}
if catch := app.config.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}
app.ReleaseCtx(c)
},
ErrorHandler: app.serverErrorHandler,
}

// fasthttp server settings
Expand Down Expand Up @@ -880,6 +877,60 @@ func (app *App) init() *App {
return app
}

// ErrorHandler is the application's method in charge of finding the
// appropiate handler for the given request. It searches any mounted
// sub fibers by their prefixes and if it finds a match, it uses that
// error handler. Otherwise it uses the configured error handler for
// the app, which if not set is the DefaultErrorHandler.
func (app *App) ErrorHandler(ctx *Ctx, err error) error {
var (
mountedErrHandler ErrorHandler
mountedPrefixParts int
)

for prefix, errHandler := range app.errorHandlers {
if strings.HasPrefix(ctx.path, prefix) {
parts := len(strings.Split(prefix, "/"))
if mountedPrefixParts <= parts {
mountedErrHandler = errHandler
mountedPrefixParts = parts
}
}
}

if mountedErrHandler != nil {
return mountedErrHandler(ctx, err)
}

return app.config.ErrorHandler(ctx, err)
}

// serverErrorHandler is a wrapper around the application's error handler method
// user for the fasthttp server configuration. It maps a set of fasthttp errors to fiber
// errors before calling the application's error handler method.
func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) {
c := app.AcquireCtx(fctx)
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
err = ErrRequestHeaderFieldsTooLarge
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
err = ErrRequestTimeout
} else if err == fasthttp.ErrBodyTooLarge {
err = ErrRequestEntityTooLarge
} else if err == fasthttp.ErrGetOnly {
err = ErrMethodNotAllowed
} else if strings.Contains(err.Error(), "timeout") {
err = ErrRequestTimeout
} else {
err = ErrBadRequest
}

if catch := app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}

app.ReleaseCtx(c)
}

// startupProcess Is the method which executes all the necessary processes just before the start of the server.
func (app *App) startupProcess() *App {
app.mutex.Lock()
Expand Down Expand Up @@ -961,7 +1012,6 @@ func (app *App) startupMessage(addr string, tls bool, pids string) {
}
}


scheme := "http"
if tls {
scheme = "https"
Expand Down
77 changes: 77 additions & 0 deletions app_test.go
Expand Up @@ -1414,3 +1414,80 @@ func Test_App_DisablePreParseMultipartForm(t *testing.T) {

utils.AssertEqual(t, testString, string(body))
}

func Test_App_UseMountedErrorHandler(t *testing.T) {
app := New()

fiber := New(Config{
ErrorHandler: func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom error")
},
})
fiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})

app.Mount("/api", fiber)

resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")

b, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom error", string(b), "Response body")
}

func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) {
app := New()

tsf := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom sub sub fiber error")
}
tripleSubFiber := New(Config{
ErrorHandler: tsf,
})
tripleSubFiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})

sf := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom sub fiber error")
}
subfiber := New(Config{
ErrorHandler: sf,
})
subfiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
subfiber.Mount("/third", tripleSubFiber)

f := func(ctx *Ctx, err error) error {
return ctx.Status(200).SendString("hi, i'm a custom error")
}
fiber := New(Config{
ErrorHandler: f,
})
fiber.Get("/", func(c *Ctx) error {
return errors.New("something happened")
})
fiber.Mount("/sub", subfiber)

app.Mount("/api", fiber)

resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", nil))
utils.AssertEqual(t, nil, err, "/api/sub req")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")

b, err := ioutil.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom sub fiber error", string(b), "Response body")

resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", nil))
utils.AssertEqual(t, nil, err, "/api/sub/third req")
utils.AssertEqual(t, 200, resp.StatusCode, "Status code")

b, err = ioutil.ReadAll(resp2.Body)
utils.AssertEqual(t, nil, err, "iotuil.ReadAll()")
utils.AssertEqual(t, "hi, i'm a custom sub sub fiber error", string(b), "Third fiber Response body")
}
2 changes: 1 addition & 1 deletion middleware/logger/logger.go
Expand Up @@ -145,7 +145,7 @@ func New(config ...Config) fiber.Handler {
}
}
// override error handler
errHandler = c.App().Config().ErrorHandler
errHandler = c.App().ErrorHandler
})

var start, stop time.Time
Expand Down
2 changes: 1 addition & 1 deletion router.go
Expand Up @@ -154,7 +154,7 @@ func (app *App) handler(rctx *fasthttp.RequestCtx) {
// Find match in stack
match, err := app.next(c)
if err != nil {
if catch := c.app.config.ErrorHandler(c, err); catch != nil {
if catch := c.app.ErrorHandler(c, err); catch != nil {
_ = c.SendStatus(StatusInternalServerError)
}
}
Expand Down