diff --git a/recover/recover.go b/recover/recover.go new file mode 100644 index 0000000..6dff73d --- /dev/null +++ b/recover/recover.go @@ -0,0 +1,39 @@ +package recover + +import ( + "log" + "runtime/debug" + + "github.com/charmbracelet/wish" + "github.com/gliderlabs/ssh" +) + +// Middleware is a wish middleware that recovers from panics and log to stderr. +func Middleware(mw ...wish.Middleware) wish.Middleware { + return MiddlewareWithLogger(nil, mw...) +} + +// MiddlewareWithLogger is a wish middleware that recovers from panics and log to +// the provided logger. +func MiddlewareWithLogger(logger *log.Logger, mw ...wish.Middleware) wish.Middleware { + if logger == nil { + logger = log.Default() + } + h := func(ssh.Session) {} + for _, m := range mw { + h = m(h) + } + return func(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + func() { + defer func() { + if r := recover(); r != nil { + logger.Printf("panic: %v\n%s", r, string(debug.Stack())) + } + }() + h(s) + }() + sh(s) + } + } +} diff --git a/recover/recover_test.go b/recover/recover_test.go new file mode 100644 index 0000000..bfd941c --- /dev/null +++ b/recover/recover_test.go @@ -0,0 +1,35 @@ +package recover + +import ( + "testing" + + "github.com/charmbracelet/wish/testsession" + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" +) + +func TestMiddleware(t *testing.T) { + t.Run("recover session", func(t *testing.T) { + _, err := setup(t).Output("") + requireNoError(t, err) + }) +} + +func setup(tb testing.TB) *gossh.Session { + tb.Helper() + return testsession.New(tb, &ssh.Server{ + Handler: Middleware(func(h ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + panic("hello") + } + })(func(s ssh.Session) {}), + }, nil) +} + +func requireNoError(t *testing.T, err error) { + t.Helper() + + if err != nil { + t.Fatalf("expected no error, got %q", err.Error()) + } +}