From 1b9edfe5ee3cb7a44d17e64742653457607b8ee3 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Fri, 28 Jan 2022 16:15:48 -0500 Subject: [PATCH] feat: add recover middleware recover middleware wraps other middlewares and recover from panics --- recover/recover.go | 35 ++++++++++++++++++++++++++++++++++ recover/recover_test.go | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 recover/recover.go create mode 100644 recover/recover_test.go diff --git a/recover/recover.go b/recover/recover.go new file mode 100644 index 0000000..b2014fa --- /dev/null +++ b/recover/recover.go @@ -0,0 +1,35 @@ +package recover + +import ( + "log" + "runtime/debug" + + "github.com/charmbracelet/wish" + "github.com/gliderlabs/ssh" +) + +// Middleware is a wish middleware that recover from panics. The default logger +// is used to log panics. +func Middleware(mw ...wish.Middleware) wish.Middleware { + return MiddlewareWithLogger(log.Default(), mw...) +} + +// MiddlewareWithLogger is a wish middleware that recover from panics and log to +// the provided logger. +func MiddlewareWithLogger(logger *log.Logger, mw ...wish.Middleware) wish.Middleware { + return func(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + func() { + defer func() { + if r := recover(); r != nil { + logger.Printf("panic: %v\n%s", r, string(debug.Stack())) + } + }() + for _, m := range mw { + m(sh)(s) + } + sh(s) + }() + } + } +} diff --git a/recover/recover_test.go b/recover/recover_test.go new file mode 100644 index 0000000..5e251be --- /dev/null +++ b/recover/recover_test.go @@ -0,0 +1,42 @@ +package recover + +import ( + "fmt" + "strings" + "testing" + + "github.com/charmbracelet/wish/testsession" + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" +) + +func TestMiddleware(t *testing.T) { + t.Run("recover session", func(t *testing.T) { + _, err := setup(t).Output("") + defer func() { + if r := recover(); r != nil { + if strings.HasPrefix(fmt.Sprint(r), "panic: hello\n") { + t.Errorf("session should be recovered") + } + } + }() + requireNoError(t, err) + }) +} + +func setup(tb testing.TB) *gossh.Session { + tb.Helper() + return testsession.New(tb, &ssh.Server{ + Handler: Middleware()(func(s ssh.Session) { + panic("hello") + }), + }, nil) +} + +func requireNoError(t *testing.T, err error) { + t.Helper() + + if err != nil { + t.Fatalf("expected no error, got %q", err.Error()) + } +}