Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability for server to receive file permissions for Open and Mkdir requests #546

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 26 additions & 8 deletions packet.go
Expand Up @@ -667,12 +667,13 @@ type sshFxpOpenPacket struct {
ID uint32
Path string
Pflags uint32
Flags uint32 // ignored
Flags uint32
Attrs interface{}
}

func (p *sshFxpOpenPacket) id() uint32 { return p.ID }

func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 + 4
Expand All @@ -684,7 +685,14 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
b = marshalUint32(b, p.Pflags)
b = marshalUint32(b, p.Flags)

return b, nil
payload := marshal(nil, p.Attrs)

return b, payload, nil
}

func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}

func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
Expand All @@ -695,9 +703,10 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}

Expand Down Expand Up @@ -869,13 +878,14 @@ func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error {

type sshFxpMkdirPacket struct {
ID uint32
Flags uint32 // ignored
Path string
Flags uint32
Attrs interface{}
}

func (p *sshFxpMkdirPacket) id() uint32 { return p.ID }

func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) {
func (p *sshFxpMkdirPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 // uint32
Expand All @@ -886,7 +896,14 @@ func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) {
b = marshalString(b, p.Path)
b = marshalUint32(b, p.Flags)

return b, nil
payload := marshal(nil, p.Attrs)

return b, payload, nil
}

func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}

func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error {
Expand All @@ -895,9 +912,10 @@ func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
return err
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}

Expand Down
4 changes: 4 additions & 0 deletions request.go
Expand Up @@ -178,6 +178,10 @@ func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Reques
switch p := pkt.(type) {
case *sshFxpOpenPacket:
request.Flags = p.Pflags
request.Attrs = p.Attrs.([]byte)
case *sshFxpMkdirPacket:
request.Flags = p.Flags
request.Attrs = p.Attrs.([]byte)
case *sshFxpSetstatPacket:
request.Flags = p.Flags
request.Attrs = p.Attrs.([]byte)
Expand Down
24 changes: 21 additions & 3 deletions server.go
Expand Up @@ -19,6 +19,9 @@ import (
const (
// SftpServerWorkerCount defines the number of workers for the SFTP server
SftpServerWorkerCount = 8

defaultFileMode = 0o644
defaultDirMode = 0o755
Comment on lines +23 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constants should be grouped individually in their own const so that a block-level documentation can cover them both.

Prefer to also have them be typed constants, so that we can simply use mode := defaultDirMode .

)

// Server is an SSH File Transfer Protocol (sftp) server.
Expand Down Expand Up @@ -218,8 +221,15 @@ func handlePacket(s *Server, p orderedRequest) error {
rpkt = statusFromError(p.ID, err)
}
case *sshFxpMkdirPacket:
// TODO FIXME: ignore flags field
err := os.Mkdir(s.toLocalPath(p.Path), 0o755)
var mode os.FileMode = defaultDirMode
if p.Attrs != nil {
attrs, _ := unmarshalFileStat(p.Flags, p.Attrs.([]byte))
Comment on lines +225 to +226
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either p.Attrs should be typed []byte or we should properly handle the case where it is not []byte.

It should be sufficient to use:

switch pattrs := p.Attrs.(type) {
case []byte:
  …
case *FileStat:
  … // since this would be a fast toFileMode(pattrs.Mode) might as well support it.
}

(I can’t speak to the legacy code doing differently. If I were redoing the FXP_SET_STAT packets, I’d do it quite a bit differently… I mean, I already did it’s in internal/encoding/ssh/filexfer…)

if p.Flags&sshFileXferAttrPermissions != 0 {
mode = toFileMode(attrs.Mode)
}
}

err := os.Mkdir(s.toLocalPath(p.Path), mode)
rpkt = statusFromError(p.ID, err)
case *sshFxpRmdirPacket:
err := os.Remove(s.toLocalPath(p.Path))
Expand Down Expand Up @@ -458,7 +468,15 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
osFlags |= os.O_EXCL
}

f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644)
var mode os.FileMode = defaultFileMode
if p.Attrs != nil {
attrs, _ := unmarshalFileStat(p.Flags, p.Attrs.([]byte))
if p.Flags&sshFileXferAttrPermissions != 0 {
mode = toFileMode(attrs.Mode)
}
}

f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
if err != nil {
return statusFromError(p.ID, err)
}
Expand Down