Skip to content

Commit

Permalink
PR: clean up RawSockaddr validation
Browse files Browse the repository at this point in the history
Signed-off-by: Hamza El-Saawy <hamzaelsaawy@microsoft.com>
  • Loading branch information
helsaawy committed Jul 6, 2022
1 parent a615ab2 commit 9919dfc
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 63 deletions.
1 change: 1 addition & 0 deletions file.go
Expand Up @@ -223,6 +223,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
// runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive
// until the channel read is complete.
// todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive?
runtime.KeepAlive(c)
return int(r.bytes), err
}
Expand Down
5 changes: 4 additions & 1 deletion hvsock.go
Expand Up @@ -140,7 +140,10 @@ func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
addr.ServiceID = raw.ServiceID
}

// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`
// Sockaddr returns a pointer to and the size of this struct.
//
// Implements the [socket.RawSockaddr] interface, and allows use in
// [socket.Bind()] and [socket.ConnectEx()]
func (r *rawHvsockAddr) Sockaddr() (ptr unsafe.Pointer, len int32, err error) {
return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
}
Expand Down
39 changes: 9 additions & 30 deletions internal/socket/rawaddr.go
@@ -1,41 +1,20 @@
package socket

import (
"errors"
"fmt"
"unsafe"
)

// todo: should these be custom types to store the desired/actual size and addr family?

var (
ErrBufferSize = errors.New("buffer size")
ErrInvalidPointer = errors.New("invalid pointer")
ErrAddrFamily = errors.New("address family")
)

// todo: helsaawy - replace this with generics, along with GetSockName and co.

// RawSockaddr allows structs to be used with Bind and ConnectEx. The
// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The
// struct must meet the Win32 sockaddr requirements specified here:
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
//
// Specifically, the struct size must be least larger than an int16 (unsigned short)
// for the address family.
type RawSockaddr interface {
// Sockaddr returns a pointer to the RawSockaddr and the length of the struct.
// Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing
// for the RawSockaddr's data to be overwritten by syscalls (if necessary).
//
// It is the callers responsibility to validate that the values are valid; invalid
// pointers or size can cause a panic.
Sockaddr() (unsafe.Pointer, int32, error)

// FromBytes populates the RawsockAddr with the data in the byte array.
// Implementers should check the buffer is correctly sized and the address family
// is appropriate.
// Receivers should be pointers.
FromBytes([]byte) error
}

func validateSockAddr(ptr unsafe.Pointer, n int32) error {
if ptr == nil {
return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer)
}
if n < 1 {
return fmt.Errorf("buffer size %d < 1: %w", n, ErrBufferSize)
}
return nil
}
54 changes: 22 additions & 32 deletions internal/socket/socket.go
Expand Up @@ -3,6 +3,7 @@
package socket

import (
"errors"
"fmt"
"net"
"sync"
Expand All @@ -21,57 +22,46 @@ import (

const socketError = uintptr(^uint32(0))

var ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
var (
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?

ErrBufferSize = errors.New("buffer size")
ErrAddrFamily = errors.New("address family")
ErrInvalidPointer = errors.New("invalid pointer")
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
)

// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)

// GetSockName returns the socket's local address. It will call the `rsa.FromBytes()` on the
// buffer returned by the getsockname syscall. The buffer is allocated to the size specified
// by `rsa.Sockaddr()`.
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not find socket size to allocate buffer: %w", err)
}
if err = validateSockAddr(ptr, l); err != nil {
return err
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}

b := make([]byte, l)
err = getsockname(s, unsafe.Pointer(&b[0]), &l)
if err != nil {
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
return err
}
return rsa.FromBytes(b[:l])
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
return getsockname(s, ptr, &l)
}

// GetPeerName returns the remote address the socket is connected to.
//
// See GetSockName for more information.
// See [GetSockName] for more information.
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not find socket size to allocate buffer: %w", err)
}
if err = validateSockAddr(ptr, l); err != nil {
return err
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}

b := make([]byte, l)
err = getpeername(s, unsafe.Pointer(&b[0]), &l)
if err != nil {
return err
}
return rsa.FromBytes(b[:l])
return getpeername(s, ptr, &l)
}

func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not find socket pointer and size: %w", err)
}
if err = validateSockAddr(ptr, l); err != nil {
return err
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
}

return bind(s, ptr, l)
Expand Down Expand Up @@ -150,7 +140,7 @@ func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen ui
// [in] LPOVERLAPPED lpOverlapped
// )
func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) (err error) {
// todo: after upgrading to 1.18, switch to syscall.SyscallN from syscall.Syscall9
// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0)
if r1 == 0 {
if e1 != 0 {
Expand Down

0 comments on commit 9919dfc

Please sign in to comment.