Skip to content

Commit

Permalink
fix(output): report terminal status when writer is not a file
Browse files Browse the repository at this point in the history
The underlying writer doesn't have to be a *os.File for it to be a TTY.
For example, a PTY ssh session is a TTY. However, the std library
returns a io.ReadWriter for the ssh session.

Combined with the WithUnsafe() option, we can query the terminal of an
ssh session using Termenv.
  • Loading branch information
aymanbagabas committed Apr 14, 2023
1 parent 0f89b9f commit 8368526
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions termenv_unix.go
Expand Up @@ -6,6 +6,7 @@ package termenv
import (
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -113,8 +114,8 @@ func (o Output) backgroundColor() Color {
return ANSIColor(0)
}

func (o *Output) waitForData(timeout time.Duration) error {
fd := o.TTY().Fd()
func (o *Output) waitForData(f *os.File, timeout time.Duration) error {
fd := f.Fd()
tv := unix.NsecToTimeval(int64(timeout))
var readfds unix.FdSet
readfds.Set(int(fd))
Expand All @@ -137,15 +138,15 @@ func (o *Output) waitForData(timeout time.Duration) error {
return nil
}

func (o *Output) readNextByte() (byte, error) {
if !o.unsafe {
if err := o.waitForData(OSCTimeout); err != nil {
func (o *Output) readNextByte(rw io.ReadWriter) (byte, error) {
if f, ok := rw.(*os.File); ok && !o.unsafe {
if err := o.waitForData(f, OSCTimeout); err != nil {
return 0, err
}
}

var b [1]byte
n, err := o.TTY().Read(b[:])
n, err := rw.Read(b[:])
if err != nil {
return 0, err
}
Expand All @@ -160,15 +161,15 @@ func (o *Output) readNextByte() (byte, error) {
// readNextResponse reads either an OSC response or a cursor position response:
// - OSC response: "\x1b]11;rgb:1111/1111/1111\x1b\\"
// - cursor position response: "\x1b[42;1R"
func (o *Output) readNextResponse() (response string, isOSC bool, err error) {
start, err := o.readNextByte()
func (o *Output) readNextResponse(rw io.ReadWriter) (response string, isOSC bool, err error) {
start, err := o.readNextByte(rw)
if err != nil {
return "", false, err
}

// first byte must be ESC
for start != ESC {
start, err = o.readNextByte()
start, err = o.readNextByte(rw)
if err != nil {
return "", false, err
}
Expand All @@ -177,7 +178,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) {
response += string(start)

// next byte is either '[' (cursor position response) or ']' (OSC response)
tpe, err := o.readNextByte()
tpe, err := o.readNextByte(rw)
if err != nil {
return "", false, err
}
Expand All @@ -195,7 +196,7 @@ func (o *Output) readNextResponse() (response string, isOSC bool, err error) {
}

for {
b, err := o.readNextByte()
b, err := o.readNextByte(rw)
if err != nil {
return "", false, err
}
Expand Down Expand Up @@ -231,13 +232,17 @@ func (o Output) termStatusReport(sequence int) (string, error) {
return "", ErrStatusReport
}

tty := o.TTY()
if tty == nil {
tty, ok := o.Writer().(io.ReadWriter)
if tty == nil || !ok {
return "", ErrStatusReport
}

if !o.unsafe {
fd := int(tty.Fd())
f, ok := tty.(*os.File)
if !ok {
return "", ErrStatusReport
}
fd := int(f.Fd())
// if in background, we can't control the terminal
if !isForeground(fd) {
return "", ErrStatusReport
Expand All @@ -264,7 +269,7 @@ func (o Output) termStatusReport(sequence int) (string, error) {
fmt.Fprintf(tty, CSI+"6n")

// read the next response
res, isOSC, err := o.readNextResponse()
res, isOSC, err := o.readNextResponse(tty)
if err != nil {
return "", fmt.Errorf("%s: %s", ErrStatusReport, err)
}
Expand All @@ -275,7 +280,7 @@ func (o Output) termStatusReport(sequence int) (string, error) {
}

// read the cursor query response next and discard the result
_, _, err = o.readNextResponse()
_, _, err = o.readNextResponse(tty)
if err != nil {
return "", err
}
Expand Down

0 comments on commit 8368526

Please sign in to comment.