diff --git a/middleware/recover.go b/middleware/recover.go index 0dbe740da..a621a9efe 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -9,6 +9,9 @@ import ( ) type ( + // LogErrorFunc defines a function for custom logging in the middleware. + LogErrorFunc func(c echo.Context, err error, stack []byte) error + // RecoverConfig defines the config for Recover middleware. RecoverConfig struct { // Skipper defines a function to skip middleware. @@ -30,6 +33,10 @@ type ( // LogLevel is log level to printing stack trace. // Optional. Default value 0 (Print). LogLevel log.Lvl + + // LogErrorFunc defines a function for custom logging in the middleware. + // If it's set you don't need to provide LogLevel for config. + LogErrorFunc LogErrorFunc } ) @@ -41,6 +48,7 @@ var ( DisableStackAll: false, DisablePrintStack: false, LogLevel: 0, + LogErrorFunc: nil, } ) @@ -73,9 +81,18 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if !ok { err = fmt.Errorf("%v", r) } - stack := make([]byte, config.StackSize) - length := runtime.Stack(stack, !config.DisableStackAll) + var stack []byte + var length int + if !config.DisablePrintStack { + stack = make([]byte, config.StackSize) + length = runtime.Stack(stack, !config.DisableStackAll) + stack = stack[:length] + } + + if config.LogErrorFunc != nil { + err = config.LogErrorFunc(c, err, stack) + } else if !config.DisablePrintStack { msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) switch config.LogLevel { case log.DEBUG: diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 644332972..9ac4feedc 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "errors" "fmt" "net/http" "net/http/httptest" @@ -81,3 +82,55 @@ func TestRecoverWithConfig_LogLevel(t *testing.T) { }) } } + +func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { + e := echo.New() + e.Logger.SetLevel(log.DEBUG) + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + testError := errors.New("test") + config := DefaultRecoverConfig + config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { + msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) + if errors.Is(err, testError) { + c.Logger().Debug(msg) + } else { + c.Logger().Error(msg) + } + return err + } + + t.Run("first branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic(testError) + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"DEBUG"`) + }) + + t.Run("else branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("other") + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"ERROR"`) + }) +}