From d4ad69dbbabaa1024642e08376a8d66a7a344050 Mon Sep 17 00:00:00 2001 From: ant1k9 Date: Mon, 17 Jan 2022 17:54:22 +0300 Subject: [PATCH 1/2] Add LogLevelSetter to recover middleware --- middleware/recover.go | 14 +++++++++++++- middleware/recover_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/middleware/recover.go b/middleware/recover.go index 0dbe740da..003acc35c 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -9,6 +9,9 @@ import ( ) type ( + // LogLevelSetter defines a function to get log level for the recovered value. + LogLevelSetter func(value interface{}) log.Lvl + // RecoverConfig defines the config for Recover middleware. RecoverConfig struct { // Skipper defines a function to skip middleware. @@ -30,6 +33,10 @@ type ( // LogLevel is log level to printing stack trace. // Optional. Default value 0 (Print). LogLevel log.Lvl + + // LogLevelSetter defines a function to get log level for the recovered value. + // LogLevelSetter has higher priority than LogLevel when it's set. + LogLevelSetter LogLevelSetter } ) @@ -41,6 +48,7 @@ var ( DisableStackAll: false, DisablePrintStack: false, LogLevel: 0, + LogLevelSetter: nil, } ) @@ -73,11 +81,15 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if !ok { err = fmt.Errorf("%v", r) } + logLevel := config.LogLevel + if config.LogLevelSetter != nil { + logLevel = config.LogLevelSetter(r) + } stack := make([]byte, config.StackSize) length := runtime.Stack(stack, !config.DisableStackAll) if !config.DisablePrintStack { msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) - switch config.LogLevel { + switch logLevel { case log.DEBUG: c.Logger().Debug(msg) case log.INFO: diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 644332972..ce7a5428c 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -81,3 +81,36 @@ func TestRecoverWithConfig_LogLevel(t *testing.T) { }) } } + +func TestRecoverWithConfig_LogLevelSetter(t *testing.T) { + e := echo.New() + e.Logger.SetLevel(log.DEBUG) + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := DefaultRecoverConfig + config.LogLevelSetter = func(value interface{}) log.Lvl { + if s, ok := value.(string); ok { + if s == "test" { + return log.DEBUG + } + } + return log.ERROR + } + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("test") + })) + + h(c) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"DEBUG"`) +} From 4730debea4d18d52efaee9c74ff82fffeb10952c Mon Sep 17 00:00:00 2001 From: ant1k9 Date: Mon, 17 Jan 2022 21:14:03 +0300 Subject: [PATCH 2/2] Change LogLevelSetter to LogErrorFunc LogErrorFunc provides more general interface to handle errors in the recover middleware. --- middleware/recover.go | 31 +++++++++++++--------- middleware/recover_test.go | 54 ++++++++++++++++++++++++++------------ 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/middleware/recover.go b/middleware/recover.go index 003acc35c..a621a9efe 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -9,8 +9,8 @@ import ( ) type ( - // LogLevelSetter defines a function to get log level for the recovered value. - LogLevelSetter func(value interface{}) log.Lvl + // LogErrorFunc defines a function for custom logging in the middleware. + LogErrorFunc func(c echo.Context, err error, stack []byte) error // RecoverConfig defines the config for Recover middleware. RecoverConfig struct { @@ -34,9 +34,9 @@ type ( // Optional. Default value 0 (Print). LogLevel log.Lvl - // LogLevelSetter defines a function to get log level for the recovered value. - // LogLevelSetter has higher priority than LogLevel when it's set. - LogLevelSetter LogLevelSetter + // LogErrorFunc defines a function for custom logging in the middleware. + // If it's set you don't need to provide LogLevel for config. + LogErrorFunc LogErrorFunc } ) @@ -48,7 +48,7 @@ var ( DisableStackAll: false, DisablePrintStack: false, LogLevel: 0, - LogLevelSetter: nil, + LogErrorFunc: nil, } ) @@ -81,15 +81,20 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if !ok { err = fmt.Errorf("%v", r) } - logLevel := config.LogLevel - if config.LogLevelSetter != nil { - logLevel = config.LogLevelSetter(r) - } - stack := make([]byte, config.StackSize) - length := runtime.Stack(stack, !config.DisableStackAll) + var stack []byte + var length int + if !config.DisablePrintStack { + stack = make([]byte, config.StackSize) + length = runtime.Stack(stack, !config.DisableStackAll) + stack = stack[:length] + } + + if config.LogErrorFunc != nil { + err = config.LogErrorFunc(c, err, stack) + } else if !config.DisablePrintStack { msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) - switch logLevel { + switch config.LogLevel { case log.DEBUG: c.Logger().Debug(msg) case log.INFO: diff --git a/middleware/recover_test.go b/middleware/recover_test.go index ce7a5428c..9ac4feedc 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "errors" "fmt" "net/http" "net/http/httptest" @@ -82,7 +83,7 @@ func TestRecoverWithConfig_LogLevel(t *testing.T) { } } -func TestRecoverWithConfig_LogLevelSetter(t *testing.T) { +func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { e := echo.New() e.Logger.SetLevel(log.DEBUG) @@ -93,24 +94,43 @@ func TestRecoverWithConfig_LogLevelSetter(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) + testError := errors.New("test") config := DefaultRecoverConfig - config.LogLevelSetter = func(value interface{}) log.Lvl { - if s, ok := value.(string); ok { - if s == "test" { - return log.DEBUG - } + config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { + msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) + if errors.Is(err, testError) { + c.Logger().Debug(msg) + } else { + c.Logger().Error(msg) } - return log.ERROR + return err } - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("test") - })) - - h(c) - - assert.Equal(t, http.StatusInternalServerError, rec.Code) - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"DEBUG"`) + t.Run("first branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic(testError) + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"DEBUG"`) + }) + + t.Run("else branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("other") + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"ERROR"`) + }) }