diff --git a/file.go b/file.go index 1b870350..f05f1ef5 100644 --- a/file.go +++ b/file.go @@ -223,6 +223,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er // runtime.KeepAlive is needed, as c is passed via native // code to ioCompletionProcessor, c must remain alive // until the channel read is complete. + // todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive? runtime.KeepAlive(c) return int(r.bytes), err } diff --git a/hvsock.go b/hvsock.go index 8872f471..320a1f2f 100644 --- a/hvsock.go +++ b/hvsock.go @@ -140,8 +140,11 @@ func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { addr.ServiceID = raw.ServiceID } -// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()` -func (r *rawHvsockAddr) Sockaddr() (ptr unsafe.Pointer, len int32, err error) { +// Sockaddr returns a pointer to and the size of this struct. +// +// Implements the [socket.RawSockaddr] interface, and allows use in +// [socket.Bind()] and [socket.ConnectEx()] +func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) { return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil } diff --git a/internal/socket/rawaddr.go b/internal/socket/rawaddr.go index 180cc4c3..7e82f9af 100644 --- a/internal/socket/rawaddr.go +++ b/internal/socket/rawaddr.go @@ -1,41 +1,20 @@ package socket import ( - "errors" - "fmt" "unsafe" ) -// todo: should these be custom types to store the desired/actual size and addr family? - -var ( - ErrBufferSize = errors.New("buffer size") - ErrInvalidPointer = errors.New("invalid pointer") - ErrAddrFamily = errors.New("address family") -) - -// todo: helsaawy - replace this with generics, along with GetSockName and co. - -// RawSockaddr allows structs to be used with Bind and ConnectEx. The +// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The // struct must meet the Win32 sockaddr requirements specified here: // https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 +// +// Specifically, the struct size must be least larger than an int16 (unsigned short) +// for the address family. type RawSockaddr interface { - // Sockaddr returns a pointer to the RawSockaddr and the length of the struct. + // Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing + // for the RawSockaddr's data to be overwritten by syscalls (if necessary). + // + // It is the callers responsibility to validate that the values are valid; invalid + // pointers or size can cause a panic. Sockaddr() (unsafe.Pointer, int32, error) - - // FromBytes populates the RawsockAddr with the data in the byte array. - // Implementers should check the buffer is correctly sized and the address family - // is appropriate. - // Receivers should be pointers. - FromBytes([]byte) error -} - -func validateSockAddr(ptr unsafe.Pointer, n int32) error { - if ptr == nil { - return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer) - } - if n < 1 { - return fmt.Errorf("buffer size %d < 1: %w", n, ErrBufferSize) - } - return nil } diff --git a/internal/socket/socket.go b/internal/socket/socket.go index b93fe332..f4134df4 100644 --- a/internal/socket/socket.go +++ b/internal/socket/socket.go @@ -3,6 +3,7 @@ package socket import ( + "errors" "fmt" "net" "sync" @@ -21,57 +22,46 @@ import ( const socketError = uintptr(^uint32(0)) -var ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed) +var ( + // todo(helsaawy): create custom error types to store the desired vs actual size and addr family? + + ErrBufferSize = errors.New("buffer size") + ErrAddrFamily = errors.New("address family") + ErrInvalidPointer = errors.New("invalid pointer") + ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed) +) + +// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error) -// GetSockName returns the socket's local address. It will call the `rsa.FromBytes()` on the -// buffer returned by the getsockname syscall. The buffer is allocated to the size specified -// by `rsa.Sockaddr()`. +// GetSockName writes the local address of socket s to the [RawSockaddr] rsa. +// If rsa is not large enough, the [windows.WSAEFAULT] is returned. func GetSockName(s windows.Handle, rsa RawSockaddr) error { ptr, l, err := rsa.Sockaddr() if err != nil { - return fmt.Errorf("could not find socket size to allocate buffer: %w", err) - } - if err = validateSockAddr(ptr, l); err != nil { - return err + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) } - b := make([]byte, l) - err = getsockname(s, unsafe.Pointer(&b[0]), &l) - if err != nil { - // although getsockname returns WSAEFAULT if the buffer is too small, it does not set - // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy - return err - } - return rsa.FromBytes(b[:l]) + // although getsockname returns WSAEFAULT if the buffer is too small, it does not set + // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy + return getsockname(s, ptr, &l) } // GetPeerName returns the remote address the socket is connected to. // -// See GetSockName for more information. +// See [GetSockName] for more information. func GetPeerName(s windows.Handle, rsa RawSockaddr) error { ptr, l, err := rsa.Sockaddr() if err != nil { - return fmt.Errorf("could not find socket size to allocate buffer: %w", err) - } - if err = validateSockAddr(ptr, l); err != nil { - return err + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) } - b := make([]byte, l) - err = getpeername(s, unsafe.Pointer(&b[0]), &l) - if err != nil { - return err - } - return rsa.FromBytes(b[:l]) + return getpeername(s, ptr, &l) } func Bind(s windows.Handle, rsa RawSockaddr) (err error) { ptr, l, err := rsa.Sockaddr() if err != nil { - return fmt.Errorf("could not find socket pointer and size: %w", err) - } - if err = validateSockAddr(ptr, l); err != nil { - return err + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) } return bind(s, ptr, l) @@ -150,7 +140,7 @@ func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen ui // [in] LPOVERLAPPED lpOverlapped // ) func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) (err error) { - // todo: after upgrading to 1.18, switch to syscall.SyscallN from syscall.Syscall9 + // todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0) if r1 == 0 { if e1 != 0 {