Skip to content

Commit

Permalink
PR: error handling bug, rebase, unexport, naming
Browse files Browse the repository at this point in the history
changed names to be more conventional
unexporting socket code, rebased onto main

Signed-off-by: Hamza El-Saawy <hamzaelsaawy@microsoft.com>
  • Loading branch information
helsaawy committed May 5, 2022
1 parent f0a7033 commit 144d754
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 113 deletions.
99 changes: 61 additions & 38 deletions hvsock.go
Expand Up @@ -5,6 +5,7 @@ package winio

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand All @@ -15,72 +16,85 @@ import (

"golang.org/x/sys/windows"

"github.com/Microsoft/go-winio/internal/socket"
"github.com/Microsoft/go-winio/pkg/guid"
"github.com/Microsoft/go-winio/pkg/sockets"
)

const afHvSock = 34 // AF_HYPERV

var (
// https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
// Well know Service and VM IDs
//https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards

// HVguidWildcard is the wildcard VmId for accepting connections from all partitions.
HVguidWildcard = guid.GUID{} // 00000000-0000-0000-0000-000000000000
// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
return guid.GUID{}
}

// HVguidBroadcast is the wildcard VmId for broadcasting sends to all partitions
HVguidBroadcast = guid.GUID{ //ffffffff-ffff-ffff-ffff-ffffffffffff
// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions
func HvsockGUIDBroadcast() guid.GUID { //ffffffff-ffff-ffff-ffff-ffffffffffff
return guid.GUID{
Data1: 0xffffffff,
Data2: 0xffff,
Data3: 0xffff,
Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}
}

// HVGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
HVguidLoopback = guid.GUID{ // e0e16197-dd56-4a10-9195-5ee7a155a838
// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
return guid.GUID{
Data1: 0xe0e16197,
Data2: 0xdd56,
Data3: 0x4a10,
Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
}
}

// HVguidSiloHost is the address of a silo's host partition:
// - The silo host of a hosted silo is the utility VM.
// - The silo host of a silo on a physical host is the physical host.
HVguidSiloHost = guid.GUID{ // 36bd0c5c-7276-4223-88ba-7d03b654c568
// HvsockGUIDSiloHost is the address of a silo's host partition:
// - The silo host of a hosted silo is the utility VM.
// - The silo host of a silo on a physical host is the physical host.
func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
return guid.GUID{
Data1: 0x36bd0c5c,
Data2: 0x7276,
Data3: 0x4223,
Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
}
}

// HVguidChildren is the wildcard VmId for accepting connections from the connector's child partitions.
HVguidChildren = guid.GUID{ // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
return guid.GUID{
Data1: 0x90db8b89,
Data2: 0xd35,
Data3: 0x4f79,
Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
}
}

// HVguidParent is the wildcard VmId for accepting connections from the connector's parent partition.
// Listening on this VmId accepts connection from:
// - Inside silos: silo host partition.
// - Inside hosted silo: host of the VM.
// - Inside VM: VM host.
// - Physical host: Not supported.
HVguidParent = guid.GUID{ // a42e7cda-d03f-480c-9cc2-a4de20abb878
// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
// Listening on this VmId accepts connection from:
// - Inside silos: silo host partition.
// - Inside hosted silo: host of the VM.
// - Inside VM: VM host.
// - Physical host: Not supported.
func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
return guid.GUID{
Data1: 0xa42e7cda,
Data2: 0xd03f,
Data3: 0x480c,
Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
}
}

// HVguidVSockServiceGUIDTemplate is the Service GUID used for the VSOCK protocol
hvguidVSockServiceTemplate = guid.GUID{ // 00000000-facb-11e6-bd58-64006a7986d3
// HVguidVSockServiceGUIDTemplate is the Service GUID used for the VSOCK protocol
func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
return guid.GUID{
Data2: 0xfacb,
Data3: 0x11e6,
Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
}
)
}

// An HvsockAddr is an address for a AF_HYPERV socket.
type HvsockAddr struct {
Expand All @@ -95,7 +109,7 @@ type rawHvsockAddr struct {
ServiceID guid.GUID
}

var _ sockets.RawSockaddr = &rawHvsockAddr{}
var _ socket.RawSockaddr = &rawHvsockAddr{}

// Network returns the address's network name, "hvsock".
func (addr *HvsockAddr) Network() string {
Expand All @@ -108,7 +122,7 @@ func (addr *HvsockAddr) String() string {

// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
func VsockServiceID(port uint32) guid.GUID {
g := hvguidVSockServiceTemplate // make a copy
g := hvsockVsockServiceTemplate() // make a copy
g.Data1 = port
return g
}
Expand Down Expand Up @@ -136,12 +150,12 @@ func (r *rawHvsockAddr) FromBytes(b []byte) error {
n := int(unsafe.Sizeof(rawHvsockAddr{}))

if len(b) < n {
return fmt.Errorf("got %d, want %d: %w", len(b), n, sockets.ErrBufferSize)
return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
}

copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
if r.Family != afHvSock {
return fmt.Errorf("got %d, want %d: %w", r.Family, afHvSock, sockets.ErrAddrFamily)
return fmt.Errorf("got %d, want %d: %w", r.Family, afHvSock, socket.ErrAddrFamily)
}

return nil
Expand Down Expand Up @@ -181,7 +195,7 @@ func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
return nil, l.opErr("listen", err)
}
sa := addr.raw()
err = sockets.Bind(windows.Handle(sock.handle), &sa)
err = socket.Bind(windows.Handle(sock.handle), &sa)
if err != nil {
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
}
Expand Down Expand Up @@ -322,7 +336,7 @@ func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn
}()

sa := addr.raw()
err = sockets.Bind(windows.Handle(sock.handle), &sa)
err = socket.Bind(windows.Handle(sock.handle), &sa)
if err != nil {
return nil, conn.opErr(op, os.NewSyscallError("bind", err))
}
Expand All @@ -335,7 +349,7 @@ func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn
var bytes uint32
n := 1 + int(d.Retries)
for i := 1; i <= n; i++ {
err = sockets.ConnectEx(
err = socket.ConnectEx(
windows.Handle(sock.handle),
&sa,
nil, // sendBuf
Expand Down Expand Up @@ -370,7 +384,7 @@ func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn

// get the local name
var sal rawHvsockAddr
err = sockets.GetSockName(windows.Handle(sock.handle), &sal)
err = socket.GetSockName(windows.Handle(sock.handle), &sal)
if err != nil {
return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
}
Expand Down Expand Up @@ -424,6 +438,10 @@ func canRedial(err error) bool {
}

func (conn *HvsockConn) opErr(op string, err error) error {
// translate from "file closed" to "socket closed"
if errors.Is(err, ErrFileClosed) {
err = socket.ErrSocketClosed
}
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
}

Expand All @@ -438,8 +456,8 @@ func (conn *HvsockConn) Read(b []byte) (int, error) {
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
if err != nil {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("wsarecv", err)
if eno := windows.Errno(0); errors.As(err, &eno) {
err = os.NewSyscallError("wsarecv", eno)
}
return 0, conn.opErr("read", err)
} else if n == 0 {
Expand Down Expand Up @@ -472,8 +490,8 @@ func (conn *HvsockConn) write(b []byte) (int, error) {
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
if err != nil {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("wsasend", err)
if eno := windows.Errno(0); errors.As(err, &eno) {
err = os.NewSyscallError("wsasend", eno)
}
return 0, conn.opErr("write", err)
}
Expand All @@ -492,11 +510,16 @@ func (conn *HvsockConn) IsClosed() bool {
// shutdown disables sending or receiving on a socket
func (conn *HvsockConn) shutdown(how int) error {
if conn.IsClosed() {
return ErrFileClosed
return socket.ErrSocketClosed
}

err := syscall.Shutdown(conn.sock.handle, how)
if err != nil {
// If the connection was closed, shutdowns fail with "not connected"
if errors.Is(err, windows.WSAENOTCONN) ||
errors.Is(err, windows.WSAESHUTDOWN) {
err = socket.ErrSocketClosed
}
return os.NewSyscallError("shutdown", err)
}
return nil
Expand Down
41 changes: 41 additions & 0 deletions internal/socket/rawaddr.go
@@ -0,0 +1,41 @@
package socket

import (
"errors"
"fmt"
"unsafe"
)

// todo: should these be custom types to store the desired/actual size and addr family?

var (
ErrBufferSize = errors.New("buffer size")
ErrInvalidPointer = errors.New("invalid pointer")
ErrAddrFamily = errors.New("address family")
)

// todo: replace this with generics, along with GetSockName and co.

// RawSockaddr allows structs to be used with Bind and ConnectEx. The
// struct must meet the Wind32 sockaddr requirements specified here:
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
type RawSockaddr interface {
// Sockaddr returns a pointer to the RawSockaddr and the length of the struct.
Sockaddr() (unsafe.Pointer, int32, error)

// FromBytes populates the RawsockAddr with the data in the byte array.
// Implementers should check the buffer is correctly sized and the address family
// is appropriate.
// Receivers should be pointers.
FromBytes([]byte) error
}

func validateSockAddr(ptr unsafe.Pointer, n int32) error {
if ptr == nil {
return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer)
}
if n < 1 {
return fmt.Errorf("buffer size %d < 1: %w", n, ErrBufferSize)
}
return nil
}
13 changes: 6 additions & 7 deletions pkg/sockets/sockets.go → internal/socket/socket.go
@@ -1,6 +1,6 @@
//go:build windows

package sockets
package socket

import (
"fmt"
Expand All @@ -13,14 +13,16 @@ import (
"golang.org/x/sys/windows"
)

//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go sockets.go
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go socket.go

//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind

const socketError = uintptr(^uint32(0))

var ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)

// CloseWriter is a connection that can disable writing to itself.
type CloseWriter interface {
net.Conn
Expand All @@ -37,7 +39,6 @@ type CloseReader interface {
// buffer returned by the getsockname syscall. The buffer is allocated to the size specified
// by `rsa.Sockaddr()`.
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
// todo: replace this (and RawSockaddr) with generics
ptr, l, err := rsa.Sockaddr()
if err != nil {
return fmt.Errorf("could not find socket size to allocate buffer: %w", err)
Expand Down Expand Up @@ -126,7 +127,6 @@ func (f *runtimeFunc) Load() error {
)
})
return f.err

}

var (
Expand All @@ -142,9 +142,8 @@ var (
)

func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) error {
err := connectExFunc.Load()
if err != nil {
return fmt.Errorf("failed to load ConnectEx function pointer: %e", err)
if err := connectExFunc.Load(); err != nil {
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
}
ptr, n, err := rsa.Sockaddr()
if err != nil {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 144d754

Please sign in to comment.