diff --git a/client.go b/client.go index 1587eb83..d8c979a0 100644 --- a/client.go +++ b/client.go @@ -1176,11 +1176,11 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { if n > 0 { f.offset += int64(n) - m, err2 := w.Write(b[:n]) + m, wErr := w.Write(b[:n]) written += int64(m) - if err == nil { - err = err2 + if wErr != nil { + return written, wErr } } diff --git a/client_integration_test.go b/client_integration_test.go index b5083845..1a855d42 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1249,6 +1249,57 @@ func TestClientReadSequential(t *testing.T) { } } +type lastChunkErrSequentialWriter struct { + expected int + written int + writtenReturn int +} + +func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) { + chunkSize := len(b) + w.written += chunkSize + if w.written == w.expected { + return w.writtenReturn, errors.New("test error") + } + return chunkSize, nil +} + +func TestClientWriteSequential_WriterErr(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr") + require.NoError(t, err) + + defer os.RemoveAll(d) + + var ( + content = []byte("hello world") + shortWrite = 2 + ) + w := lastChunkErrSequentialWriter{ + expected: len(content), + writtenReturn: shortWrite, + } + + f, err := ioutil.TempFile(d, "write-sequential-writeerr-test") + require.NoError(t, err) + fname := f.Name() + n, err := f.Write(content) + require.NoError(t, err) + require.Equal(t, n, len(content)) + require.NoError(t, f.Close()) + + sftpFile, err := sftp.Open(fname) + require.NoError(t, err) + defer sftpFile.Close() + + gotWritten, gotErr := sftpFile.writeToSequential(&w) + require.NotErrorIs(t, io.EOF, gotErr) + require.Equal(t, int64(shortWrite), gotWritten) +} + func TestClientReadDir(t *testing.T) { sftp1, cmd1 := testClient(t, READONLY, NODELAY) sftp2, cmd2 := testClientGoSvr(t, READONLY, NODELAY)