diff --git a/client_integration_test.go b/client_integration_test.go index f26adb44..b5083845 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "reflect" "regexp" + "runtime" "sort" "strconv" "sync" @@ -359,7 +360,7 @@ func TestClientOpenIsNotExist(t *testing.T) { defer cmd.Wait() defer sftp.Close() - if _, err := sftp.Open("/doesnt/exist/"); !os.IsNotExist(err) { + if _, err := sftp.Open("/doesnt/exist"); !os.IsNotExist(err) { t.Errorf("os.IsNotExist(%v) = false, want true", err) } } @@ -369,7 +370,7 @@ func TestClientStatIsNotExist(t *testing.T) { defer cmd.Wait() defer sftp.Close() - if _, err := sftp.Stat("/doesnt/exist/"); !os.IsNotExist(err) { + if _, err := sftp.Stat("/doesnt/exist"); !os.IsNotExist(err) { t.Errorf("os.IsNotExist(%v) = false, want true", err) } } @@ -758,6 +759,11 @@ func TestClientGetwd(t *testing.T) { } func TestClientReadLink(t *testing.T) { + if runtime.GOOS == "windows" && *testServerImpl { + // os.Symlink requires privilege escalation. + t.Skip() + } + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() @@ -810,6 +816,11 @@ func TestClientLink(t *testing.T) { } func TestClientSymlink(t *testing.T) { + if runtime.GOOS == "windows" && *testServerImpl { + // os.Symlink requires privilege escalation. + t.Skip() + } + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() @@ -1600,6 +1611,7 @@ func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) { if !*testServerImpl { t.Skipf("skipping without -testserver") } + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() @@ -2241,18 +2253,23 @@ func TestServerRoughDisconnectEOF(t *testing.T) { // sftp/issue/26 writing to a read only file caused client to loop. func TestClientWriteToROFile(t *testing.T) { skipIfWindows(t) + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() + defer func() { err := sftp.Close() assert.NoError(t, err) }() + // TODO (puellanivis): /dev/zero is not actually a read-only file. + // So, this test works purely by accident. f, err := sftp.Open("/dev/zero") if err != nil { t.Fatal(err) } defer f.Close() + _, err = f.Write([]byte("hello")) if err == nil { t.Fatal("expected error, got", err) diff --git a/packet.go b/packet.go index 50ca069d..4059cf8e 100644 --- a/packet.go +++ b/packet.go @@ -1242,7 +1242,7 @@ func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { } func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { - err := os.Rename(p.Oldpath, p.Newpath) + err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) return statusFromError(p.ID, err) } @@ -1271,6 +1271,6 @@ func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error { } func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket { - err := os.Link(p.Oldpath, p.Newpath) + err := os.Link(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) return statusFromError(p.ID, err) } diff --git a/server_integration_test.go b/server_integration_test.go index f72d64b8..407d38a2 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -19,7 +19,6 @@ import ( "net" "os" "os/exec" - "path" "path/filepath" "regexp" "runtime" @@ -363,7 +362,9 @@ func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { } // starts an ssh server to test. returns: host string and port -func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, string, int) { +func testServer(t *testing.T, useSubsystem bool, readonly bool) (func(), string, int) { + t.Helper() + if !*testIntegration { t.Skip("skipping integration test") } @@ -382,28 +383,35 @@ func testServer(t *testing.T, useSubsystem bool, readonly bool) (net.Listener, s t.Fatal(err) } + shutdown := make(chan struct{}) + go func() { for { conn, err := listener.Accept() if err != nil { - fmt.Fprintf(sshServerDebugStream, "ssh server socket closed: %v\n", err) - break + select { + case <-shutdown: + default: + t.Error("ssh server socket closed:", err) + } + return } go func() { defer conn.Close() + sshSvr, err := sshServerFromConn(conn, useSubsystem, basicServerConfig()) if err != nil { t.Error(err) return } - err = sshSvr.Wait() - fmt.Fprintf(sshServerDebugStream, "ssh server finished, err: %v\n", err) + + _ = sshSvr.Wait() }() } }() - return listener, host, port + return func() { close(shutdown); listener.Close() }, host, port } func makeDummyKey() (string, error) { @@ -468,35 +476,40 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i } defer os.Remove(dummyKey) - args := []string{ + cmd := exec.Command( + *testSftpClientBin, // "-vvvv", "-b", "-", "-o", "StrictHostKeyChecking=no", "-o", "LogLevel=ERROR", "-o", "UserKnownHostsFile /dev/null", // do not trigger ssh-agent prompting - "-o", "IdentityFile=" + dummyKey, + "-o", "IdentityFile="+dummyKey, "-o", "IdentitiesOnly=yes", "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path), - } - cmd := exec.Command(*testSftpClientBin, args...) - var stdout bytes.Buffer - var stderr bytes.Buffer - cmd.Stdin = bytes.NewBufferString(script) - cmd.Stdout = &stdout - cmd.Stderr = &stderr + ) + + cmd.Stdin = strings.NewReader(script) + + stdout := new(bytes.Buffer) + cmd.Stdout = stdout + + stderr := new(bytes.Buffer) + cmd.Stderr = stderr + if err := cmd.Start(); err != nil { return "", err } - err = cmd.Wait() - if err != nil { - err = &execError{ + + if err := cmd.Wait(); err != nil { + return stdout.String(), &execError{ path: cmd.Path, stderr: stderr.String(), err: err, } } - return stdout.String(), err + + return stdout.String(), nil } // assert.Eventually seems to have a data rate on macOS with go 1.14 so replace it with this simpler function @@ -532,10 +545,16 @@ func checkAllocatorAfterServerClose(t *testing.T, alloc *allocator) { } func TestServerCompareSubsystems(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - listenerOp, hostOp, portOp := testServer(t, OpenSSHSFTP, READONLY) - defer listenerGo.Close() - defer listenerOp.Close() + if runtime.GOOS == "windows" { + // TODO (puellanivis): not sure how to fix this, the OpenSSH SFTP implementation closes immediately. + t.Skip() + } + + shutdownGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdownGo() + + shutdownOp, hostOp, portOp := testServer(t, OpenSSHSFTP, READONLY) + defer shutdownOp() script := ` ls / @@ -561,10 +580,11 @@ ls -l /usr/bin/ outputGoLines := newlineRegex.Split(outputGo, -1) outputOpLines := newlineRegex.Split(outputOp, -1) + if len(outputGoLines) != len(outputOpLines) { + t.Fatalf("output line count differs, go = %d, openssh = %d", len(outputGoLines), len(outputOpLines)) + } + for i, goLine := range outputGoLines { - if i > len(outputOpLines) { - t.Fatalf("output line count differs") - } opLine := outputOpLines[i] bad := false if goLine != opLine { @@ -576,8 +596,9 @@ ls -l /usr/bin/ // during testing as processes are created/destroyed. // words[7] as timestamp on dirs can very for things like /tmp for j, goWord := range goWords { - if j > len(opWords) { + if j >= len(opWords) { bad = true + break } opWord := opWords[j] if goWord != opWord && j != 1 && j != 2 && j != 3 && j != 7 { @@ -587,7 +608,7 @@ ls -l /usr/bin/ } if bad { - t.Errorf("outputs differ, go:\n%v\nopenssh:\n%v\n", goLine, opLine) + t.Errorf("outputs differ\n go: %q\nopenssh: %q\n", goLine, opLine) } } } @@ -607,8 +628,8 @@ func randName() string { } func TestServerMkdirRmdir(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() tmpDir := "/tmp/" + randName() defer os.RemoveAll(tmpDir) @@ -635,8 +656,8 @@ func TestServerMkdirRmdir(t *testing.T) { func TestServerLink(t *testing.T) { skipIfWindows(t) // No hard links on windows. - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() tmpFileLocalData := randData(999) @@ -664,8 +685,8 @@ func TestServerLink(t *testing.T) { func TestServerSymlink(t *testing.T) { skipIfWindows(t) // No symlinks on windows. - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() link := "/tmp/" + randName() defer os.RemoveAll(link) @@ -684,8 +705,8 @@ func TestServerSymlink(t *testing.T) { } func TestServerPut(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() tmpFileLocal := "/tmp/" + randName() tmpFileRemote := "/tmp/" + randName() @@ -714,8 +735,8 @@ func TestServerPut(t *testing.T) { } func TestServerResume(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() tmpFileLocal := "/tmp/" + randName() tmpFileRemote := "/tmp/" + randName() @@ -761,8 +782,8 @@ func TestServerResume(t *testing.T) { } func TestServerGet(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() tmpFileLocal := "/tmp/" + randName() tmpFileRemote := "/tmp/" + randName() @@ -802,7 +823,7 @@ func compareDirectoriesRecursive(t *testing.T, aroot, broot string) { if err != nil { t.Fatalf("could not find relative path for %v: %v", aPath, err) } - bPath := path.Join(broot, aRel) + bPath := filepath.Join(broot, aRel) if aRel == "." { continue @@ -857,8 +878,8 @@ func compareDirectoriesRecursive(t *testing.T, aroot, broot string) { } func TestServerPutRecursive(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() dirLocal, err := os.Getwd() if err != nil { @@ -869,17 +890,23 @@ func TestServerPutRecursive(t *testing.T) { t.Logf("put recursive: local %v remote %v", dirLocal, tmpDirRemote) + // On windows, the client copies the contents of the directory, not the directory itself. + winFix := "" + if runtime.GOOS == "windows" { + winFix = "/" + filepath.Base(dirLocal) + } //*/ + // push this directory (source code etc) recursively to the server - if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -r -P "+dirLocal+"/ "+tmpDirRemote+"/", "/", hostGo, portGo); err != nil { + if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -R -p "+dirLocal+" "+tmpDirRemote+winFix, "/", hostGo, portGo); err != nil { t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) } - compareDirectoriesRecursive(t, dirLocal, path.Join(tmpDirRemote, path.Base(dirLocal))) + compareDirectoriesRecursive(t, dirLocal, filepath.Join(tmpDirRemote, filepath.Base(dirLocal))) } func TestServerGetRecursive(t *testing.T) { - listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) - defer listenerGo.Close() + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() dirRemote, err := os.Getwd() if err != nil { @@ -890,10 +917,16 @@ func TestServerGetRecursive(t *testing.T) { t.Logf("get recursive: local %v remote %v", tmpDirLocal, dirRemote) + // On windows, the client copies the contents of the directory, not the directory itself. + winFix := "" + if runtime.GOOS == "windows" { + winFix = "/" + filepath.Base(dirRemote) + } + // pull this directory (source code etc) recursively from the server - if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -r -P "+dirRemote+"/ "+tmpDirLocal+"/", "/", hostGo, portGo); err != nil { + if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -R -p "+dirRemote+" "+tmpDirLocal+winFix, "/", hostGo, portGo); err != nil { t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) } - compareDirectoriesRecursive(t, dirRemote, path.Join(tmpDirLocal, path.Base(dirRemote))) + compareDirectoriesRecursive(t, dirRemote, filepath.Join(tmpDirLocal, filepath.Base(dirRemote))) }