Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WithMaxTxPacket server option #584

Merged
merged 1 commit into from Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion packet.go
Expand Up @@ -823,7 +823,7 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error {
// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
const dataHeaderLen = 4 + 1 + 4 + 4

func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte {
dataLen := p.Len
if dataLen > maxTxPacket {
dataLen = maxTxPacket
Expand Down
32 changes: 25 additions & 7 deletions request-server.go
Expand Up @@ -10,7 +10,7 @@ import (
"sync"
)

var maxTxPacket uint32 = 1 << 15
const defaultMaxTxPacket uint32 = 1 << 15

// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
Expand All @@ -28,6 +28,7 @@ type RequestServer struct {
pktMgr *packetManager

startDirectory string
maxTxPacket uint32

mu sync.RWMutex
handleCount int
Expand Down Expand Up @@ -57,6 +58,22 @@ func WithStartDirectory(startDirectory string) RequestServerOption {
}
}

// WithRSMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithRSMaxTxPacket(size uint32) RequestServerOption {
return func(rs *RequestServer) {
if size < defaultMaxTxPacket {
return
}

rs.maxTxPacket = size
}
}

// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
Expand All @@ -73,6 +90,7 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ
pktMgr: newPktMgr(svrConn),

startDirectory: "/",
maxTxPacket: defaultMaxTxPacket,

openRequests: make(map[string]*Request),
}
Expand Down Expand Up @@ -260,7 +278,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
Expand All @@ -272,32 +290,32 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpExtendedPacketPosixRename:
request := &Request{
Method: "PosixRename",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case *sshFxpExtendedPacketStatVFS:
request := &Request{
Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.id(), EBADF)
} else {
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case hasPath:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
request.close()
default:
rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
Expand Down
24 changes: 12 additions & 12 deletions request.go
Expand Up @@ -300,14 +300,14 @@ func (r *Request) transferError(err error) {
}

// called from worker to handle packet/request
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
switch r.Method {
case "Get":
return fileget(handlers.FileGet, r, pkt, alloc, orderID)
return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket)
case "Put":
return fileput(handlers.FilePut, r, pkt, alloc, orderID)
return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Open":
return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
return filecmd(handlers.FileCmd, r, pkt)
case "List":
Expand Down Expand Up @@ -392,13 +392,13 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
}

// wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rd := r.getReaderAt()
if rd == nil {
return statusFromError(pkt.id(), errors.New("unexpected read packet"))
}

data, offset, _ := packetData(pkt, alloc, orderID)
data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)

n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read
Expand All @@ -414,28 +414,28 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde
}

// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
wr := r.getWriterAt()
if wr == nil {
return statusFromError(pkt.id(), errors.New("unexpected write packet"))
}

data, offset, _ := packetData(pkt, alloc, orderID)
data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)

_, err := wr.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
}

// wrap OpenFileWriter handler
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rw := r.getWriterAtReaderAt()
if rw == nil {
return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
}

switch p := pkt.(type) {
case *sshFxpReadPacket:
data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset)

n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read
Expand All @@ -461,10 +461,10 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o
}

// file data for additional read/write packets
func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) {
switch p := p.(type) {
case *sshFxpReadPacket:
return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len
case *sshFxpWritePacket:
return p.Data, int64(p.Offset), p.Length
}
Expand Down
22 changes: 11 additions & 11 deletions request_test.go
Expand Up @@ -149,7 +149,7 @@ func TestRequestGet(t *testing.T) {
for i, txt := range []string{"file-", "data."} {
pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a",
Offset: uint64(i * 5), Len: 5}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
dpkt := rpkt.(*sshFxpDataPacket)
assert.Equal(t, dpkt.id(), uint32(i))
assert.Equal(t, string(dpkt.Data), txt)
Expand All @@ -162,7 +162,7 @@ func TestRequestCustomError(t *testing.T) {
pkt := fakePacket{myid: 1}
cmdErr := errors.New("stat not supported")
handlers.returnError(cmdErr)
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr))
}

Expand All @@ -173,11 +173,11 @@ func TestRequestPut(t *testing.T) {
request.state.writerAt, _ = handlers.FilePut.Filewrite(request)
pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5,
Data: []byte("file-")}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)
pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5,
Data: []byte("data.")}
rpkt = request.call(handlers, pkt, nil, 0)
rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}
Expand All @@ -186,19 +186,19 @@ func TestRequestCmdr(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Mkdir")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)

handlers.returnError(errTest)
rpkt = request.call(handlers, pkt, nil, 0)
rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest))
}

func TestRequestInfoStat(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
spkt, ok := rpkt.(*sshFxpStatResponse)
assert.True(t, ok)
assert.Equal(t, spkt.info.Name(), "request_test.go")
Expand All @@ -215,13 +215,13 @@ func TestRequestInfoList(t *testing.T) {
assert.Equal(t, hpkt.Handle, "1")
}
pkt = fakePacket{myid: 2}
request.call(handlers, pkt, nil, 0)
request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
}
func TestRequestInfoReadlink(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Readlink")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
npkt, ok := rpkt.(*sshFxpNamePacket)
if assert.True(t, ok) {
assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0])
Expand All @@ -234,7 +234,7 @@ func TestOpendirHandleReuse(t *testing.T) {
request := testRequest("Stat")
request.handle = "1"
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt, nil, 0)
rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.IsType(t, &sshFxpStatResponse{}, rpkt)

request.Method = "List"
Expand All @@ -244,6 +244,6 @@ func TestOpendirHandleReuse(t *testing.T) {
hpkt := rpkt.(*sshFxpHandlePacket)
assert.Equal(t, hpkt.Handle, "1")
}
rpkt = request.call(handlers, pkt, nil, 0)
rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.IsType(t, &sshFxpNamePacket{}, rpkt)
}
38 changes: 29 additions & 9 deletions server.go
Expand Up @@ -34,6 +34,7 @@ type Server struct {
openFilesLock sync.RWMutex
handleCount int
workDir string
maxTxPacket uint32
}

func (svr *Server) nextHandle(f *os.File) string {
Expand Down Expand Up @@ -86,6 +87,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File),
maxTxPacket: defaultMaxTxPacket,
}

for _, o := range options {
Expand Down Expand Up @@ -139,6 +141,24 @@ func WithServerWorkingDirectory(workDir string) ServerOption {
}
}

// WithMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithMaxTxPacket(size uint32) ServerOption {
return func(s *Server) error {
if size < defaultMaxTxPacket {
return errors.New("size must be greater than or equal to 32768")
}

s.maxTxPacket = size

return nil
}
}

type rxPacket struct {
pktType fxp
pktBytes []byte
Expand Down Expand Up @@ -287,7 +307,7 @@ func handlePacket(s *Server, p orderedRequest) error {
f, ok := s.getHandle(p.Handle)
if ok {
err = nil
data := p.getDataSlice(s.pktMgr.alloc, orderID)
data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket)
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
err = _err
Expand Down Expand Up @@ -513,16 +533,16 @@ func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {

fs, err := p.unmarshalFileStat(p.Flags)

if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
puellanivis marked this conversation as resolved.
Show resolved Hide resolved
err = os.Truncate(path, int64(fs.Size))
}
if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = os.Chmod(path, fs.FileMode())
}
if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = os.Chown(path, int(fs.UID), int(fs.GID))
}
if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}

Expand All @@ -541,16 +561,16 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {

fs, err := p.unmarshalFileStat(p.Flags)

if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = f.Truncate(int64(fs.Size))
}
if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = f.Chmod(fs.FileMode())
}
if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = f.Chown(int(fs.UID), int(fs.GID))
}
if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
type chtimer interface {
Chtimes(atime, mtime time.Time) error
}
Expand Down