Skip to content

Commit

Permalink
writeToSequential: improve tests for write errors
Browse files Browse the repository at this point in the history
  • Loading branch information
drakkan committed Mar 3, 2022
1 parent 65f24bc commit fd3f204
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 29 deletions.
6 changes: 3 additions & 3 deletions client.go
Expand Up @@ -1176,11 +1176,11 @@ func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
if n > 0 {
f.offset += int64(n)

m, wErr := w.Write(b[:n])
m, err := w.Write(b[:n])
written += int64(m)

if wErr != nil {
return written, wErr
if err != nil {
return written, err
}
}

Expand Down
47 changes: 21 additions & 26 deletions client_integration_test.go
Expand Up @@ -1249,55 +1249,50 @@ func TestClientReadSequential(t *testing.T) {
}
}

// this writer requires maxPacket = 3 and always returns an error for the second write call
type lastChunkErrSequentialWriter struct {
expected int
written int
writtenReturn int
writeCounter int
}

func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) {
chunkSize := len(b)
w.written += chunkSize
if w.written == w.expected {
return w.writtenReturn, errors.New("test error")
w.writeCounter++
if w.writeCounter == 1 {
if len(b) != 3 {
return 0, errors.New("this writer require maxPacket = 3, please set MaxPacketChecked(3)")
}
return len(b), nil
}
return chunkSize, nil
return 1, errors.New("this writer fails after the first write")
}

func TestClientWriteSequential_WriterErr(t *testing.T) {
sftp, cmd := testClient(t, READONLY, NODELAY)
func TestClientWriteSequentialWriterErr(t *testing.T) {
client, cmd := testClient(t, READONLY, NODELAY, MaxPacketChecked(3))
defer cmd.Wait()
defer sftp.Close()
defer client.Close()

d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr")
require.NoError(t, err)

defer os.RemoveAll(d)

var (
content = []byte("hello world")
shortWrite = 2
)
w := lastChunkErrSequentialWriter{
expected: len(content),
writtenReturn: shortWrite,
}

f, err := ioutil.TempFile(d, "write-sequential-writeerr-test")
require.NoError(t, err)
fname := f.Name()
n, err := f.Write(content)
_, err = f.Write([]byte("12345"))
require.NoError(t, err)
require.Equal(t, n, len(content))
require.NoError(t, f.Close())

sftpFile, err := sftp.Open(fname)
sftpFile, err := client.Open(fname)
require.NoError(t, err)
defer sftpFile.Close()

gotWritten, gotErr := sftpFile.writeToSequential(&w)
require.NotErrorIs(t, io.EOF, gotErr)
require.Equal(t, int64(shortWrite), gotWritten)
w := &lastChunkErrSequentialWriter{}
written, err := sftpFile.writeToSequential(w)
assert.Error(t, err)
if written != 4 {
t.Errorf("sftpFile.Write() = %d, but expected 4", written)
}
assert.Equal(t, 2, w.writeCounter)
}

func TestClientReadDir(t *testing.T) {
Expand Down

0 comments on commit fd3f204

Please sign in to comment.