Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fd leaks, tests on windows #103

Merged
merged 28 commits into from Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions git/git_test.go
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"sync"
"testing"

Expand Down Expand Up @@ -66,6 +67,9 @@ func TestGitMiddleware(t *testing.T) {
})

t.Run("create repo in subdir", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission issues")
}
cwd := t.TempDir()
requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/abc/repo1"))
Expand Down
28 changes: 15 additions & 13 deletions scp/copy_from_client.go
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"regexp"
"strconv"
"strings"

"github.com/charmbracelet/ssh"
)
Expand Down Expand Up @@ -39,42 +40,43 @@ func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) err
)

for {
line, _, err := r.ReadLine()
line, err := r.ReadString('\n')
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return fmt.Errorf("failed to read line: %w", err)
}
line = strings.TrimSuffix(line, "\n")

if matches := reTimestamp.FindAllStringSubmatch(string(line), 2); matches != nil {
if matches := reTimestamp.FindAllStringSubmatch(line, 2); matches != nil {
mtime, err = strconv.ParseInt(matches[0][1], 10, 64)
if err != nil {
return parseError{string(line)}
return parseError{line}
}
atime, err = strconv.ParseInt(matches[0][2], 10, 64)
if err != nil {
return parseError{string(line)}
return parseError{line}
}

// accepts the header
_, _ = s.Write(NULL)
continue
}

if matches := reNewFile.FindAllStringSubmatch(string(line), 3); matches != nil {
if matches := reNewFile.FindAllStringSubmatch(line, 3); matches != nil {
if len(matches) != 1 || len(matches[0]) != 4 {
return parseError{string(line)}
return parseError{line}
}

mode, err := strconv.ParseUint(matches[0][1], 8, 32)
if err != nil {
return parseError{string(line)}
return parseError{line}
}

size, err := strconv.ParseInt(matches[0][2], 10, 64)
if err != nil {
return parseError{string(line)}
return parseError{line}
}
name := matches[0][3]

Expand Down Expand Up @@ -107,14 +109,14 @@ func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) err
continue
}

if matches := reNewFolder.FindAllStringSubmatch(string(line), 2); matches != nil {
if matches := reNewFolder.FindAllStringSubmatch(line, 2); matches != nil {
if len(matches) != 1 || len(matches[0]) != 3 {
return parseError{string(line)}
return parseError{line}
}

mode, err := strconv.ParseUint(matches[0][1], 8, 32)
if err != nil {
return parseError{string(line)}
return parseError{line}
}
name := matches[0][2]

Expand All @@ -136,15 +138,15 @@ func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) err
continue
}

if string(line) == "E" {
if line == "E" {
path = filepath.Dir(path)

// says 'hey im done'
_, _ = s.Write(NULL)
continue
}

return fmt.Errorf("unhandled input: %q", string(line))
return fmt.Errorf("unhandled input: %q", line)
}

_, _ = s.Write(NULL)
Expand Down
4 changes: 3 additions & 1 deletion scp/copy_to_client.go
Expand Up @@ -18,7 +18,9 @@ func copyToClient(s ssh.Session, info Info, handler CopyToClientHandler) error {

rootEntry := &RootEntry{}
var closers []func() error
defer closeAll(closers)
defer func() {
closeAll(closers)
}()

for _, match := range matches {
if !info.Recursive {
Expand Down
5 changes: 5 additions & 0 deletions scp/filesystem.go
Expand Up @@ -18,6 +18,7 @@ var _ Handler = &fileSystemHandler{}

// NewFileSystemHandler return a Handler based on the given dir.
func NewFileSystemHandler(root string) Handler {
// FIXME: if you scp -r host:/, it'll copy the root folder too, and it shouldn't.
return &fileSystemHandler{
root: filepath.Clean(root),
}
Expand Down Expand Up @@ -113,9 +114,13 @@ func (h *fileSystemHandler) Write(_ ssh.Session, entry *FileEntry) (int64, error
if err != nil {
return 0, fmt.Errorf("failed to open file: %q: %w", entry.Filepath, err)
}
defer f.Close() //nolint:errcheck
written, err := io.Copy(f, entry.Reader)
if err != nil {
return 0, fmt.Errorf("failed to write file: %q: %w", entry.Filepath, err)
}
if err := f.Close(); err != nil {
return 0, fmt.Errorf("failed to close file: %q: %w", entry.Filepath, err)
}
return written, h.chtimes(entry.Filepath, entry.Mtime, entry.Atime)
}
24 changes: 11 additions & 13 deletions scp/filesystem_test.go
Expand Up @@ -71,8 +71,8 @@ func TestFilesystem(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive glob", func(t *testing.T) {
Expand All @@ -88,8 +88,8 @@ func TestFilesystem(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a/*")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive invalid file", func(t *testing.T) {
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestFilesystem(t *testing.T) {
err := h.Mkdir(nil, &DirEntry{
Name: "foo",
Filepath: "foo/bar/baz",
Mode: 0755,
Mode: 0o755,
})
is.True(err != nil) // should err
})
Expand All @@ -222,7 +222,7 @@ func TestFilesystem(t *testing.T) {
_, err := h.Write(nil, &FileEntry{
Name: "foo.txt",
Filepath: "baz/foo.txt",
Mode: 0644,
Mode: 0o644,
Size: 10,
})
is.True(err != nil) // should err
Expand All @@ -234,23 +234,21 @@ func TestFilesystem(t *testing.T) {
_, err := h.Write(nil, &FileEntry{
Name: "foo.txt",
Filepath: "foo.txt",
Mode: 0644,
Mode: 0o644,
Size: 10,
Reader: iotest.ErrReader(fmt.Errorf("fake err")),
})
is.True(err != nil) // should err
})
})

})

}

func chtimesTree(tb testing.TB, dir string, atime, mtime time.Time) {
is := is.New(tb)

filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
is.NoErr(os.Chtimes(path, atime, mtime))
return nil
})
is.New(tb).NoErr(filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
return os.Chtimes(path, atime, mtime)
}))
}
1 change: 1 addition & 0 deletions scp/fs.go
Expand Up @@ -26,6 +26,7 @@ func (h *fsHandler) WalkDir(_ ssh.Session, path string, fn fs.WalkDirFunc) error
}

func (h *fsHandler) NewDirEntry(_ ssh.Session, path string) (*DirEntry, error) {
path = normalizePath(path)
info, err := fs.Stat(h.fsys, path)
if err != nil {
return nil, fmt.Errorf("failed to open dir: %q: %w", path, err)
Expand Down
4 changes: 2 additions & 2 deletions scp/fs_test.go
Expand Up @@ -67,8 +67,8 @@ func TestFS(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive glob", func(t *testing.T) {
Expand All @@ -84,8 +84,8 @@ func TestFS(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a/*")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive invalid file", func(t *testing.T) {
Expand Down
19 changes: 14 additions & 5 deletions scp/scp.go
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"io/fs"
"path/filepath"
"runtime"
"strconv"
"strings"

Expand Down Expand Up @@ -150,7 +151,7 @@ type RootEntry []Entry
// Appennd the given entry to a child directory, or the the itself if
// none matches.
func (e *RootEntry) Append(entry Entry) {
parent := filepath.Dir(entry.path())
parent := normalizePath(filepath.Dir(entry.path()))

for _, child := range *e {
switch dir := child.(type) {
Expand All @@ -159,7 +160,7 @@ func (e *RootEntry) Append(entry Entry) {
dir.Children = append(dir.Children, entry)
return
}
if strings.HasPrefix(parent, dir.Filepath) {
if strings.HasPrefix(parent, normalizePath(dir.Filepath)) {
dir.Append(entry)
return
}
Expand Down Expand Up @@ -220,7 +221,7 @@ func (e *DirEntry) Write(w io.Writer) error {

// Appends an entry to the folder or their children.
func (e *DirEntry) Append(entry Entry) {
parent := filepath.Dir(entry.path())
parent := normalizePath(filepath.Dir(entry.path()))

for _, child := range e.Children {
switch dir := child.(type) {
Expand All @@ -229,7 +230,7 @@ func (e *DirEntry) Append(entry Entry) {
dir.Children = append(dir.Children, entry)
return
}
if strings.HasPrefix(parent, dir.path()) {
if strings.HasPrefix(parent, normalizePath(dir.path())) {
dir.Append(entry)
return
}
Expand Down Expand Up @@ -257,7 +258,7 @@ type Info struct {
// Ok is true if the current session is a SCP.
Ok bool

// Recursice is true if its a recursive SCP.
// Recursive is true if its a recursive SCP.
Recursive bool

// Path is the server path of the scp operation.
Expand Down Expand Up @@ -294,3 +295,11 @@ func GetInfo(cmd []string) Info {
func octalPerms(info fs.FileMode) string {
return "0" + strconv.FormatUint(uint64(info.Perm()), 8)
}

func normalizePath(p string) string {
p = filepath.Clean(p)
if runtime.GOOS == "windows" {
return strings.ReplaceAll(p, "\\", "/")
}
return p
}
21 changes: 17 additions & 4 deletions scp/scp_test.go
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish/testsession"
"github.com/google/go-cmp/cmp"
"github.com/matryer/is"
gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -115,7 +117,6 @@ func setup(tb testing.TB, rh CopyToClientHandler, wh CopyFromClientHandler) *gos
return testsession.New(tb, &ssh.Server{
Handler: Middleware(rh, wh)(func(s ssh.Session) {
s.Exit(0)
s.Close()
}),
}, nil)
}
Expand All @@ -124,7 +125,17 @@ func requireEqualGolden(tb testing.TB, out []byte) {
tb.Helper()
is := is.New(tb)

out = bytes.ReplaceAll(out, NULL, []byte("<NULL>"))
fixOutput := func(bts []byte) []byte {
bts = bytes.ReplaceAll(bts, []byte("\r"), []byte(""))
if runtime.GOOS == "windows" {
// perms always come different on Windows because, well, its Windows.
bts = bytes.ReplaceAll(bts, []byte("0666"), []byte("0644"))
bts = bytes.ReplaceAll(bts, []byte("0777"), []byte("0755"))
}
return bytes.ReplaceAll(bts, NULL, []byte("<NULL>"))
}

out = fixOutput(out)
golden := "testdata/" + tb.Name() + ".test"
if os.Getenv("UPDATE") != "" {
is.NoErr(os.MkdirAll(filepath.Dir(golden), 0o755))
Expand All @@ -133,7 +144,9 @@ func requireEqualGolden(tb testing.TB, out []byte) {

gbts, err := os.ReadFile(golden)
is.NoErr(err)
gbts = fixOutput(gbts)

gbts = bytes.ReplaceAll(gbts, NULL, []byte("<NULL>"))
is.Equal(string(gbts), string(out))
if diff := cmp.Diff(string(gbts), string(out)); diff != "" {
tb.Fatal("files do not match:", diff)
}
}
2 changes: 2 additions & 0 deletions testsession/testsession.go
Expand Up @@ -45,6 +45,8 @@ func newLocalListener(tb testing.TB) net.Listener {
tb.Fatalf("failed to listen on a port: %v", err)
}
}

tb.Cleanup(func() { _ = l.Close() })
return l
}

Expand Down