diff --git a/client.go b/client.go index ce62286f..8894053f 100644 --- a/client.go +++ b/client.go @@ -1461,11 +1461,20 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { cancel := make(chan struct{}) type work struct { - b []byte + id uint32 + res chan result + off int64 } workCh := make(chan work) + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + pool := newResChanPool(concurrency) + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. go func() { defer close(workCh) @@ -1479,8 +1488,20 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { wb = wb[:chunkSize] } + id := f.c.nextID() + res := pool.Get() + off := off + int64(read) + + f.c.dispatchRequest(res, &sshFxpWritePacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off), + Length: uint32(len(wb)), + Data: wb, + }) + select { - case workCh <- work{wb, off + int64(read)}: + case workCh <- work{id, res, off}: case <-cancel: return } @@ -1495,11 +1516,6 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { } errCh := make(chan wErr) - concurrency := len(b)/f.c.maxPacket + 1 - if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { - concurrency = f.c.maxConcurrentRequests - } - var wg sync.WaitGroup wg.Add(concurrency) for i := 0; i < concurrency; i++ { @@ -1507,13 +1523,22 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { go func() { defer wg.Done() - ch := make(chan result, 1) // reusable channel per mapper. + for work := range workCh { + s := <-work.res + pool.Put(work.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(work.id, s.data)) + default: + err = unimplementedPacketErr(s.typ) + } + } - for packet := range workCh { - n, err := f.writeChunkAt(ch, packet.b, packet.off) if err != nil { - // return the offset as the start + how much we wrote before the error. - errCh <- wErr{packet.off + int64(n), err} + errCh <- wErr{work.off, err} } } }() @@ -1598,8 +1623,9 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64 cancel := make(chan struct{}) type work struct { - b []byte - n int + id uint32 + res chan result + off int64 } workCh := make(chan work) @@ -1614,24 +1640,34 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64 concurrency = f.c.maxConcurrentRequests } - pool := newBufPool(concurrency, f.c.maxPacket) + pool := newResChanPool(concurrency) // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. go func() { defer close(workCh) + b := make([]byte, f.c.maxPacket) off := f.offset for { - b := pool.Get() - n, err := r.Read(b) + if n > 0 { read += int64(n) + id := f.c.nextID() + res := pool.Get() + + f.c.dispatchRequest(res, &sshFxpWritePacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off), + Length: uint32(n), + Data: b, + }) + select { - case workCh <- work{b, n, off}: - // We need the pool.Put(b) to put the whole slice, not just trunced. + case workCh <- work{id, res, off}: case <-cancel: return } @@ -1655,15 +1691,23 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64 go func() { defer wg.Done() - ch := make(chan result, 1) // reusable channel per mapper. + for work := range workCh { + s := <-work.res + pool.Put(work.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(work.id, s.data)) + default: + err = unimplementedPacketErr(s.typ) + } + } - for packet := range workCh { - n, err := f.writeChunkAt(ch, packet.b[:packet.n], packet.off) if err != nil { - // return the offset as the start + how much we wrote before the error. - errCh <- rwErr{packet.off + int64(n), err} + errCh <- rwErr{work.off, err} } - pool.Put(packet.b) } }() }