From ae00b32792f07a9378adbc433bbdfcd751c78b96 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 3 Mar 2022 10:30:30 +0100 Subject: [PATCH] writeToSequential: improve tests for write errors --- client.go | 6 ++--- client_integration_test.go | 48 ++++++++++++++++++-------------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/client.go b/client.go index d8c979a0..9e0b6164 100644 --- a/client.go +++ b/client.go @@ -1176,11 +1176,11 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { if n > 0 { f.offset += int64(n) - m, wErr := w.Write(b[:n]) + m, err := w.Write(b[:n]) written += int64(m) - if wErr != nil { - return written, wErr + if err != nil { + return written, err } } diff --git a/client_integration_test.go b/client_integration_test.go index 1a855d42..b4bedb8c 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1249,55 +1249,53 @@ func TestClientReadSequential(t *testing.T) { } } +// this writer requires maxPacket = 2 and returns a short write error for the second write call type lastChunkErrSequentialWriter struct { - expected int - written int - writtenReturn int + writeCounter int } func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) { - chunkSize := len(b) - w.written += chunkSize - if w.written == w.expected { - return w.writtenReturn, errors.New("test error") + if len(b) != 2 { + return 0, errors.New("this writer require maxPacket = 2, pleae set MaxPacketChecked(2)") + } + w.writeCounter++ + switch w.writeCounter { + case 1: + return len(b), nil + default: + return 1, io.ErrShortWrite } - return chunkSize, nil } -func TestClientWriteSequential_WriterErr(t *testing.T) { - sftp, cmd := testClient(t, READONLY, NODELAY) +func TestClientWriteSequentialWriterErr(t *testing.T) { + client, cmd := testClient(t, READONLY, NODELAY, MaxPacketChecked(2)) defer cmd.Wait() - defer sftp.Close() + defer client.Close() d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr") require.NoError(t, err) defer os.RemoveAll(d) - var ( - content = []byte("hello world") - shortWrite = 2 - ) - w := lastChunkErrSequentialWriter{ - expected: len(content), - writtenReturn: shortWrite, - } + w := &lastChunkErrSequentialWriter{} f, err := ioutil.TempFile(d, "write-sequential-writeerr-test") require.NoError(t, err) fname := f.Name() - n, err := f.Write(content) + _, err = f.Write([]byte("hello world")) require.NoError(t, err) - require.Equal(t, n, len(content)) require.NoError(t, f.Close()) - sftpFile, err := sftp.Open(fname) + sftpFile, err := client.Open(fname) require.NoError(t, err) defer sftpFile.Close() - gotWritten, gotErr := sftpFile.writeToSequential(&w) - require.NotErrorIs(t, io.EOF, gotErr) - require.Equal(t, int64(shortWrite), gotWritten) + written, err := sftpFile.writeToSequential(w) + assert.NotNil(t, err) + if written != 3 { + t.Errorf("sftpFile.Write() = %d, but expected 3", written) + } + assert.Equal(t, 2, w.writeCounter) } func TestClientReadDir(t *testing.T) {