From 2d34a53a6a407c7566b50a59c8cacff8fb27426c Mon Sep 17 00:00:00 2001 From: Hilari Moragrega Date: Thu, 17 Feb 2022 19:58:49 +0100 Subject: [PATCH 1/2] Return writer error if not nil in writeSequential --- client.go | 8 +++++--- client_integration_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 1587eb83..f5bd39ee 100644 --- a/client.go +++ b/client.go @@ -1176,11 +1176,13 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { if n > 0 { f.offset += int64(n) - m, err2 := w.Write(b[:n]) + m, wErr := w.Write(b[:n]) written += int64(m) - if err == nil { - err = err2 + if wErr != nil { + if err == nil || err == io.EOF { + err = wErr + } } } diff --git a/client_integration_test.go b/client_integration_test.go index b5083845..d7c10408 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1249,6 +1249,44 @@ func TestClientReadSequential(t *testing.T) { } } +type writerFunc func(b []byte) (int, error) + +func (f writerFunc) Write(b []byte) (int, error) { + return f(b) +} + +func TestClientWriteSequential_WriterErr(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + sftp.disableConcurrentReads = true + d, err := ioutil.TempDir("", "sftptest-writesequential") + require.NoError(t, err) + + defer os.RemoveAll(d) + + f, err := ioutil.TempFile(d, "write-sequential-test") + require.NoError(t, err) + fname := f.Name() + content := []byte("hello world") + f.Write(content) + f.Close() + + sftpFile, err := sftp.Open(fname) + require.NoError(t, err) + defer sftpFile.Close() + + want := errors.New("error writing") + n, got := io.Copy(writerFunc(func(b []byte) (int, error) { + return 10, want + }), sftpFile) + + require.Error(t, got) + assert.ErrorIs(t, want, got) + assert.Equal(t, int64(10), n) +} + func TestClientReadDir(t *testing.T) { sftp1, cmd1 := testClient(t, READONLY, NODELAY) sftp2, cmd2 := testClientGoSvr(t, READONLY, NODELAY) From 1c605bf1f54e6fc69bfeb318c96b22f5e5f14f3a Mon Sep 17 00:00:00 2001 From: Hilari Moragrega Date: Mon, 21 Feb 2022 12:19:31 +0100 Subject: [PATCH 2/2] Improved test with CR feedback --- client.go | 4 +--- client_integration_test.go | 47 ++++++++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index f5bd39ee..d8c979a0 100644 --- a/client.go +++ b/client.go @@ -1180,9 +1180,7 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) { written += int64(m) if wErr != nil { - if err == nil || err == io.EOF { - err = wErr - } + return written, wErr } } diff --git a/client_integration_test.go b/client_integration_test.go index d7c10408..1a855d42 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1249,10 +1249,19 @@ func TestClientReadSequential(t *testing.T) { } } -type writerFunc func(b []byte) (int, error) +type lastChunkErrSequentialWriter struct { + expected int + written int + writtenReturn int +} -func (f writerFunc) Write(b []byte) (int, error) { - return f(b) +func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) { + chunkSize := len(b) + w.written += chunkSize + if w.written == w.expected { + return w.writtenReturn, errors.New("test error") + } + return chunkSize, nil } func TestClientWriteSequential_WriterErr(t *testing.T) { @@ -1260,31 +1269,35 @@ func TestClientWriteSequential_WriterErr(t *testing.T) { defer cmd.Wait() defer sftp.Close() - sftp.disableConcurrentReads = true - d, err := ioutil.TempDir("", "sftptest-writesequential") + d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr") require.NoError(t, err) defer os.RemoveAll(d) - f, err := ioutil.TempFile(d, "write-sequential-test") + var ( + content = []byte("hello world") + shortWrite = 2 + ) + w := lastChunkErrSequentialWriter{ + expected: len(content), + writtenReturn: shortWrite, + } + + f, err := ioutil.TempFile(d, "write-sequential-writeerr-test") require.NoError(t, err) fname := f.Name() - content := []byte("hello world") - f.Write(content) - f.Close() + n, err := f.Write(content) + require.NoError(t, err) + require.Equal(t, n, len(content)) + require.NoError(t, f.Close()) sftpFile, err := sftp.Open(fname) require.NoError(t, err) defer sftpFile.Close() - want := errors.New("error writing") - n, got := io.Copy(writerFunc(func(b []byte) (int, error) { - return 10, want - }), sftpFile) - - require.Error(t, got) - assert.ErrorIs(t, want, got) - assert.Equal(t, int64(10), n) + gotWritten, gotErr := sftpFile.writeToSequential(&w) + require.NotErrorIs(t, io.EOF, gotErr) + require.Equal(t, int64(shortWrite), gotWritten) } func TestClientReadDir(t *testing.T) {