diff --git a/client.go b/client.go index d8c979a0..9e0b6164 100644 --- a/client.go +++ b/client.go @@ -1176,11 +1176,11 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { if n > 0 { f.offset += int64(n) - m, wErr := w.Write(b[:n]) + m, err := w.Write(b[:n]) written += int64(m) - if wErr != nil { - return written, wErr + if err != nil { + return written, err } } diff --git a/client_integration_test.go b/client_integration_test.go index 1a855d42..39902f9a 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1249,55 +1249,51 @@ func TestClientReadSequential(t *testing.T) { } } +// this writer requires maxPacket = 3 and always returns an error for the second write call type lastChunkErrSequentialWriter struct { - expected int - written int - writtenReturn int + counter int } func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) { - chunkSize := len(b) - w.written += chunkSize - if w.written == w.expected { - return w.writtenReturn, errors.New("test error") + w.counter++ + if w.counter == 1 { + if len(b) != 3 { + return 0, errors.New("this writer require maxPacket = 3, please set MaxPacketChecked(3)") + } + return len(b), nil } - return chunkSize, nil + return 1, errors.New("this writer fails after the first write") } -func TestClientWriteSequential_WriterErr(t *testing.T) { - sftp, cmd := testClient(t, READONLY, NODELAY) +func TestClientWriteSequentialWriterErr(t *testing.T) { + client, cmd := testClient(t, READONLY, NODELAY, MaxPacketChecked(3)) defer cmd.Wait() - defer sftp.Close() + defer client.Close() d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr") require.NoError(t, err) defer os.RemoveAll(d) - var ( - content = []byte("hello world") - shortWrite = 2 - ) - w := lastChunkErrSequentialWriter{ - expected: len(content), - writtenReturn: shortWrite, - } - f, err := ioutil.TempFile(d, "write-sequential-writeerr-test") require.NoError(t, err) fname := f.Name() - n, err := f.Write(content) + _, err = f.Write([]byte("12345")) require.NoError(t, err) - require.Equal(t, n, len(content)) require.NoError(t, f.Close()) - sftpFile, err := sftp.Open(fname) + sftpFile, err := client.Open(fname) require.NoError(t, err) defer sftpFile.Close() - gotWritten, gotErr := sftpFile.writeToSequential(&w) - require.NotErrorIs(t, io.EOF, gotErr) - require.Equal(t, int64(shortWrite), gotWritten) + w := &lastChunkErrSequentialWriter{} + written, err := sftpFile.writeToSequential(w) + assert.Error(t, err) + expected := int64(4) + if written != expected { + t.Errorf("sftpFile.Write() = %d, but expected %d", written, expected) + } + assert.Equal(t, 2, w.counter) } func TestClientReadDir(t *testing.T) {