Skip to content

Commit

Permalink
fix: fd leaks, tests on windows (#103)
Browse files Browse the repository at this point in the history
* fix: close everything

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: missing close

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: close

* fix: closer

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* revert

* revert

* fix: errors

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: output

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* chore: diff

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* chore: escape all escapes

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: out

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* loosing hope

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: 0o666 because windows

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: perms, hopefully

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: windos

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: E, hopefully

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: close server

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: missing closes

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: fixes

* chore: trying

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: windows paths

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: windows paths

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>

* fix: recurse windows

* fix: close

* fix: exit already closes

* fix: fd leak

* chore: fmts

* chore: cleanup

Signed-off-by: Carlos A Becker <caarlos0@users.noreply.github.com>
  • Loading branch information
caarlos0 committed Nov 23, 2022
1 parent d65a162 commit 9e9765d
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 38 deletions.
4 changes: 4 additions & 0 deletions git/git_test.go
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"sync"
"testing"

Expand Down Expand Up @@ -66,6 +67,9 @@ func TestGitMiddleware(t *testing.T) {
})

t.Run("create repo in subdir", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission issues")
}
cwd := t.TempDir()
requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/abc/repo1"))
Expand Down
28 changes: 15 additions & 13 deletions scp/copy_from_client.go
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"regexp"
"strconv"
"strings"

"github.com/charmbracelet/ssh"
)
Expand Down Expand Up @@ -39,42 +40,43 @@ func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) err
)

for {
line, _, err := r.ReadLine()
line, err := r.ReadString('\n')
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return fmt.Errorf("failed to read line: %w", err)
}
line = strings.TrimSuffix(line, "\n")

if matches := reTimestamp.FindAllStringSubmatch(string(line), 2); matches != nil {
if matches := reTimestamp.FindAllStringSubmatch(line, 2); matches != nil {
mtime, err = strconv.ParseInt(matches[0][1], 10, 64)
if err != nil {
return parseError{string(line)}
return parseError{line}
}
atime, err = strconv.ParseInt(matches[0][2], 10, 64)
if err != nil {
return parseError{string(line)}
return parseError{line}
}

// accepts the header
_, _ = s.Write(NULL)
continue
}

if matches := reNewFile.FindAllStringSubmatch(string(line), 3); matches != nil {
if matches := reNewFile.FindAllStringSubmatch(line, 3); matches != nil {
if len(matches) != 1 || len(matches[0]) != 4 {
return parseError{string(line)}
return parseError{line}
}

mode, err := strconv.ParseUint(matches[0][1], 8, 32)
if err != nil {
return parseError{string(line)}
return parseError{line}
}

size, err := strconv.ParseInt(matches[0][2], 10, 64)
if err != nil {
return parseError{string(line)}
return parseError{line}
}
name := matches[0][3]

Expand Down Expand Up @@ -107,14 +109,14 @@ func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) err
continue
}

if matches := reNewFolder.FindAllStringSubmatch(string(line), 2); matches != nil {
if matches := reNewFolder.FindAllStringSubmatch(line, 2); matches != nil {
if len(matches) != 1 || len(matches[0]) != 3 {
return parseError{string(line)}
return parseError{line}
}

mode, err := strconv.ParseUint(matches[0][1], 8, 32)
if err != nil {
return parseError{string(line)}
return parseError{line}
}
name := matches[0][2]

Expand All @@ -136,15 +138,15 @@ func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) err
continue
}

if string(line) == "E" {
if line == "E" {
path = filepath.Dir(path)

// says 'hey im done'
_, _ = s.Write(NULL)
continue
}

return fmt.Errorf("unhandled input: %q", string(line))
return fmt.Errorf("unhandled input: %q", line)
}

_, _ = s.Write(NULL)
Expand Down
4 changes: 3 additions & 1 deletion scp/copy_to_client.go
Expand Up @@ -18,7 +18,9 @@ func copyToClient(s ssh.Session, info Info, handler CopyToClientHandler) error {

rootEntry := &RootEntry{}
var closers []func() error
defer closeAll(closers)
defer func() {
closeAll(closers)
}()

for _, match := range matches {
if !info.Recursive {
Expand Down
5 changes: 5 additions & 0 deletions scp/filesystem.go
Expand Up @@ -18,6 +18,7 @@ var _ Handler = &fileSystemHandler{}

// NewFileSystemHandler return a Handler based on the given dir.
func NewFileSystemHandler(root string) Handler {
// FIXME: if you scp -r host:/, it'll copy the root folder too, and it shouldn't.
return &fileSystemHandler{
root: filepath.Clean(root),
}
Expand Down Expand Up @@ -113,9 +114,13 @@ func (h *fileSystemHandler) Write(_ ssh.Session, entry *FileEntry) (int64, error
if err != nil {
return 0, fmt.Errorf("failed to open file: %q: %w", entry.Filepath, err)
}
defer f.Close() //nolint:errcheck
written, err := io.Copy(f, entry.Reader)
if err != nil {
return 0, fmt.Errorf("failed to write file: %q: %w", entry.Filepath, err)
}
if err := f.Close(); err != nil {
return 0, fmt.Errorf("failed to close file: %q: %w", entry.Filepath, err)
}
return written, h.chtimes(entry.Filepath, entry.Mtime, entry.Atime)
}
24 changes: 11 additions & 13 deletions scp/filesystem_test.go
Expand Up @@ -71,8 +71,8 @@ func TestFilesystem(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive glob", func(t *testing.T) {
Expand All @@ -88,8 +88,8 @@ func TestFilesystem(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a/*")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive invalid file", func(t *testing.T) {
Expand Down Expand Up @@ -209,7 +209,7 @@ func TestFilesystem(t *testing.T) {
err := h.Mkdir(nil, &DirEntry{
Name: "foo",
Filepath: "foo/bar/baz",
Mode: 0755,
Mode: 0o755,
})
is.True(err != nil) // should err
})
Expand All @@ -222,7 +222,7 @@ func TestFilesystem(t *testing.T) {
_, err := h.Write(nil, &FileEntry{
Name: "foo.txt",
Filepath: "baz/foo.txt",
Mode: 0644,
Mode: 0o644,
Size: 10,
})
is.True(err != nil) // should err
Expand All @@ -234,23 +234,21 @@ func TestFilesystem(t *testing.T) {
_, err := h.Write(nil, &FileEntry{
Name: "foo.txt",
Filepath: "foo.txt",
Mode: 0644,
Mode: 0o644,
Size: 10,
Reader: iotest.ErrReader(fmt.Errorf("fake err")),
})
is.True(err != nil) // should err
})
})

})

}

func chtimesTree(tb testing.TB, dir string, atime, mtime time.Time) {
is := is.New(tb)

filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
is.NoErr(os.Chtimes(path, atime, mtime))
return nil
})
is.New(tb).NoErr(filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
return os.Chtimes(path, atime, mtime)
}))
}
1 change: 1 addition & 0 deletions scp/fs.go
Expand Up @@ -26,6 +26,7 @@ func (h *fsHandler) WalkDir(_ ssh.Session, path string, fn fs.WalkDirFunc) error
}

func (h *fsHandler) NewDirEntry(_ ssh.Session, path string) (*DirEntry, error) {
path = normalizePath(path)
info, err := fs.Stat(h.fsys, path)
if err != nil {
return nil, fmt.Errorf("failed to open dir: %q: %w", path, err)
Expand Down
4 changes: 2 additions & 2 deletions scp/fs_test.go
Expand Up @@ -67,8 +67,8 @@ func TestFS(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive glob", func(t *testing.T) {
Expand All @@ -84,8 +84,8 @@ func TestFS(t *testing.T) {

session := setup(t, h, nil)
bts, err := session.CombinedOutput("scp -r -f a/*")
requireEqualGolden(t, bts)
is.NoErr(err)
requireEqualGolden(t, bts)
})

t.Run("recursive invalid file", func(t *testing.T) {
Expand Down
19 changes: 14 additions & 5 deletions scp/scp.go
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"io/fs"
"path/filepath"
"runtime"
"strconv"
"strings"

Expand Down Expand Up @@ -150,7 +151,7 @@ type RootEntry []Entry
// Appennd the given entry to a child directory, or the the itself if
// none matches.
func (e *RootEntry) Append(entry Entry) {
parent := filepath.Dir(entry.path())
parent := normalizePath(filepath.Dir(entry.path()))

for _, child := range *e {
switch dir := child.(type) {
Expand All @@ -159,7 +160,7 @@ func (e *RootEntry) Append(entry Entry) {
dir.Children = append(dir.Children, entry)
return
}
if strings.HasPrefix(parent, dir.Filepath) {
if strings.HasPrefix(parent, normalizePath(dir.Filepath)) {
dir.Append(entry)
return
}
Expand Down Expand Up @@ -220,7 +221,7 @@ func (e *DirEntry) Write(w io.Writer) error {

// Appends an entry to the folder or their children.
func (e *DirEntry) Append(entry Entry) {
parent := filepath.Dir(entry.path())
parent := normalizePath(filepath.Dir(entry.path()))

for _, child := range e.Children {
switch dir := child.(type) {
Expand All @@ -229,7 +230,7 @@ func (e *DirEntry) Append(entry Entry) {
dir.Children = append(dir.Children, entry)
return
}
if strings.HasPrefix(parent, dir.path()) {
if strings.HasPrefix(parent, normalizePath(dir.path())) {
dir.Append(entry)
return
}
Expand Down Expand Up @@ -257,7 +258,7 @@ type Info struct {
// Ok is true if the current session is a SCP.
Ok bool

// Recursice is true if its a recursive SCP.
// Recursive is true if its a recursive SCP.
Recursive bool

// Path is the server path of the scp operation.
Expand Down Expand Up @@ -294,3 +295,11 @@ func GetInfo(cmd []string) Info {
func octalPerms(info fs.FileMode) string {
return "0" + strconv.FormatUint(uint64(info.Perm()), 8)
}

func normalizePath(p string) string {
p = filepath.Clean(p)
if runtime.GOOS == "windows" {
return strings.ReplaceAll(p, "\\", "/")
}
return p
}
21 changes: 17 additions & 4 deletions scp/scp_test.go
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish/testsession"
"github.com/google/go-cmp/cmp"
"github.com/matryer/is"
gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -115,7 +117,6 @@ func setup(tb testing.TB, rh CopyToClientHandler, wh CopyFromClientHandler) *gos
return testsession.New(tb, &ssh.Server{
Handler: Middleware(rh, wh)(func(s ssh.Session) {
s.Exit(0)
s.Close()
}),
}, nil)
}
Expand All @@ -124,7 +125,17 @@ func requireEqualGolden(tb testing.TB, out []byte) {
tb.Helper()
is := is.New(tb)

out = bytes.ReplaceAll(out, NULL, []byte("<NULL>"))
fixOutput := func(bts []byte) []byte {
bts = bytes.ReplaceAll(bts, []byte("\r"), []byte(""))
if runtime.GOOS == "windows" {
// perms always come different on Windows because, well, its Windows.
bts = bytes.ReplaceAll(bts, []byte("0666"), []byte("0644"))
bts = bytes.ReplaceAll(bts, []byte("0777"), []byte("0755"))
}
return bytes.ReplaceAll(bts, NULL, []byte("<NULL>"))
}

out = fixOutput(out)
golden := "testdata/" + tb.Name() + ".test"
if os.Getenv("UPDATE") != "" {
is.NoErr(os.MkdirAll(filepath.Dir(golden), 0o755))
Expand All @@ -133,7 +144,9 @@ func requireEqualGolden(tb testing.TB, out []byte) {

gbts, err := os.ReadFile(golden)
is.NoErr(err)
gbts = fixOutput(gbts)

gbts = bytes.ReplaceAll(gbts, NULL, []byte("<NULL>"))
is.Equal(string(gbts), string(out))
if diff := cmp.Diff(string(gbts), string(out)); diff != "" {
tb.Fatal("files do not match:", diff)
}
}
2 changes: 2 additions & 0 deletions testsession/testsession.go
Expand Up @@ -45,6 +45,8 @@ func newLocalListener(tb testing.TB) net.Listener {
tb.Fatalf("failed to listen on a port: %v", err)
}
}

tb.Cleanup(func() { _ = l.Close() })
return l
}

Expand Down

0 comments on commit 9e9765d

Please sign in to comment.