diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4b24bde4..24cfd046 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,8 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 + with: + go-version: '^1.17.0' - run: go test -gcflags=all=-d=checkptr -v ./... build: @@ -17,7 +19,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 with: - go-version: '^1.15.0' + go-version: '^1.17.0' - run: go build ./pkg/etw/sample/ - run: go build ./tools/etw-provider-gen/ diff --git a/.gitignore b/.gitignore index b883f1fd..d9fbc2e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -*.exe +.vscode/ + +*.exe \ No newline at end of file diff --git a/go.mod b/go.mod index f39a608d..a5bca01c 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/Microsoft/go-winio -go 1.13 +go 1.17 require ( github.com/sirupsen/logrus v1.7.0 diff --git a/hvsock.go b/hvsock.go index b2b644d0..9fe6d938 100644 --- a/hvsock.go +++ b/hvsock.go @@ -4,6 +4,7 @@ package winio import ( + "context" "fmt" "io" "net" @@ -12,15 +13,74 @@ import ( "time" "unsafe" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/go-winio/pkg/sockets" ) -//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind +const afHvSock = 34 // AF_HYPERV + +var ( + // https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards + + // HVguidWildcard is the wildcard VmId for accepting connections from all partitions. + HVguidWildcard = guid.GUID{} // 00000000-0000-0000-0000-000000000000 + + // HVguidBroadcast is the wildcard VmId for broadcasting sends to all partitions + HVguidBroadcast = guid.GUID{ //ffffffff-ffff-ffff-ffff-ffffffffffff + Data1: 0xffffffff, + Data2: 0xffff, + Data3: 0xffff, + Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + } + + // HVGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector. + HVguidLoopback = guid.GUID{ // e0e16197-dd56-4a10-9195-5ee7a155a838 + Data1: 0xe0e16197, + Data2: 0xdd56, + Data3: 0x4a10, + Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38}, + } + + // HVguidSiloHost is the address of a silo's host partition: + // - The silo host of a hosted silo is the utility VM. + // - The silo host of a silo on a physical host is the physical host. + HVguidSiloHost = guid.GUID{ // 36bd0c5c-7276-4223-88ba-7d03b654c568 + Data1: 0x36bd0c5c, + Data2: 0x7276, + Data3: 0x4223, + Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68}, + } -const ( - afHvSock = 34 // AF_HYPERV + // HVguidChildren is the wildcard VmId for accepting connections from the connector's child partitions. + HVguidChildren = guid.GUID{ // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd + Data1: 0x90db8b89, + Data2: 0xd35, + Data3: 0x4f79, + Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd}, + } + + // HVguidParent is the wildcard VmId for accepting connections from the connector's parent partition. + // Listening on this VmId accepts connection from: + // - Inside silos: silo host partition. + // - Inside hosted silo: host of the VM. + // - Inside VM: VM host. + // - Physical host: Not supported. + HVguidParent = guid.GUID{ // a42e7cda-d03f-480c-9cc2-a4de20abb878 + Data1: 0xa42e7cda, + Data2: 0xd03f, + Data3: 0x480c, + Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78}, + } - socketError = ^uintptr(0) + // HVguidVSockServiceGUIDTemplate is the Service GUID used for the VSOCK protocol + hvguidVSockServiceTemplate = guid.GUID{ // 00000000-facb-11e6-bd58-64006a7986d3 + Data2: 0xfacb, + Data3: 0x11e6, + Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3}, + } ) // An HvsockAddr is an address for a AF_HYPERV socket. @@ -36,6 +96,8 @@ type rawHvsockAddr struct { ServiceID guid.GUID } +var _ sockets.RawSockaddr = &rawHvsockAddr{} + // Network returns the address's network name, "hvsock". func (addr *HvsockAddr) Network() string { return "hvsock" @@ -47,7 +109,7 @@ func (addr *HvsockAddr) String() string { // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port. func VsockServiceID(port uint32) guid.GUID { - g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3") + g := hvguidVSockServiceTemplate // make a copy g.Data1 = port return g } @@ -65,6 +127,27 @@ func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { addr.ServiceID = raw.ServiceID } +// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()` +func (r *rawHvsockAddr) Sockaddr() (ptr unsafe.Pointer, len int32, err error) { + return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil +} + +// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()` +func (r *rawHvsockAddr) FromBytes(b []byte) error { + n := int(unsafe.Sizeof(rawHvsockAddr{})) + + if len(b) < n { + return fmt.Errorf("got %d, want %d: %w", len(b), n, sockets.ErrBufferSize) + } + + copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n]) + if r.Family != afHvSock { + return fmt.Errorf("got %d, want %d: %w", r.Family, afHvSock, sockets.ErrAddrFamily) + } + + return nil +} + // HvsockListener is a socket listener for the AF_HYPERV address family. type HvsockListener struct { sock *win32File @@ -99,7 +182,7 @@ func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { return nil, l.opErr("listen", err) } sa := addr.raw() - err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa))) + err = sockets.Bind(windows.Handle(sock.handle), &sa) if err != nil { return nil, l.opErr("listen", os.NewSyscallError("socket", err)) } @@ -136,21 +219,54 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) { } defer l.sock.wg.Done() - // AcceptEx, per documentation, requires an extra 16 bytes per address. + // AcceptEx, per documentation, requires an extra 16 bytes per address: + // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{})) var addrbuf [addrlen * 2]byte var bytes uint32 - err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o) - _, err = l.sock.asyncIo(c, nil, bytes, err) + err = syscall.AcceptEx(l.sock.handle, + sock.handle, + &addrbuf[0], + 0, // rxdatalen + addrlen, + addrlen, + &bytes, + &c.o) + _, err = l.sock.asyncIo(c, + nil, // deadlineHandler + bytes, + err) if err != nil { return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) } + conn := &HvsockConn{ sock: sock, } conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0]))) conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) + + // initialize the accepted socket and update its properties with those of the listening socket + if err = windows.Setsockopt(windows.Handle(sock.handle), + windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT, + (*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil { + return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err)) + } + + // The local address returned in the AcceptEx buffer is the same as the Listener socket's + // address. However, the service GUID reported by GetSockName is different from the Listeners + // socket, and is sometimes the same as the local address of the socket that dialed the + // address, with the service GUID.Data1 incremented, but othertimes is different. + // todo: does the local address matter? is the listener's address or the actual address appropriate? + + // var ra rawHvsockAddr + // err = sockets.GetSockName(windows.Handle(sock.handle), &ra) + // if err != nil { + // return nil, conn.opErr("accept", os.NewSyscallError("getsockname", err)) + // } + // conn.local.fromRaw(&ra) + sock = nil return conn, nil } @@ -160,36 +276,153 @@ func (l *HvsockListener) Close() error { return l.sock.Close() } -/* Need to finish ConnectEx handling -func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) { +type HvsockDialer struct { + // Deadline is the time the Dial operation must connect before erroring. + Deadline time.Time + + // Retries is the number of additional connects to try if the connection times out, is refused, + // or the host is unreachable + Retries uint + + // RetryWait is the time to wait after a connection error to retry + RetryWait time.Duration + + rt *time.Timer // redial wait timer +} + +func (d *HvsockDialer) Dial(addr *HvsockAddr) (*HvsockConn, error) { + return d.DialContext(context.Background(), addr) +} + +func (d *HvsockDialer) DialContext(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { + op := "dial" + // create the conn early to use opErr() + conn = &HvsockConn{ + remote: *addr, + } + + if !d.Deadline.IsZero() { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, d.Deadline) + defer cancel() + } + + // preemptive timeout/cancellation check + if err = ctx.Err(); err != nil { + return nil, conn.opErr(op, err) + } + sock, err := newHvSocket() if err != nil { - return nil, err + return nil, conn.opErr(op, err) } defer func() { if sock != nil { sock.Close() } }() + + sa := addr.raw() + err = sockets.Bind(windows.Handle(sock.handle), &sa) + if err != nil { + return nil, conn.opErr(op, os.NewSyscallError("bind", err)) + } + c, err := sock.prepareIo() if err != nil { - return nil, err + return nil, conn.opErr(op, err) } defer sock.wg.Done() var bytes uint32 - err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o) - _, err = sock.asyncIo(ctx, c, nil, bytes, err) + n := 1 + int(d.Retries) + for i := 1; i <= n; i++ { + err = sockets.ConnectEx( + windows.Handle(sock.handle), + &sa, + nil, // sendBuf + 0, // sendDataLen + &bytes, + (*windows.Overlapped)(unsafe.Pointer(&c.o))) + // todo: create an asyncIO version that takes a context + // could create a deadlineHandler triggered by context cancelation, but that seems inefficient ... + _, err = sock.asyncIo(c, nil, bytes, err) + if i < n && canRedial(err) { + if err = d.redialWait(ctx); err != nil { + break + } + continue + } + break + } if err != nil { - return nil, err + return nil, conn.opErr(op, os.NewSyscallError("connectex", err)) } - conn := &HvsockConn{ - sock: sock, - remote: *addr, + + // update the connection properties, so shutdown can be used + if err = windows.Setsockopt( + windows.Handle(sock.handle), + windows.SOL_SOCKET, + windows.SO_UPDATE_CONNECT_CONTEXT, + nil, // optvalue + 0, // optlen + ); err != nil { + return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err)) + } + + // get the local name + var sal rawHvsockAddr + err = sockets.GetSockName(windows.Handle(sock.handle), &sal) + if err != nil { + return nil, conn.opErr(op, os.NewSyscallError("getsockname", err)) + } + conn.local.fromRaw(&sal) + + // one last check for timeout, since asyncIO doesnt check the context + if err = ctx.Err(); err != nil { + return nil, conn.opErr(op, err) } + + conn.sock = sock sock = nil + return conn, nil } -*/ + +func (d *HvsockDialer) redialWait(ctx context.Context) (err error) { + if d.RetryWait == 0 { + return nil + } + + if d.rt == nil { + d.rt = time.NewTimer(d.RetryWait) + } else { + // should already be stopped and drained + d.rt.Reset(d.RetryWait) + } + + select { + case <-ctx.Done(): + case <-d.rt.C: + return nil + } + + // stop and drain the timer + if !d.rt.Stop() { + <-d.rt.C + } + return ctx.Err() +} + +// assumes error is a plain, unwrapped syscall.Errno +func canRedial(err error) bool { + switch err { + case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT, + windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL: + return true + default: + return false + } +} func (conn *HvsockConn) opErr(op string, err error) error { return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} @@ -257,6 +490,7 @@ func (conn *HvsockConn) IsClosed() bool { return conn.sock.IsClosed() } +// shutdown disables sending or receiving on a socket func (conn *HvsockConn) shutdown(how int) error { if conn.IsClosed() { return ErrFileClosed @@ -273,7 +507,7 @@ func (conn *HvsockConn) shutdown(how int) error { func (conn *HvsockConn) CloseRead() error { err := conn.shutdown(syscall.SHUT_RD) if err != nil { - return conn.opErr("close", err) + return conn.opErr("closeread", err) } return nil } @@ -283,7 +517,7 @@ func (conn *HvsockConn) CloseRead() error { func (conn *HvsockConn) CloseWrite() error { err := conn.shutdown(syscall.SHUT_WR) if err != nil { - return conn.opErr("close", err) + return conn.opErr("closewrite", err) } return nil } @@ -314,3 +548,33 @@ func (conn *HvsockConn) SetReadDeadline(t time.Time) error { func (conn *HvsockConn) SetWriteDeadline(t time.Time) error { return conn.sock.SetWriteDeadline(t) } + +// HvSockRegisterService registers the application defined by the guid and name with the +// Hyper-V Host's registry. +// +// See: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#register-a-new-application +func HvSockRegisterService(id guid.GUID, name string) error { + k, exists, err := registry.CreateKey(registry.LOCAL_MACHINE, hvSocketRegKey(id), registry.WRITE) + if err != nil { + return fmt.Errorf("could not create registry for %s: %w", id.String(), err) + } + defer k.Close() + + if exists { + return nil + } + if err = k.SetStringValue("ElementName", name); err != nil { + return fmt.Errorf("could not set service name to %s: %w", name, err) + } + + return nil +} + +// HvSockUnregisterService deleted the registration defined by the guid from the Hyper-V Host's registry. +func HvSockUnegisterService(id guid.GUID) error { + return registry.DeleteKey(registry.LOCAL_MACHINE, hvSocketRegKey(id)) +} + +func hvSocketRegKey(id guid.GUID) string { + return `SOFTWARE\Microsoft\Windows NT\CurrentVersion\Virtualization\GuestCommunicationServices\` + id.String() +} diff --git a/hvsock_test.go b/hvsock_test.go new file mode 100644 index 00000000..7559ddaf --- /dev/null +++ b/hvsock_test.go @@ -0,0 +1,692 @@ +//go:build windows + +package winio + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "testing" + "time" + + "github.com/Microsoft/go-winio/pkg/guid" + "github.com/Microsoft/go-winio/pkg/sockets" + "golang.org/x/sys/windows" +) + +// TODO: timeouts on listen + +// var addr = &HvsockAddr{ +// VMID: HVguidLoopback, +// ServiceID: VsockServiceID(58), +// } + +func randHvsockAddr() *HvsockAddr { + p := uint32(rand.Int31()) + return &HvsockAddr{ + VMID: HVguidLoopback, + ServiceID: VsockServiceID(p), + } + +} + +func serverListen(t *testing.T) (*HvsockListener, *HvsockAddr) { + a := randHvsockAddr() + l, err := ListenHvsock(a) + if err != nil { + t.Fatalf("could not listen: %v", err) + } + t.Cleanup(func() { + if err := l.Close(); err != nil { + t.Logf("could not close Hyper-V socket listener: %v", err) + } + }) + + return l, a +} + +func TestHvSockService(t *testing.T) { + a := hvguidVSockServiceTemplate + b := hvguidVSockServiceTemplate + a.Data1 = 2016 + + fmt.Println("a", a) + fmt.Println("b", b) + +} + +func TestHvSockConstants(t *testing.T) { + // not really constants ... + tests := []struct { + name string + want string + give guid.GUID + }{ + {"wildcard", "00000000-0000-0000-0000-000000000000", HVguidWildcard}, + {"broadcast", "ffffffff-ffff-ffff-ffff-ffffffffffff", HVguidBroadcast}, + {"loopback", "e0e16197-dd56-4a10-9195-5ee7a155a838", HVguidLoopback}, + {"children", "90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd", HVguidChildren}, + {"parent", "a42e7cda-d03f-480c-9cc2-a4de20abb878", HVguidParent}, + {"silohost", "36bd0c5c-7276-4223-88ba-7d03b654c568", HVguidSiloHost}, + {"vsock template", "00000000-facb-11e6-bd58-64006a7986d3", hvguidVSockServiceTemplate}, + } + for _, tt := range tests { + if tt.give.String() != tt.want { + t.Errorf("%s give: %v; want: %s", tt.name, tt.give, tt.want) + } + } +} +func TestHvSockAddresses(t *testing.T) { + errs := make(chan error) + defer close(errs) + + l, addr := serverListen(t) + var sv *HvsockConn + go func() { + ss, err := l.Accept() + sv = ss.(*HvsockConn) + if err != nil { + errs <- fmt.Errorf("listener accept error: %w", err) + return + } + errs <- nil + }() + + cl, err := (&HvsockDialer{}).Dial(addr) + if err != nil { + <-errs // wait on the go routine before closing it + t.Fatalf("could not dial: %s", err) + } + defer cl.Close() + + if err := <-errs; err != nil { + t.Fatalf(err.Error()) + } + defer sv.Close() + + la := (l.Addr()).(*HvsockAddr) + sra := (sv.RemoteAddr()).(*HvsockAddr) + sla := (sv.LocalAddr()).(*HvsockAddr) + cra := (cl.RemoteAddr()).(*HvsockAddr) + cla := (cl.LocalAddr()).(*HvsockAddr) + + t.Run("Info", func(t *testing.T) { + tests := []struct { + name string + give *HvsockAddr + want HvsockAddr + }{ + {"listener", la, *addr}, + {"client local", cla, HvsockAddr{HVguidChildren, sra.ServiceID}}, + {"client remote", cra, *addr}, + {"server local", sla, HvsockAddr{HVguidChildren, addr.ServiceID}}, + {"server remote", sra, HvsockAddr{HVguidLoopback, cla.ServiceID}}, + } + for _, tt := range tests { + if *tt.give != tt.want { + t.Errorf("%s address give: %v; want: %v", tt.name, tt.give, tt.want) + } + } + }) + + t.Run("OSinfo", func(t *testing.T) { + ra := rawHvsockAddr{} + sa := HvsockAddr{} + + localTests := []struct { + name string + giveSock *win32File + wantAddr HvsockAddr + }{ + {"listener", l.sock, *addr}, + {"client", cl.sock, HvsockAddr{HVguidChildren, cla.ServiceID}}, + // The server sockets local address seems arbitrary, so skip this test + // see comment in `(*HvsockListener) Accept()` for more info + // {"server", sv.sock, _sla}, + } + for _, tt := range localTests { + sockets.GetSockName(windows.Handle(tt.giveSock.handle), &ra) + sa.fromRaw(&ra) + if sa != tt.wantAddr { + t.Errorf("%s local addr give: %v; want: %v", tt.name, sa, tt.wantAddr) + } + } + + remoteTests := []struct { + name string + giveConn *HvsockConn + }{ + {"client", cl}, + {"server", sv}, + } + for _, tt := range remoteTests { + sockets.GetPeerName(windows.Handle(tt.giveConn.sock.handle), &ra) + sa.fromRaw(&ra) + if sa != tt.giveConn.remote { + t.Errorf("%s remote addr give: %v; want: %v", tt.name, sa, tt.giveConn.remote) + } + } + }) +} + +func TestHvSockReadWrite(t *testing.T) { + svch := make(chan error) + defer close(svch) + clch := make(chan error) + defer close(clch) + + l, addr := serverListen(t) + + tests := []struct { + req, rsp string + }{ + {"hello ", "world!"}, + {"ping", "pong"}, + } + + go func() { + c, err := l.Accept() + if err != nil { + svch <- fmt.Errorf("listener accept error: %w", err) + return + } + defer c.Close() + + b := make([]byte, 64) + for _, tt := range tests { + n, err := c.Read(b) + if err != nil { + svch <- fmt.Errorf("server rx error: %w", err) + return + } + + r := string(b[:n]) + + if r != tt.req { + svch <- fmt.Errorf("server rx error, actual %q, expected %q", b[:n], tt.req) + return + } + + if _, err = c.Write([]byte(tt.rsp)); err != nil { + svch <- fmt.Errorf("server tx error, could not send %q: %w", tt.rsp, err) + return + } + } + n, err := c.Read(b) + if err != io.EOF && n != 0 { + svch <- fmt.Errorf("expected 0 bytes and EOF, actual %d, %v", n, err) + return + } + + svch <- nil + }() + + var cl *HvsockConn + go func() { + var err error + cl, err = (&HvsockDialer{}).Dial(addr) + if err != nil { + clch <- fmt.Errorf("client dial error: %w", err) + return + } + defer cl.Close() + + b := make([]byte, 64) + for _, tt := range tests { + if _, err := cl.Write([]byte(tt.req)); err != nil { + clch <- fmt.Errorf("client tx error, could not send %q: %w", tt.req, err) + return + } + + n, err := cl.Read(b) + if err != nil { + clch <- fmt.Errorf("client rx error: %w", err) + return + } + + r := string(b[:n]) + if r != tt.rsp { + clch <- fmt.Errorf("client rx error, actual %q, expected %q", b[:n], tt.rsp) + return + } + } + + cl.CloseWrite() + clch <- nil + }() + + var err error + tr := time.NewTimer(time.Minute) + defer tr.Stop() + + select { + case <-tr.C: + err = fmt.Errorf("test timed out") + case err = <-svch: + case err = <-clch: + } + if err != nil { + t.Error(err.Error()) + l.Close() + cl.Close() + } + + // grab the other error too + select { + case err = <-svch: + case err = <-clch: + } + if err != nil { + t.Errorf(err.Error()) + } +} + +func TestHvSockReadTooSmall(t *testing.T) { + errs := make(chan error) + defer close(errs) + + s := "this is a really long string that hopefully takes up more than 16 bytes ..." + l, addr := serverListen(t) + + go func() { + c, err := l.Accept() + if err != nil { + errs <- fmt.Errorf("listener accept error: %w", err) + return + } + defer c.Close() + + b := make([]byte, 16) + ss := "" + for { + n, err := c.Read(b) + if err == io.EOF { + break + } else if err != nil { + errs <- fmt.Errorf("server rx error: %w", err) + return + } + ss += string(b[:n]) + } + + if ss != s { + errs <- fmt.Errorf("got wrong string: %q", ss) + } + errs <- nil + }() + + var cl *HvsockConn + go func() { + var err error + cl, err = (&HvsockDialer{}).Dial(addr) + if err != nil { + errs <- fmt.Errorf("client dial error: %w", err) + return + } + defer cl.Close() + + if _, err := cl.Write([]byte(s)); err != nil { + errs <- fmt.Errorf("client tx error, could not send: %w", err) + return + } + errs <- nil + }() + + var err error + tr := time.NewTimer(time.Minute) + defer tr.Stop() + + select { + case <-tr.C: + err = fmt.Errorf("test timed out") + case err = <-errs: + } + if err != nil { + t.Error(err.Error()) + l.Close() + cl.Close() + } + + // grab the other error too + if err := <-errs; err != nil { + t.Errorf(err.Error()) + } +} + +func TestHvSockCloseReadWriteListener(t *testing.T) { + errs := make(chan error) + defer close(errs) + syn := make(chan struct{}) + defer close(syn) + defer func() { + // make sure the go routine ends before closing the channels + if err := <-errs; err != nil { + t.Error(err.Error()) + } + }() + + l, addr := serverListen(t) + + go func() { + c, err := l.Accept() + if err != nil { + errs <- fmt.Errorf("listener accept error: %w", err) + return + } + defer c.Close() + + // + // test CloseWrite() + // + _, err = c.Write([]byte("test")) + if err != nil { + errs <- fmt.Errorf("server tx error: %w", err) + return + } + + cw := c.(sockets.CloseWriter) + if err = cw.CloseWrite(); err != nil { + errs <- fmt.Errorf("server close write: %w", err) + return + } + + _, err = c.Write([]byte("test")) + if !errors.Is(err, windows.WSAESHUTDOWN) { + errs <- fmt.Errorf("server did not shutdown writes: %w", err) + return + } + + // safe to call multiple times + if err = cw.CloseWrite(); err != nil { + errs <- fmt.Errorf("server second close write: %w", err) + return + } + + // + // test CloseRead() + // + b := make([]byte, 256) + n, err := c.Read(b) + if err != nil { + errs <- fmt.Errorf("server read: %w", err) + return + } + if string(b[:n]) != "test" { + errs <- fmt.Errorf("expected %q, actual %q", "test", b[:n]) + return + } + + cr := c.(sockets.CloseReader) + if err = cr.CloseRead(); err != nil { + errs <- fmt.Errorf("server close read: %w", err) + return + } + syn <- struct{}{} + // signal the client to send more info + // if it was sent before, the read would succeed if the data was buffered prior + _, err = c.Read(b) + if !errors.Is(err, windows.WSAESHUTDOWN) { + errs <- fmt.Errorf("server did not shutdown reads: %w", err) + return + } + + // safe to call multiple times + if err = cr.CloseRead(); err != nil { + errs <- fmt.Errorf("server second close read: %w", err) + return + } + + c.Close() + if err = cw.CloseWrite(); !errors.Is(err, ErrFileClosed) { + errs <- fmt.Errorf("client close write did not return `ErrFileClosed`: %w", err) + return + } + + if err = cr.CloseRead(); !errors.Is(err, ErrFileClosed) { + errs <- fmt.Errorf("client close read did not return `ErrFileClosed`: %w", err) + return + } + + errs <- nil + }() + + cl, err := (&HvsockDialer{}).Dial(addr) + if err != nil { + t.Fatalf("could not dial: %s", err) + } + defer cl.Close() + + b := make([]byte, 256) + n, err := cl.Read(b) + if err != nil { + t.Fatalf("client read: %v", err) + } + if string(b[:n]) != "test" { + t.Fatalf("expected %q, actual %q", "test", b[:n]) + } + + n, err = cl.Read(b) + if n != 0 && err != io.EOF { + t.Fatalf("client did not get EOF: %v", err) + } + + _, err = cl.Write([]byte("test")) + if err != nil { + t.Fatalf("client write: %v", err) + } + <-syn + // this should succeed + _, err = cl.Write([]byte("test2")) + if err != nil { + t.Fatalf("client write: %v", err) + } + +} + +func TestHvSockCloseReadWriteDial(t *testing.T) { + errs := make(chan error) + defer close(errs) + syn := make(chan struct{}) + defer close(syn) + + defer func() { + // make sure the go routine ends before closing the channels + if err := <-errs; err != nil { + t.Errorf(err.Error()) + } + }() + + l, addr := serverListen(t) + + go func() { + c, err := l.Accept() + if err != nil { + errs <- fmt.Errorf("listener accept error: %w", err) + return + } + defer c.Close() + + b := make([]byte, 256) + n, err := c.Read(b) + if err != nil { + errs <- fmt.Errorf("server read: %w", err) + return + } + if string(b[:n]) != "test" { + errs <- fmt.Errorf("expected %q, actual %q", "test", b[:n]) + return + } + + n, err = c.Read(b) + if n != 0 && err != io.EOF { + errs <- fmt.Errorf("server did not get EOF: %w", err) + return + } + + _, err = c.Write([]byte("test")) + if err != nil { + errs <- fmt.Errorf("server tx error: %w", err) + return + } + <-syn + _, err = c.Write([]byte("test")) + if err != nil { + errs <- fmt.Errorf("server tx error: %w", err) + return + } + + c.Close() + errs <- nil + }() + + cl, err := (&HvsockDialer{}).Dial(addr) + if err != nil { + t.Fatalf("could not dial: %s", err) + } + defer cl.Close() + + // + // test CloseWrite() + // + _, err = cl.Write([]byte("test")) + if err != nil { + t.Fatalf("client write: %v", err) + } + + if err = cl.CloseWrite(); err != nil { + t.Fatalf("client close write: %v", err) + } + + _, err = cl.Write([]byte("test")) + if !errors.Is(err, windows.WSAESHUTDOWN) { + t.Fatalf("client did not shutdown writes: %v", err) + } + + // safe to call multiple times + if err = cl.CloseWrite(); err != nil { + t.Fatalf("client second close write: %v", err) + } + + // + // test CloseRead() + // + b := make([]byte, 256) + n, err := cl.Read(b) + if err != nil { + t.Fatalf("client read: %v", err) + } + if string(b[:n]) != "test" { + t.Fatalf("expected %q, actual %q", "test", b[:n]) + } + + if err = cl.CloseRead(); err != nil { + t.Fatalf("client close read: %v", err) + } + + syn <- struct{}{} + // signal the client to send more info + // if it was sent before, the read would succeed if the data was buffered prior + _, err = cl.Read(b) + if !errors.Is(err, windows.WSAESHUTDOWN) { + t.Fatalf("client did not shutdown reads: %v", err) + } + + // safe to call multiple times + if err = cl.CloseRead(); err != nil { + t.Fatalf("client second close write: %v", err) + } + + l.Close() + cl.Close() + + if err = cl.CloseWrite(); !errors.Is(err, ErrFileClosed) { + t.Fatalf("client close write did not return `ErrFileClosed`: %v", err) + } + + if err = cl.CloseRead(); !errors.Is(err, ErrFileClosed) { + t.Fatalf("client close read did not return `ErrFileClosed`: %v", err) + } +} + +func TestHvSockDialNoTimeout(t *testing.T) { + errs := make(chan error) + defer close(errs) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + addr := randHvsockAddr() + cl, err := (&HvsockDialer{}).DialContext(ctx, addr) + if err != nil { + errs <- err + return + } + defer cl.Close() + errs <- errors.New("should not have gotten here") + }() + + select { + case err := <-errs: + if !errors.Is(err, windows.WSAECONNREFUSED) { + t.Fatalf("expected connection refused error, actual: %v", err) + } + // connections usually take about ~500µs + case <-time.After(2 * time.Millisecond): + t.Fatalf("dial did not time out") + } +} + +func TestHvSockDialDeadline(t *testing.T) { + d := &HvsockDialer{} + d.Deadline = time.Now().Add(50 * time.Microsecond) + d.Retries = 1 + // we need the wait time to be long enough for the deadline goroutine to run first and signal + // timeout + d.RetryWait = 100 * time.Millisecond + addr := randHvsockAddr() + cl, err := d.Dial(addr) + if err == nil { + cl.Close() + t.Fatalf("dial should not have finished") + } else if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("dial did not exceed deadline: %v", err) + } +} + +func TestHvSockDialContext(t *testing.T) { + errs := make(chan error) + defer close(errs) + + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(50*time.Microsecond, cancel) + + d := &HvsockDialer{} + d.Retries = 1 + d.RetryWait = 100 * time.Millisecond + addr := randHvsockAddr() + cl, err := d.DialContext(ctx, addr) + if err == nil { + cl.Close() + t.Fatalf("dial should not have finished") + } else if !errors.Is(err, context.Canceled) { + t.Fatalf("dial was not canceled: %v", err) + } +} + +func TestHvSockAcceptClose(t *testing.T) { + l, _ := serverListen(t) + go func() { + time.Sleep(50 * time.Millisecond) + l.Close() + }() + + c, err := l.Accept() + if err == nil { + c.Close() + t.Fatal("listener should not have accepted anything") + } else if !errors.Is(err, ErrFileClosed) { + t.Fatalf("expected %v, actual %v", ErrFileClosed, err) + } +} diff --git a/pkg/sockets/rawaddr.go b/pkg/sockets/rawaddr.go new file mode 100644 index 00000000..03b8fc4a --- /dev/null +++ b/pkg/sockets/rawaddr.go @@ -0,0 +1,67 @@ +package sockets + +import ( + "errors" + "fmt" + "unsafe" +) + +// todo: should these be custom types to store the desired/actual size and addr family? + +var ( + ErrBufferSize = errors.New("buffer size") + ErrInvalidPointer = errors.New("invalid pointer") + ErrAddrFamily = errors.New("address family") +) + +// todo: replace this with generics +// The function calls should be: +// +// type RawSockaddrHeader { +// Family uint16 +// } +// +// func ConnectEx[T ~RawSockaddrHeader] (s Handle, a *T, ...) error { +// n := unsafe.SizeOf(*a) +// r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), +// uintptr(unsafe.Pointer(a)), uintptr(n), /* ... */) +// /* ... */ +// } +// +// Similarly, `GetAcceptExSockaddrs` requires a `**sockaddr`, so the syscall can change the pointer +// to data it allocates. Currently, the options are (1) dealing with pointers to the interface +// `* RawSockaddr`, use reflection or pull the pointer from the internal interface representation, +// and change where the interface points to; or (2) allocate dedicate, presized buffers based on +// `(r RawSockaddr).Sockaddr()`'s return, and pass that to `(r RawSockaddr).FromBytes()`. +// It would be safer and more readable to have: +// +// func GetAcceptExSockaddrs[L ~RawSockaddrHeader, R ~RawSockaddrHeader]( +// b *byte, +// rxlen uint32, +// local **L, +// remote **R, +// ) error { /*...*/ } + +// RawSockaddr allows structs to be used with Bind and ConnectEx. The +// struct must meet the Wind32 sockaddr requirements specified here: +// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 +type RawSockaddr interface { + // Sockaddr returns a pointer to the RawSockaddr and the length of the struct. + Sockaddr() (ptr unsafe.Pointer, len int32, err error) + + // FromBytes populates the RawsockAddr with the data in the byte array. + // Implementers should check the buffer is correctly sized and the address family + // is appropriate. + // Receivers should be pointers. + FromBytes([]byte) error +} + +func validateSockAddr(ptr unsafe.Pointer, len int32) error { + if ptr == nil { + return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer) + } + if len < 1 { + return fmt.Errorf("buffer size %d < 1: %w", len, ErrBufferSize) + } + return nil +} diff --git a/pkg/sockets/sockets.go b/pkg/sockets/sockets.go new file mode 100644 index 00000000..74360f6b --- /dev/null +++ b/pkg/sockets/sockets.go @@ -0,0 +1,175 @@ +//go:build windows + +package sockets + +import ( + "fmt" + "net" + "sync" + "syscall" + "unsafe" + + "github.com/Microsoft/go-winio/pkg/guid" + "golang.org/x/sys/windows" +) + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go sockets.go + +//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname +//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername +//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind + +const socketError = uintptr(^uint32(0)) + +// CloseWriter is a connection that can disable writing to itself. +type CloseWriter interface { + net.Conn + CloseWrite() error +} + +// CloseReader is a connection that can disable reading from itself. +type CloseReader interface { + net.Conn + CloseRead() error +} + +// GetSockName returns the socket's local address. It will call the `rsa.FromBytes()` on the +// buffer returned by the getsockname syscall. The buffer is allocated to the size specified +// by `rsa.Sockaddr()`. +func GetSockName(s windows.Handle, rsa RawSockaddr) error { + // todo: replace this (and RawSockaddr) with generics + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not find socket size to allocate buffer: %w", err) + } + if err = validateSockAddr(ptr, l); err != nil { + return err + } + + b := make([]byte, l) + err = getsockname(s, unsafe.Pointer(&b[0]), &l) + if err != nil { + // although getsockname returns WSAEFAULT if the buffer is too small, it does not set + // &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy + return err + } + return rsa.FromBytes(b[:l]) +} + +// GetPeerName returns the remote address the socket is connected to. +// +// See GetSockName for more information. +func GetPeerName(s windows.Handle, rsa RawSockaddr) error { + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not find socket size to allocate buffer: %w", err) + } + if err = validateSockAddr(ptr, l); err != nil { + return err + } + + b := make([]byte, l) + err = getpeername(s, unsafe.Pointer(&b[0]), &l) + if err != nil { + return err + } + return rsa.FromBytes(b[:l]) +} + +func Bind(s windows.Handle, rsa RawSockaddr) (err error) { + ptr, l, err := rsa.Sockaddr() + if err != nil { + return fmt.Errorf("could not find socket pointer and size: %w", err) + } + if err = validateSockAddr(ptr, l); err != nil { + return err + } + + return bind(s, ptr, l) +} + +// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the +// their sockaddr interface, so they cannot be used with HvsockAddr +// Replicate functionality here from +// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go + +// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at +// runtime via a WSAIoctl call: +// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks + +type runtimeFunc struct { + id guid.GUID + once sync.Once + addr uintptr + err error +} + +func (f *runtimeFunc) Load() error { + f.once.Do(func() { + var s windows.Handle + s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP) + if f.err != nil { + return + } + defer windows.CloseHandle(s) + + var n uint32 + f.err = windows.WSAIoctl(s, + windows.SIO_GET_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(&f.id)), + uint32(unsafe.Sizeof(f.id)), + (*byte)(unsafe.Pointer(&f.addr)), + uint32(unsafe.Sizeof(f.addr)), + &n, + nil, //overlapped + 0, //completionRoutine + ) + }) + return f.err + +} + +var ( + // todo: add `AcceptEx` and `GetAcceptExSockaddrs` + WSAID_CONNECTEX = guid.GUID{ + Data1: 0x25a207b9, + Data2: 0xddf3, + Data3: 0x4660, + Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e}, + } + + connectExFunc = runtimeFunc{id: WSAID_CONNECTEX} +) + +func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) error { + err := connectExFunc.Load() + if err != nil { + return fmt.Errorf("failed to load ConnectEx function pointer: %e", err) + } + ptr, n, err := rsa.Sockaddr() + if err != nil { + return err + } + return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped) +} + +// BOOL LpfnConnectex( +// [in] SOCKET s, +// [in] const sockaddr *name, +// [in] int namelen, +// [in, optional] PVOID lpSendBuffer, +// [in] DWORD dwSendDataLength, +// [out] LPDWORD lpdwBytesSent, +// [in] LPOVERLAPPED lpOverlapped +// ) +func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) (err error) { + r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0) + if r1 == 0 { + if e1 != 0 { + err = error(e1) + } else { + err = syscall.EINVAL + } + } + return +} diff --git a/pkg/sockets/zsyscall_windows.go b/pkg/sockets/zsyscall_windows.go new file mode 100644 index 00000000..152ec7f6 --- /dev/null +++ b/pkg/sockets/zsyscall_windows.go @@ -0,0 +1,70 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package sockets + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + + procbind = modws2_32.NewProc("bind") + procgetpeername = modws2_32.NewProc("getpeername") + procgetsockname = modws2_32.NewProc("getsockname") +) + +func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) { + r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen)) + if r1 == socketError { + err = errnoErr(e1) + } + return +} + +func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { + r1, _, e1 := syscall.Syscall(procgetpeername.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) + if r1 == socketError { + err = errnoErr(e1) + } + return +} + +func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { + r1, _, e1 := syscall.Syscall(procgetsockname.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) + if r1 == socketError { + err = errnoErr(e1) + } + return +} diff --git a/zsyscall_windows.go b/zsyscall_windows.go index 176ff75e..038ee3d5 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -74,7 +74,6 @@ var ( procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") - procbind = modws2_32.NewProc("bind") ) func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) { @@ -417,11 +416,3 @@ func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint } return } - -func bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) { - r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen)) - if r1 == socketError { - err = errnoErr(e1) - } - return -}