diff --git a/README.md b/README.md index bf6b9b9..0fbd39a 100644 --- a/README.md +++ b/README.md @@ -338,13 +338,17 @@ terminal applications on Unix support ANSI styling out-of-the-box, on Windows you need to enable ANSI processing in your application first: ```go - mode, err := termenv.EnableWindowsANSIConsole() + restoreConsole, err := termenv.EnableVirtualTerminalProcessing(termenv.DefaultOutput()) if err != nil { panic(err) } - defer termenv.RestoreWindowsConsole(mode) + defer restoreConsole() ``` +The above code is safe to include on non-Windows systems or when os.Stdout does +not refer to a terminal (e.g. in tests). + + ## Color Chart ![ANSI color chart](https://github.com/muesli/termenv/raw/master/examples/color-chart/color-chart.png) diff --git a/examples/color-chart/main.go b/examples/color-chart/main.go index 860dd2b..1ea232c 100644 --- a/examples/color-chart/main.go +++ b/examples/color-chart/main.go @@ -7,6 +7,12 @@ import ( ) func main() { + restoreConsole, err := termenv.EnableVirtualTerminalProcessing(termenv.DefaultOutput()) + if err != nil { + panic(err) + } + defer restoreConsole() + // Basic ANSI colors 0 - 15 fmt.Println(termenv.String("Basic ANSI colors").Bold()) diff --git a/examples/hello-world/main.go b/examples/hello-world/main.go index 48633c8..eaab4ad 100644 --- a/examples/hello-world/main.go +++ b/examples/hello-world/main.go @@ -7,6 +7,12 @@ import ( ) func main() { + restoreConsole, err := termenv.EnableVirtualTerminalProcessing(termenv.DefaultOutput()) + if err != nil { + panic(err) + } + defer restoreConsole() + p := termenv.ColorProfile() fmt.Printf("\n\t%s %s %s %s %s", diff --git a/termenv_test.go b/termenv_test.go index 8b59d98..87f22ad 100644 --- a/termenv_test.go +++ b/termenv_test.go @@ -392,3 +392,16 @@ func TestCache(t *testing.T) { t.Errorf("Expected cache to be active, got %t", o.cache) } } + +func TestEnableVirtualTerminalProcessing(t *testing.T) { + // EnableVirtualTerminalProcessing should always return a non-nil + // restoreFunc, and in tests it should never return an error. + restoreFunc, err := EnableVirtualTerminalProcessing(NewOutput(os.Stdout)) + if restoreFunc == nil || err != nil { + t.Fatalf("expected non-, , got %p, %v", restoreFunc, err) + } + // In tests, restoreFunc should never return an error. + if err := restoreFunc(); err != nil { + t.Fatalf("expected , got %v", err) + } +} diff --git a/termenv_unix.go b/termenv_unix.go index 333a08b..778c354 100644 --- a/termenv_unix.go +++ b/termenv_unix.go @@ -5,6 +5,7 @@ package termenv import ( "fmt" + "io" "strconv" "strings" "time" @@ -273,3 +274,11 @@ func (o Output) termStatusReport(sequence int) (string, error) { // fmt.Println("Rcvd", res[1:]) return res, nil } + +// EnableVirtualTerminalProcessing enables virtual terminal processing on +// Windows for w and returns a function that restores w to its previous state. +// On non-Windows platforms, or if w does not refer to a terminal, then it +// returns a non-nil no-op function and no error. +func EnableVirtualTerminalProcessing(w io.Writer) (func() error, error) { + return func() error { return nil }, nil +} diff --git a/termenv_windows.go b/termenv_windows.go index 84e5c2e..1d9c618 100644 --- a/termenv_windows.go +++ b/termenv_windows.go @@ -4,6 +4,7 @@ package termenv import ( + "fmt" "strconv" "golang.org/x/sys/windows" @@ -90,3 +91,49 @@ func RestoreWindowsConsole(mode uint32) error { return windows.SetConsoleMode(handle, mode) } + +// EnableVirtualTerminalProcessing enables virtual terminal processing on +// Windows for o and returns a function that restores o to its previous state. +// On non-Windows platforms, or if o does not refer to a terminal, then it +// returns a non-nil no-op function and no error. +func EnableVirtualTerminalProcessing(o *Output) (restoreFunc func() error, err error) { + // There is nothing to restore until we set the console mode. + restoreFunc = func() error { + return nil + } + + // If o is not a tty, then there is nothing to do. + tty := o.TTY() + if tty == nil { + return + } + + // Get the current console mode. If there is an error, assume that o is not + // a terminal, discard the error, and return. + var mode uint32 + if err2 := windows.GetConsoleMode(windows.Handle(tty.Fd()), &mode); err2 != nil { + return + } + + // If virtual terminal processing is already set, then there is nothing to + // do and nothing to restore. + if mode&windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING == windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING { + return + } + + // Enable virtual terminal processing. See + // https://docs.microsoft.com/en-us/windows/console/console-virtual-terminal-sequences + if err2 := windows.SetConsoleMode(windows.Handle(tty.Fd()), mode|windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING); err2 != nil { + err = fmt.Errorf("windows.SetConsoleMode: %w", err2) + return + } + + // Set the restore function. We maintain a reference to the tty in the + // closure (rather than just its handle) to ensure that the tty is not + // closed by a finalizer. + restoreFunc = func() error { + return windows.SetConsoleMode(windows.Handle(tty.Fd()), mode) + } + + return +}