From 9a4667230d9cdfe45a0c4e06a1586cc68948b597 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Thu, 17 Feb 2022 11:50:31 -0500 Subject: [PATCH 1/3] Added HV Socket known IDs, Dial, bug fixes Added: * Well-know Hyper-V VMIDs for parents, children, and loopback. * VSock interop service GUID. * `Dial()` and `DialContext()` to dial a specific Hyper-V socket at a known address (along with a corresponding `HvsockDialer` struct. Bug fixes: * Dial (and Listen) now properly initialize and set properties of their sockets after ConnectEx (and AcceptEx). * The `socketError` used by `bind` was incorrect, it should be `int32(-1)`, not `uintptr(^0)` * Return errors for `(*HvsockConn) SetDeadline` Created a `sockets` package, currently only with syscalls to `Bind`, `ConnectEx` and `GetSockName`, bypassing `syscall/windows` restrictions on the types that can do so. Signed-off-by: Hamza El-Saawy --- hvsock.go | 291 +++++++++++++++++++++++++++++--- pkg/sockets/rawaddr.go | 67 ++++++++ pkg/sockets/sockets.go | 175 +++++++++++++++++++ pkg/sockets/zsyscall_windows.go | 70 ++++++++ zsyscall_windows.go | 9 - 5 files changed, 579 insertions(+), 33 deletions(-) create mode 100644 pkg/sockets/rawaddr.go create mode 100644 pkg/sockets/sockets.go create mode 100644 pkg/sockets/zsyscall_windows.go diff --git a/hvsock.go b/hvsock.go index b2b644d0..b40add5c 100644 --- a/hvsock.go +++ b/hvsock.go @@ -4,6 +4,7 @@ package winio import ( + "context" "fmt" "io" "net" @@ -12,15 +13,73 @@ import ( "time" "unsafe" + "golang.org/x/sys/windows" + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/go-winio/pkg/sockets" ) -//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind +const afHvSock = 34 // AF_HYPERV + +var ( + // https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards + + // HVguidWildcard is the wildcard VmId for accepting connections from all partitions. + HVguidWildcard = guid.GUID{} // 00000000-0000-0000-0000-000000000000 + + // HVguidBroadcast is the wildcard VmId for broadcasting sends to all partitions + HVguidBroadcast = guid.GUID{ //ffffffff-ffff-ffff-ffff-ffffffffffff + Data1: 0xffffffff, + Data2: 0xffff, + Data3: 0xffff, + Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + } + + // HVGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector. + HVguidLoopback = guid.GUID{ // e0e16197-dd56-4a10-9195-5ee7a155a838 + Data1: 0xe0e16197, + Data2: 0xdd56, + Data3: 0x4a10, + Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38}, + } + + // HVguidSiloHost is the address of a silo's host partition: + // - The silo host of a hosted silo is the utility VM. + // - The silo host of a silo on a physical host is the physical host. + HVguidSiloHost = guid.GUID{ // 36bd0c5c-7276-4223-88ba-7d03b654c568 + Data1: 0x36bd0c5c, + Data2: 0x7276, + Data3: 0x4223, + Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68}, + } + + // HVguidChildren is the wildcard VmId for accepting connections from the connector's child partitions. + HVguidChildren = guid.GUID{ // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd + Data1: 0x90db8b89, + Data2: 0xd35, + Data3: 0x4f79, + Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd}, + } -const ( - afHvSock = 34 // AF_HYPERV + // HVguidParent is the wildcard VmId for accepting connections from the connector's parent partition. + // Listening on this VmId accepts connection from: + // - Inside silos: silo host partition. + // - Inside hosted silo: host of the VM. + // - Inside VM: VM host. + // - Physical host: Not supported. + HVguidParent = guid.GUID{ // a42e7cda-d03f-480c-9cc2-a4de20abb878 + Data1: 0xa42e7cda, + Data2: 0xd03f, + Data3: 0x480c, + Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78}, + } - socketError = ^uintptr(0) + // HVguidVSockServiceGUIDTemplate is the Service GUID used for the VSOCK protocol + hvguidVSockServiceTemplate = guid.GUID{ // 00000000-facb-11e6-bd58-64006a7986d3 + Data2: 0xfacb, + Data3: 0x11e6, + Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3}, + } ) // An HvsockAddr is an address for a AF_HYPERV socket. @@ -36,6 +95,8 @@ type rawHvsockAddr struct { ServiceID guid.GUID } +var _ sockets.RawSockaddr = &rawHvsockAddr{} + // Network returns the address's network name, "hvsock". func (addr *HvsockAddr) Network() string { return "hvsock" @@ -47,7 +108,7 @@ func (addr *HvsockAddr) String() string { // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port. func VsockServiceID(port uint32) guid.GUID { - g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3") + g := hvguidVSockServiceTemplate // make a copy g.Data1 = port return g } @@ -65,18 +126,43 @@ func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { addr.ServiceID = raw.ServiceID } +// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()` +func (r *rawHvsockAddr) Sockaddr() (ptr unsafe.Pointer, len int32, err error) { + return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil +} + +// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()` +func (r *rawHvsockAddr) FromBytes(b []byte) error { + n := int(unsafe.Sizeof(rawHvsockAddr{})) + + if len(b) < n { + return fmt.Errorf("got %d, want %d: %w", len(b), n, sockets.ErrBufferSize) + } + + copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n]) + if r.Family != afHvSock { + return fmt.Errorf("got %d, want %d: %w", r.Family, afHvSock, sockets.ErrAddrFamily) + } + + return nil +} + // HvsockListener is a socket listener for the AF_HYPERV address family. type HvsockListener struct { sock *win32File addr HvsockAddr } +var _ net.Listener = &HvsockListener{} + // HvsockConn is a connected socket of the AF_HYPERV address family. type HvsockConn struct { sock *win32File local, remote HvsockAddr } +var _ net.Conn = &HvsockConn{} + func newHvSocket() (*win32File, error) { fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1) if err != nil { @@ -99,7 +185,7 @@ func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { return nil, l.opErr("listen", err) } sa := addr.raw() - err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa))) + err = sockets.Bind(windows.Handle(sock.handle), &sa) if err != nil { return nil, l.opErr("listen", os.NewSyscallError("socket", err)) } @@ -136,21 +222,54 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) { } defer l.sock.wg.Done() - // AcceptEx, per documentation, requires an extra 16 bytes per address. + // AcceptEx, per documentation, requires an extra 16 bytes per address: + // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{})) var addrbuf [addrlen * 2]byte var bytes uint32 - err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o) - _, err = l.sock.asyncIo(c, nil, bytes, err) + err = syscall.AcceptEx(l.sock.handle, + sock.handle, + &addrbuf[0], + 0, // rxdatalen + addrlen, + addrlen, + &bytes, + &c.o) + _, err = l.sock.asyncIo(c, + nil, // deadlineHandler + bytes, + err) if err != nil { return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) } + conn := &HvsockConn{ sock: sock, } conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0]))) conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) + + // initialize the accepted socket and update its properties with those of the listening socket + if err = windows.Setsockopt(windows.Handle(sock.handle), + windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT, + (*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil { + return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err)) + } + + // The local address returned in the AcceptEx buffer is the same as the Listener socket's + // address. However, the service GUID reported by GetSockName is different from the Listeners + // socket, and is sometimes the same as the local address of the socket that dialed the + // address, with the service GUID.Data1 incremented, but othertimes is different. + // todo: does the local address matter? is the listener's address or the actual address appropriate? + + // var ra rawHvsockAddr + // err = sockets.GetSockName(windows.Handle(sock.handle), &ra) + // if err != nil { + // return nil, conn.opErr("accept", os.NewSyscallError("getsockname", err)) + // } + // conn.local.fromRaw(&ra) + sock = nil return conn, nil } @@ -160,36 +279,154 @@ func (l *HvsockListener) Close() error { return l.sock.Close() } -/* Need to finish ConnectEx handling -func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) { +type HvsockDialer struct { + // Deadline is the time the Dial operation must connect before erroring. + Deadline time.Time + + // Retries is the number of additional connects to try if the connection times out, is refused, + // or the host is unreachable + Retries uint + + // RetryWait is the time to wait after a connection error to retry + RetryWait time.Duration + + rt *time.Timer // redial wait timer +} + +func (d *HvsockDialer) Dial(addr *HvsockAddr) (*HvsockConn, error) { + return d.DialContext(context.Background(), addr) +} + +func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { + op := "dial" + // create the conn early to use opErr() + conn = &HvsockConn{ + remote: *addr, + } + + if !d.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, d.Deadline) + defer cancel() + } + + // preemptive timeout/cancellation check + if err = ctx.Err(); err != nil { + return nil, conn.opErr(op, err) + } + sock, err := newHvSocket() if err != nil { - return nil, err + return nil, conn.opErr(op, err) } defer func() { if sock != nil { sock.Close() } }() + + sa := addr.raw() + err = sockets.Bind(windows.Handle(sock.handle), &sa) + if err != nil { + return nil, conn.opErr(op, os.NewSyscallError("bind", err)) + } + c, err := sock.prepareIo() if err != nil { - return nil, err + return nil, conn.opErr(op, err) } defer sock.wg.Done() var bytes uint32 - err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o) - _, err = sock.asyncIo(ctx, c, nil, bytes, err) + n := 1 + int(d.Retries) + for i := 1; i <= n; i++ { + err = sockets.ConnectEx( + windows.Handle(sock.handle), + &sa, + nil, // sendBuf + 0, // sendDataLen + &bytes, + (*windows.Overlapped)(unsafe.Pointer(&c.o))) + // todo: create an asyncIO version that takes a context + // could create a deadlineHandler triggered by context cancelation, but that seems inefficient ... + _, err = sock.asyncIo(c, nil, bytes, err) + if i < n && canRedial(err) { + if err = d.redialWait(ctx); err != nil { + break + } + continue + } + break + } if err != nil { - return nil, err + return nil, conn.opErr(op, os.NewSyscallError("connectex", err)) } - conn := &HvsockConn{ - sock: sock, - remote: *addr, + + // update the connection properties, so shutdown can be used + if err = windows.Setsockopt( + windows.Handle(sock.handle), + windows.SOL_SOCKET, + windows.SO_UPDATE_CONNECT_CONTEXT, + nil, // optvalue + 0, // optlen + ); err != nil { + return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err)) + } + + // get the local name + var sal rawHvsockAddr + err = sockets.GetSockName(windows.Handle(sock.handle), &sal) + if err != nil { + return nil, conn.opErr(op, os.NewSyscallError("getsockname", err)) } + conn.local.fromRaw(&sal) + + // one last check for timeout, since asyncIO doesnt check the context + if err = ctx.Err(); err != nil { + return nil, conn.opErr(op, err) + } + + conn.sock = sock sock = nil + return conn, nil } -*/ + +func (d *HvsockDialer) redialWait(ctx context.Context) (err error) { + if d.RetryWait == 0 { + return nil + } + + if d.rt == nil { + d.rt = time.NewTimer(d.RetryWait) + } else { + // should already be stopped and drained + d.rt.Reset(d.RetryWait) + } + + select { + case <-ctx.Done(): + case <-d.rt.C: + return nil + } + + // stop and drain the timer + if !d.rt.Stop() { + <-d.rt.C + } + return ctx.Err() +} + +// assumes error is a plain, unwrapped syscall.Errno provided by direct syscall +func canRedial(err error) bool { + // nolint:errorlint + switch err { + case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT, + windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL: + return true + default: + return false + } +} func (conn *HvsockConn) opErr(op string, err error) error { return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} @@ -257,6 +494,7 @@ func (conn *HvsockConn) IsClosed() bool { return conn.sock.IsClosed() } +// shutdown disables sending or receiving on a socket func (conn *HvsockConn) shutdown(how int) error { if conn.IsClosed() { return ErrFileClosed @@ -273,7 +511,7 @@ func (conn *HvsockConn) shutdown(how int) error { func (conn *HvsockConn) CloseRead() error { err := conn.shutdown(syscall.SHUT_RD) if err != nil { - return conn.opErr("close", err) + return conn.opErr("closeread", err) } return nil } @@ -283,7 +521,7 @@ func (conn *HvsockConn) CloseRead() error { func (conn *HvsockConn) CloseWrite() error { err := conn.shutdown(syscall.SHUT_WR) if err != nil { - return conn.opErr("close", err) + return conn.opErr("closewrite", err) } return nil } @@ -300,8 +538,13 @@ func (conn *HvsockConn) RemoteAddr() net.Addr { // SetDeadline implements the net.Conn SetDeadline method. func (conn *HvsockConn) SetDeadline(t time.Time) error { - conn.SetReadDeadline(t) - conn.SetWriteDeadline(t) + // todo: implement `SetDeadline` for `win32File` + if err := conn.SetReadDeadline(t); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + if err := conn.SetWriteDeadline(t); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } return nil } diff --git a/pkg/sockets/rawaddr.go b/pkg/sockets/rawaddr.go new file mode 100644 index 00000000..03b8fc4a --- /dev/null +++ b/pkg/sockets/rawaddr.go @@ -0,0 +1,67 @@ +package sockets + +import ( + "errors" + "fmt" + "unsafe" +) + +// todo: should these be custom types to store the desired/actual size and addr family? + +var ( + ErrBufferSize = errors.New("buffer size") + ErrInvalidPointer = errors.New("invalid pointer") + ErrAddrFamily = errors.New("address family") +) + +// todo: replace this with generics +// The function calls should be: +// +// type RawSockaddrHeader { +// Family uint16 +// } +// +// func ConnectEx[T ~RawSockaddrHeader] (s Handle, a *T, ...) error { +// n := unsafe.SizeOf(*a) +// r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), +// uintptr(unsafe.Pointer(a)), uintptr(n), /* ... */) +// /* ... */ +// } +// +// Similarly, `GetAcceptExSockaddrs` requires a `**sockaddr`, so the syscall can change the pointer +// to data it allocates. Currently, the options are (1) dealing with pointers to the interface +// `* RawSockaddr`, use reflection or pull the pointer from the internal interface representation, +// and change where the interface points to; or (2) allocate dedicate, presized buffers based on +// `(r RawSockaddr).Sockaddr()`'s return, and pass that to `(r RawSockaddr).FromBytes()`. +// It would be safer and more readable to have: +// +// func GetAcceptExSockaddrs[L ~RawSockaddrHeader, R ~RawSockaddrHeader]( +// b *byte, +// rxlen uint32, +// local **L, +// remote **R, +// ) error { /*...*/ } + +// RawSockaddr allows structs to be used with Bind and ConnectEx. The +// struct must meet the Wind32 sockaddr requirements specified here: +// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 +type RawSockaddr interface { + // Sockaddr returns a pointer to the RawSockaddr and the length of the struct. + Sockaddr() (ptr unsafe.Pointer, len int32, err error) + + // FromBytes populates the RawsockAddr with the data in the byte array. + // Implementers should check the buffer is correctly sized and the address family + // is appropriate. + // Receivers should be pointers. + FromBytes([]byte) error +} + +func validateSockAddr(ptr unsafe.Pointer, len int32) error { + if ptr == nil { + return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer) + } + if len < 1 { + return fmt.Errorf("buffer size %d < 1: %w", len, ErrBufferSize) + } + return nil +} diff --git a/pkg/sockets/sockets.go b/pkg/sockets/sockets.go new file mode 100644 index 00000000..74360f6b --- /dev/null +++ b/pkg/sockets/sockets.go @@ -0,0 +1,175 @@ +//go:build windows + +package sockets + +import ( + "fmt" + "net" + "sync" + "syscall" + "unsafe" + + "github.com/Microsoft/go-winio/pkg/guid" + "golang.org/x/sys/windows" +) + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go sockets.go + +//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname +//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername +//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind + +const socketError = uintptr(^uint32(0)) + +// CloseWriter is a connection that can disable writing to itself. +type CloseWriter interface { + net.Conn + CloseWrite() error +} + +// CloseReader is a connection that can disable reading from itself. +type CloseReader interface { + net.Conn + CloseRead() error +} + +// GetSockName returns the socket's local address. It will call the `rsa.FromBytes()` on the +// buffer returned by the getsockname syscall. The buffer is allocated to the size specified +// by `rsa.Sockaddr()`. +func GetSockName(s windows.Handle, rsa RawSockaddr) error { + // todo: replace this (and RawSockaddr) with generics + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not find socket size to allocate buffer: %w", err) + } + if err = validateSockAddr(ptr, l); err != nil { + return err + } + + b := make([]byte, l) + err = getsockname(s, unsafe.Pointer(&b[0]), &l) + if err != nil { + // although getsockname returns WSAEFAULT if the buffer is too small, it does not set + // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy + return err + } + return rsa.FromBytes(b[:l]) +} + +// GetPeerName returns the remote address the socket is connected to. +// +// See GetSockName for more information. +func GetPeerName(s windows.Handle, rsa RawSockaddr) error { + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not find socket size to allocate buffer: %w", err) + } + if err = validateSockAddr(ptr, l); err != nil { + return err + } + + b := make([]byte, l) + err = getpeername(s, unsafe.Pointer(&b[0]), &l) + if err != nil { + return err + } + return rsa.FromBytes(b[:l]) +} + +func Bind(s windows.Handle, rsa RawSockaddr) (err error) { + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not find socket pointer and size: %w", err) + } + if err = validateSockAddr(ptr, l); err != nil { + return err + } + + return bind(s, ptr, l) +} + +// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the +// their sockaddr interface, so they cannot be used with HvsockAddr +// Replicate functionality here from +// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go + +// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at +// runtime via a WSAIoctl call: +// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks + +type runtimeFunc struct { + id guid.GUID + once sync.Once + addr uintptr + err error +} + +func (f *runtimeFunc) Load() error { + f.once.Do(func() { + var s windows.Handle + s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP) + if f.err != nil { + return + } + defer windows.CloseHandle(s) + + var n uint32 + f.err = windows.WSAIoctl(s, + windows.SIO_GET_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(&f.id)), + uint32(unsafe.Sizeof(f.id)), + (*byte)(unsafe.Pointer(&f.addr)), + uint32(unsafe.Sizeof(f.addr)), + &n, + nil, //overlapped + 0, //completionRoutine + ) + }) + return f.err + +} + +var ( + // todo: add `AcceptEx` and `GetAcceptExSockaddrs` + WSAID_CONNECTEX = guid.GUID{ + Data1: 0x25a207b9, + Data2: 0xddf3, + Data3: 0x4660, + Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e}, + } + + connectExFunc = runtimeFunc{id: WSAID_CONNECTEX} +) + +func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) error { + err := connectExFunc.Load() + if err != nil { + return fmt.Errorf("failed to load ConnectEx function pointer: %e", err) + } + ptr, n, err := rsa.Sockaddr() + if err != nil { + return err + } + return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped) +} + +// BOOL LpfnConnectex( +// [in] SOCKET s, +// [in] const sockaddr *name, +// [in] int namelen, +// [in, optional] PVOID lpSendBuffer, +// [in] DWORD dwSendDataLength, +// [out] LPDWORD lpdwBytesSent, +// [in] LPOVERLAPPED lpOverlapped +// ) +func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) (err error) { + r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0) + if r1 == 0 { + if e1 != 0 { + err = error(e1) + } else { + err = syscall.EINVAL + } + } + return +} diff --git a/pkg/sockets/zsyscall_windows.go b/pkg/sockets/zsyscall_windows.go new file mode 100644 index 00000000..152ec7f6 --- /dev/null +++ b/pkg/sockets/zsyscall_windows.go @@ -0,0 +1,70 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package sockets + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + + procbind = modws2_32.NewProc("bind") + procgetpeername = modws2_32.NewProc("getpeername") + procgetsockname = modws2_32.NewProc("getsockname") +) + +func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) { + r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen)) + if r1 == socketError { + err = errnoErr(e1) + } + return +} + +func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { + r1, _, e1 := syscall.Syscall(procgetpeername.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) + if r1 == socketError { + err = errnoErr(e1) + } + return +} + +func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { + r1, _, e1 := syscall.Syscall(procgetsockname.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) + if r1 == socketError { + err = errnoErr(e1) + } + return +} diff --git a/zsyscall_windows.go b/zsyscall_windows.go index 176ff75e..038ee3d5 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -74,7 +74,6 @@ var ( procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") - procbind = modws2_32.NewProc("bind") ) func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) { @@ -417,11 +416,3 @@ func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint } return } - -func bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) { - r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen)) - if r1 == socketError { - err = errnoErr(e1) - } - return -} From a615ab277b1f9d3ffb72a05bd60bb031713b1e79 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Thu, 24 Mar 2022 11:05:38 -0400 Subject: [PATCH 2/3] PR: error handling bug, rebase, unexport, naming * comments and todos statements * spelling * removed dead code * changed names to be more conventional * unexported socket code * made `(*HvsockDialer) Dial` take `Context`, removed `DialContext` * added default `Dial` function * rebased onto main * cleaned up `Dial(` retry loop Signed-off-by: Hamza El-Saawy --- file.go | 2 + hvsock.go | 159 +++++++++--------- internal/socket/rawaddr.go | 41 +++++ .../sockets.go => internal/socket/socket.go | 28 +-- .../socket}/zsyscall_windows.go | 2 +- pkg/sockets/rawaddr.go | 67 -------- 6 files changed, 136 insertions(+), 163 deletions(-) create mode 100644 internal/socket/rawaddr.go rename pkg/sockets/sockets.go => internal/socket/socket.go (90%) rename {pkg/sockets => internal/socket}/zsyscall_windows.go (99%) delete mode 100644 pkg/sockets/rawaddr.go diff --git a/file.go b/file.go index 293ab54c..1b870350 100644 --- a/file.go +++ b/file.go @@ -178,6 +178,8 @@ func ioCompletionProcessor(h syscall.Handle) { } } +// todo: helsaawy - create an asyncIO version that takes a context + // asyncIo processes the return value from ReadFile or WriteFile, blocking until // the operation has actually completed. func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { diff --git a/hvsock.go b/hvsock.go index b40add5c..8872f471 100644 --- a/hvsock.go +++ b/hvsock.go @@ -5,6 +5,7 @@ package winio import ( "context" + "errors" "fmt" "io" "net" @@ -15,72 +16,85 @@ import ( "golang.org/x/sys/windows" + "github.com/Microsoft/go-winio/internal/socket" "github.com/Microsoft/go-winio/pkg/guid" - "github.com/Microsoft/go-winio/pkg/sockets" ) const afHvSock = 34 // AF_HYPERV -var ( - // https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards +// Well known Service and VM IDs +//https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards - // HVguidWildcard is the wildcard VmId for accepting connections from all partitions. - HVguidWildcard = guid.GUID{} // 00000000-0000-0000-0000-000000000000 +// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions. +func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000 + return guid.GUID{} +} - // HVguidBroadcast is the wildcard VmId for broadcasting sends to all partitions - HVguidBroadcast = guid.GUID{ //ffffffff-ffff-ffff-ffff-ffffffffffff +// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions +func HvsockGUIDBroadcast() guid.GUID { //ffffffff-ffff-ffff-ffff-ffffffffffff + return guid.GUID{ Data1: 0xffffffff, Data2: 0xffff, Data3: 0xffff, Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, } +} - // HVGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector. - HVguidLoopback = guid.GUID{ // e0e16197-dd56-4a10-9195-5ee7a155a838 +// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector. +func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838 + return guid.GUID{ Data1: 0xe0e16197, Data2: 0xdd56, Data3: 0x4a10, Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38}, } +} - // HVguidSiloHost is the address of a silo's host partition: - // - The silo host of a hosted silo is the utility VM. - // - The silo host of a silo on a physical host is the physical host. - HVguidSiloHost = guid.GUID{ // 36bd0c5c-7276-4223-88ba-7d03b654c568 +// HvsockGUIDSiloHost is the address of a silo's host partition: +// - The silo host of a hosted silo is the utility VM. +// - The silo host of a silo on a physical host is the physical host. +func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568 + return guid.GUID{ Data1: 0x36bd0c5c, Data2: 0x7276, Data3: 0x4223, Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68}, } +} - // HVguidChildren is the wildcard VmId for accepting connections from the connector's child partitions. - HVguidChildren = guid.GUID{ // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd +// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions. +func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd + return guid.GUID{ Data1: 0x90db8b89, Data2: 0xd35, Data3: 0x4f79, Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd}, } +} - // HVguidParent is the wildcard VmId for accepting connections from the connector's parent partition. - // Listening on this VmId accepts connection from: - // - Inside silos: silo host partition. - // - Inside hosted silo: host of the VM. - // - Inside VM: VM host. - // - Physical host: Not supported. - HVguidParent = guid.GUID{ // a42e7cda-d03f-480c-9cc2-a4de20abb878 +// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition. +// Listening on this VmId accepts connection from: +// - Inside silos: silo host partition. +// - Inside hosted silo: host of the VM. +// - Inside VM: VM host. +// - Physical host: Not supported. +func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878 + return guid.GUID{ Data1: 0xa42e7cda, Data2: 0xd03f, Data3: 0x480c, Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78}, } +} - // HVguidVSockServiceGUIDTemplate is the Service GUID used for the VSOCK protocol - hvguidVSockServiceTemplate = guid.GUID{ // 00000000-facb-11e6-bd58-64006a7986d3 +// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol +func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3 + return guid.GUID{ Data2: 0xfacb, Data3: 0x11e6, Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3}, } -) +} // An HvsockAddr is an address for a AF_HYPERV socket. type HvsockAddr struct { @@ -95,7 +109,7 @@ type rawHvsockAddr struct { ServiceID guid.GUID } -var _ sockets.RawSockaddr = &rawHvsockAddr{} +var _ socket.RawSockaddr = &rawHvsockAddr{} // Network returns the address's network name, "hvsock". func (addr *HvsockAddr) Network() string { @@ -108,7 +122,7 @@ func (addr *HvsockAddr) String() string { // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port. func VsockServiceID(port uint32) guid.GUID { - g := hvguidVSockServiceTemplate // make a copy + g := hvsockVsockServiceTemplate() // make a copy g.Data1 = port return g } @@ -136,12 +150,12 @@ func (r *rawHvsockAddr) FromBytes(b []byte) error { n := int(unsafe.Sizeof(rawHvsockAddr{})) if len(b) < n { - return fmt.Errorf("got %d, want %d: %w", len(b), n, sockets.ErrBufferSize) + return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize) } copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n]) if r.Family != afHvSock { - return fmt.Errorf("got %d, want %d: %w", r.Family, afHvSock, sockets.ErrAddrFamily) + return fmt.Errorf("got %d, want %d: %w", r.Family, afHvSock, socket.ErrAddrFamily) } return nil @@ -185,7 +199,7 @@ func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { return nil, l.opErr("listen", err) } sa := addr.raw() - err = sockets.Bind(windows.Handle(sock.handle), &sa) + err = socket.Bind(windows.Handle(sock.handle), &sa) if err != nil { return nil, l.opErr("listen", os.NewSyscallError("socket", err)) } @@ -228,25 +242,19 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) { var addrbuf [addrlen * 2]byte var bytes uint32 - err = syscall.AcceptEx(l.sock.handle, - sock.handle, - &addrbuf[0], - 0, // rxdatalen - addrlen, - addrlen, - &bytes, - &c.o) - _, err = l.sock.asyncIo(c, - nil, // deadlineHandler - bytes, - err) - if err != nil { + err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /*rxdatalen*/, addrlen, addrlen, &bytes, &c.o) + if _, err = l.sock.asyncIo(c, nil, bytes, err); err != nil { return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) } conn := &HvsockConn{ sock: sock, } + // The local address returned in the AcceptEx buffer is the same as the Listener socket's + // address. However, the service GUID reported by GetSockName is different from the Listeners + // socket, and is sometimes the same as the local address of the socket that dialed the + // address, with the service GUID.Data1 incremented, but othertimes is different. + // todo: does the local address matter? is the listener's address or the actual address appropriate? conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0]))) conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) @@ -257,19 +265,6 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) { return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err)) } - // The local address returned in the AcceptEx buffer is the same as the Listener socket's - // address. However, the service GUID reported by GetSockName is different from the Listeners - // socket, and is sometimes the same as the local address of the socket that dialed the - // address, with the service GUID.Data1 incremented, but othertimes is different. - // todo: does the local address matter? is the listener's address or the actual address appropriate? - - // var ra rawHvsockAddr - // err = sockets.GetSockName(windows.Handle(sock.handle), &ra) - // if err != nil { - // return nil, conn.opErr("accept", os.NewSyscallError("getsockname", err)) - // } - // conn.local.fromRaw(&ra) - sock = nil return conn, nil } @@ -279,6 +274,7 @@ func (l *HvsockListener) Close() error { return l.sock.Close() } +// HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]). type HvsockDialer struct { // Deadline is the time the Dial operation must connect before erroring. Deadline time.Time @@ -293,11 +289,18 @@ type HvsockDialer struct { rt *time.Timer // redial wait timer } -func (d *HvsockDialer) Dial(addr *HvsockAddr) (*HvsockConn, error) { - return d.DialContext(context.Background(), addr) +// Dial the Hyper-V socket at addr. +// +// See (*HvsockDialer).Dial for more information. +func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { + return (&HvsockDialer{}).Dial(ctx, addr) } -func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { +// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful. +// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between +// retries. +// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx. +func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { op := "dial" // create the conn early to use opErr() conn = &HvsockConn{ @@ -326,7 +329,7 @@ func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn }() sa := addr.raw() - err = sockets.Bind(windows.Handle(sock.handle), &sa) + err = socket.Bind(windows.Handle(sock.handle), &sa) if err != nil { return nil, conn.opErr(op, os.NewSyscallError("bind", err)) } @@ -337,23 +340,19 @@ func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn } defer sock.wg.Done() var bytes uint32 - n := 1 + int(d.Retries) - for i := 1; i <= n; i++ { - err = sockets.ConnectEx( + for i := uint(0); i <= d.Retries; i++ { + err = socket.ConnectEx( windows.Handle(sock.handle), &sa, nil, // sendBuf 0, // sendDataLen &bytes, (*windows.Overlapped)(unsafe.Pointer(&c.o))) - // todo: create an asyncIO version that takes a context - // could create a deadlineHandler triggered by context cancelation, but that seems inefficient ... _, err = sock.asyncIo(c, nil, bytes, err) - if i < n && canRedial(err) { - if err = d.redialWait(ctx); err != nil { - break + if i < d.Retries && canRedial(err) { + if err = d.redialWait(ctx); err == nil { + continue } - continue } break } @@ -374,7 +373,7 @@ func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn // get the local name var sal rawHvsockAddr - err = sockets.GetSockName(windows.Handle(sock.handle), &sal) + err = socket.GetSockName(windows.Handle(sock.handle), &sal) if err != nil { return nil, conn.opErr(op, os.NewSyscallError("getsockname", err)) } @@ -391,6 +390,7 @@ func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn return conn, nil } +// redialWait waits before attempting to redial, resetting the timer as appropriate func (d *HvsockDialer) redialWait(ctx context.Context) (err error) { if d.RetryWait == 0 { return nil @@ -429,6 +429,10 @@ func canRedial(err error) bool { } func (conn *HvsockConn) opErr(op string, err error) error { + // translate from "file closed" to "socket closed" + if errors.Is(err, ErrFileClosed) { + err = socket.ErrSocketClosed + } return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} } @@ -443,8 +447,8 @@ func (conn *HvsockConn) Read(b []byte) (int, error) { err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) if err != nil { - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsarecv", err) + if eno := windows.Errno(0); errors.As(err, &eno) { + err = os.NewSyscallError("wsarecv", eno) } return 0, conn.opErr("read", err) } else if n == 0 { @@ -477,8 +481,8 @@ func (conn *HvsockConn) write(b []byte) (int, error) { err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) if err != nil { - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsasend", err) + if eno := windows.Errno(0); errors.As(err, &eno) { + err = os.NewSyscallError("wsasend", eno) } return 0, conn.opErr("write", err) } @@ -497,11 +501,16 @@ func (conn *HvsockConn) IsClosed() bool { // shutdown disables sending or receiving on a socket func (conn *HvsockConn) shutdown(how int) error { if conn.IsClosed() { - return ErrFileClosed + return socket.ErrSocketClosed } err := syscall.Shutdown(conn.sock.handle, how) if err != nil { + // If the connection was closed, shutdowns fail with "not connected" + if errors.Is(err, windows.WSAENOTCONN) || + errors.Is(err, windows.WSAESHUTDOWN) { + err = socket.ErrSocketClosed + } return os.NewSyscallError("shutdown", err) } return nil diff --git a/internal/socket/rawaddr.go b/internal/socket/rawaddr.go new file mode 100644 index 00000000..180cc4c3 --- /dev/null +++ b/internal/socket/rawaddr.go @@ -0,0 +1,41 @@ +package socket + +import ( + "errors" + "fmt" + "unsafe" +) + +// todo: should these be custom types to store the desired/actual size and addr family? + +var ( + ErrBufferSize = errors.New("buffer size") + ErrInvalidPointer = errors.New("invalid pointer") + ErrAddrFamily = errors.New("address family") +) + +// todo: helsaawy - replace this with generics, along with GetSockName and co. + +// RawSockaddr allows structs to be used with Bind and ConnectEx. The +// struct must meet the Win32 sockaddr requirements specified here: +// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 +type RawSockaddr interface { + // Sockaddr returns a pointer to the RawSockaddr and the length of the struct. + Sockaddr() (unsafe.Pointer, int32, error) + + // FromBytes populates the RawsockAddr with the data in the byte array. + // Implementers should check the buffer is correctly sized and the address family + // is appropriate. + // Receivers should be pointers. + FromBytes([]byte) error +} + +func validateSockAddr(ptr unsafe.Pointer, n int32) error { + if ptr == nil { + return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer) + } + if n < 1 { + return fmt.Errorf("buffer size %d < 1: %w", n, ErrBufferSize) + } + return nil +} diff --git a/pkg/sockets/sockets.go b/internal/socket/socket.go similarity index 90% rename from pkg/sockets/sockets.go rename to internal/socket/socket.go index 74360f6b..b93fe332 100644 --- a/pkg/sockets/sockets.go +++ b/internal/socket/socket.go @@ -1,6 +1,6 @@ //go:build windows -package sockets +package socket import ( "fmt" @@ -13,7 +13,7 @@ import ( "golang.org/x/sys/windows" ) -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go sockets.go +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go socket.go //sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname //sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername @@ -21,23 +21,12 @@ import ( const socketError = uintptr(^uint32(0)) -// CloseWriter is a connection that can disable writing to itself. -type CloseWriter interface { - net.Conn - CloseWrite() error -} - -// CloseReader is a connection that can disable reading from itself. -type CloseReader interface { - net.Conn - CloseRead() error -} +var ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed) // GetSockName returns the socket's local address. It will call the `rsa.FromBytes()` on the // buffer returned by the getsockname syscall. The buffer is allocated to the size specified // by `rsa.Sockaddr()`. func GetSockName(s windows.Handle, rsa RawSockaddr) error { - // todo: replace this (and RawSockaddr) with generics ptr, l, err := rsa.Sockaddr() if err != nil { return fmt.Errorf("could not find socket size to allocate buffer: %w", err) @@ -111,7 +100,7 @@ func (f *runtimeFunc) Load() error { if f.err != nil { return } - defer windows.CloseHandle(s) + defer windows.CloseHandle(s) //nolint:errcheck var n uint32 f.err = windows.WSAIoctl(s, @@ -126,12 +115,11 @@ func (f *runtimeFunc) Load() error { ) }) return f.err - } var ( // todo: add `AcceptEx` and `GetAcceptExSockaddrs` - WSAID_CONNECTEX = guid.GUID{ + WSAID_CONNECTEX = guid.GUID{ //nolint:revive,stylecheck Data1: 0x25a207b9, Data2: 0xddf3, Data3: 0x4660, @@ -142,9 +130,8 @@ var ( ) func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) error { - err := connectExFunc.Load() - if err != nil { - return fmt.Errorf("failed to load ConnectEx function pointer: %e", err) + if err := connectExFunc.Load(); err != nil { + return fmt.Errorf("failed to load ConnectEx function pointer: %w", err) } ptr, n, err := rsa.Sockaddr() if err != nil { @@ -163,6 +150,7 @@ func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen ui // [in] LPOVERLAPPED lpOverlapped // ) func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) (err error) { + // todo: after upgrading to 1.18, switch to syscall.SyscallN from syscall.Syscall9 r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0) if r1 == 0 { if e1 != 0 { diff --git a/pkg/sockets/zsyscall_windows.go b/internal/socket/zsyscall_windows.go similarity index 99% rename from pkg/sockets/zsyscall_windows.go rename to internal/socket/zsyscall_windows.go index 152ec7f6..d1868e29 100644 --- a/pkg/sockets/zsyscall_windows.go +++ b/internal/socket/zsyscall_windows.go @@ -1,6 +1,6 @@ // Code generated by 'go generate'; DO NOT EDIT. -package sockets +package socket import ( "syscall" diff --git a/pkg/sockets/rawaddr.go b/pkg/sockets/rawaddr.go deleted file mode 100644 index 03b8fc4a..00000000 --- a/pkg/sockets/rawaddr.go +++ /dev/null @@ -1,67 +0,0 @@ -package sockets - -import ( - "errors" - "fmt" - "unsafe" -) - -// todo: should these be custom types to store the desired/actual size and addr family? - -var ( - ErrBufferSize = errors.New("buffer size") - ErrInvalidPointer = errors.New("invalid pointer") - ErrAddrFamily = errors.New("address family") -) - -// todo: replace this with generics -// The function calls should be: -// -// type RawSockaddrHeader { -// Family uint16 -// } -// -// func ConnectEx[T ~RawSockaddrHeader] (s Handle, a *T, ...) error { -// n := unsafe.SizeOf(*a) -// r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), -// uintptr(unsafe.Pointer(a)), uintptr(n), /* ... */) -// /* ... */ -// } -// -// Similarly, `GetAcceptExSockaddrs` requires a `**sockaddr`, so the syscall can change the pointer -// to data it allocates. Currently, the options are (1) dealing with pointers to the interface -// `* RawSockaddr`, use reflection or pull the pointer from the internal interface representation, -// and change where the interface points to; or (2) allocate dedicate, presized buffers based on -// `(r RawSockaddr).Sockaddr()`'s return, and pass that to `(r RawSockaddr).FromBytes()`. -// It would be safer and more readable to have: -// -// func GetAcceptExSockaddrs[L ~RawSockaddrHeader, R ~RawSockaddrHeader]( -// b *byte, -// rxlen uint32, -// local **L, -// remote **R, -// ) error { /*...*/ } - -// RawSockaddr allows structs to be used with Bind and ConnectEx. The -// struct must meet the Wind32 sockaddr requirements specified here: -// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 -type RawSockaddr interface { - // Sockaddr returns a pointer to the RawSockaddr and the length of the struct. - Sockaddr() (ptr unsafe.Pointer, len int32, err error) - - // FromBytes populates the RawsockAddr with the data in the byte array. - // Implementers should check the buffer is correctly sized and the address family - // is appropriate. - // Receivers should be pointers. - FromBytes([]byte) error -} - -func validateSockAddr(ptr unsafe.Pointer, len int32) error { - if ptr == nil { - return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer) - } - if len < 1 { - return fmt.Errorf("buffer size %d < 1: %w", len, ErrBufferSize) - } - return nil -} From cdc1225fb0b63be38ac7e5023339b3b6bbc6c25a Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Wed, 6 Jul 2022 11:21:42 -0400 Subject: [PATCH 3/3] PR: RawSockaddr validation, `.As(` style Signed-off-by: Hamza El-Saawy --- file.go | 1 + hvsock.go | 13 ++++++--- internal/socket/rawaddr.go | 39 +++++++-------------------- internal/socket/socket.go | 54 ++++++++++++++++---------------------- 4 files changed, 41 insertions(+), 66 deletions(-) diff --git a/file.go b/file.go index 1b870350..f05f1ef5 100644 --- a/file.go +++ b/file.go @@ -223,6 +223,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er // runtime.KeepAlive is needed, as c is passed via native // code to ioCompletionProcessor, c must remain alive // until the channel read is complete. + // todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive? runtime.KeepAlive(c) return int(r.bytes), err } diff --git a/hvsock.go b/hvsock.go index 8872f471..8b56f96b 100644 --- a/hvsock.go +++ b/hvsock.go @@ -140,8 +140,11 @@ func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { addr.ServiceID = raw.ServiceID } -// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()` -func (r *rawHvsockAddr) Sockaddr() (ptr unsafe.Pointer, len int32, err error) { +// Sockaddr returns a pointer to and the size of this struct. +// +// Implements the [socket.RawSockaddr] interface, and allows use in +// [socket.Bind()] and [socket.ConnectEx()] +func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) { return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil } @@ -447,7 +450,8 @@ func (conn *HvsockConn) Read(b []byte) (int, error) { err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) if err != nil { - if eno := windows.Errno(0); errors.As(err, &eno) { + var eno windows.Errno + if errors.As(err, &eno) { err = os.NewSyscallError("wsarecv", eno) } return 0, conn.opErr("read", err) @@ -481,7 +485,8 @@ func (conn *HvsockConn) write(b []byte) (int, error) { err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) if err != nil { - if eno := windows.Errno(0); errors.As(err, &eno) { + var eno windows.Errno + if errors.As(err, &eno) { err = os.NewSyscallError("wsasend", eno) } return 0, conn.opErr("write", err) diff --git a/internal/socket/rawaddr.go b/internal/socket/rawaddr.go index 180cc4c3..7e82f9af 100644 --- a/internal/socket/rawaddr.go +++ b/internal/socket/rawaddr.go @@ -1,41 +1,20 @@ package socket import ( - "errors" - "fmt" "unsafe" ) -// todo: should these be custom types to store the desired/actual size and addr family? - -var ( - ErrBufferSize = errors.New("buffer size") - ErrInvalidPointer = errors.New("invalid pointer") - ErrAddrFamily = errors.New("address family") -) - -// todo: helsaawy - replace this with generics, along with GetSockName and co. - -// RawSockaddr allows structs to be used with Bind and ConnectEx. The +// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The // struct must meet the Win32 sockaddr requirements specified here: // https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 +// +// Specifically, the struct size must be least larger than an int16 (unsigned short) +// for the address family. type RawSockaddr interface { - // Sockaddr returns a pointer to the RawSockaddr and the length of the struct. + // Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing + // for the RawSockaddr's data to be overwritten by syscalls (if necessary). + // + // It is the callers responsibility to validate that the values are valid; invalid + // pointers or size can cause a panic. Sockaddr() (unsafe.Pointer, int32, error) - - // FromBytes populates the RawsockAddr with the data in the byte array. - // Implementers should check the buffer is correctly sized and the address family - // is appropriate. - // Receivers should be pointers. - FromBytes([]byte) error -} - -func validateSockAddr(ptr unsafe.Pointer, n int32) error { - if ptr == nil { - return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer) - } - if n < 1 { - return fmt.Errorf("buffer size %d < 1: %w", n, ErrBufferSize) - } - return nil } diff --git a/internal/socket/socket.go b/internal/socket/socket.go index b93fe332..f4134df4 100644 --- a/internal/socket/socket.go +++ b/internal/socket/socket.go @@ -3,6 +3,7 @@ package socket import ( + "errors" "fmt" "net" "sync" @@ -21,57 +22,46 @@ import ( const socketError = uintptr(^uint32(0)) -var ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed) +var ( + // todo(helsaawy): create custom error types to store the desired vs actual size and addr family? + + ErrBufferSize = errors.New("buffer size") + ErrAddrFamily = errors.New("address family") + ErrInvalidPointer = errors.New("invalid pointer") + ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed) +) + +// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error) -// GetSockName returns the socket's local address. It will call the `rsa.FromBytes()` on the -// buffer returned by the getsockname syscall. The buffer is allocated to the size specified -// by `rsa.Sockaddr()`. +// GetSockName writes the local address of socket s to the [RawSockaddr] rsa. +// If rsa is not large enough, the [windows.WSAEFAULT] is returned. func GetSockName(s windows.Handle, rsa RawSockaddr) error { ptr, l, err := rsa.Sockaddr() if err != nil { - return fmt.Errorf("could not find socket size to allocate buffer: %w", err) - } - if err = validateSockAddr(ptr, l); err != nil { - return err + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) } - b := make([]byte, l) - err = getsockname(s, unsafe.Pointer(&b[0]), &l) - if err != nil { - // although getsockname returns WSAEFAULT if the buffer is too small, it does not set - // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy - return err - } - return rsa.FromBytes(b[:l]) + // although getsockname returns WSAEFAULT if the buffer is too small, it does not set + // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy + return getsockname(s, ptr, &l) } // GetPeerName returns the remote address the socket is connected to. // -// See GetSockName for more information. +// See [GetSockName] for more information. func GetPeerName(s windows.Handle, rsa RawSockaddr) error { ptr, l, err := rsa.Sockaddr() if err != nil { - return fmt.Errorf("could not find socket size to allocate buffer: %w", err) - } - if err = validateSockAddr(ptr, l); err != nil { - return err + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) } - b := make([]byte, l) - err = getpeername(s, unsafe.Pointer(&b[0]), &l) - if err != nil { - return err - } - return rsa.FromBytes(b[:l]) + return getpeername(s, ptr, &l) } func Bind(s windows.Handle, rsa RawSockaddr) (err error) { ptr, l, err := rsa.Sockaddr() if err != nil { - return fmt.Errorf("could not find socket pointer and size: %w", err) - } - if err = validateSockAddr(ptr, l); err != nil { - return err + return fmt.Errorf("could not retrieve socket pointer and size: %w", err) } return bind(s, ptr, l) @@ -150,7 +140,7 @@ func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen ui // [in] LPOVERLAPPED lpOverlapped // ) func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) (err error) { - // todo: after upgrading to 1.18, switch to syscall.SyscallN from syscall.Syscall9 + // todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0) if r1 == 0 { if e1 != 0 {