From 6c516b9331808b1e1b37f5e6cca5b4548ce455f4 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Thu, 21 Jul 2022 19:40:33 -0400 Subject: [PATCH] Added HV Socket tests Added tests for core Hyper-V socket functionality, including testing CloseRead and CloseWrite, as well as checking addresses are appropriate and timeouts work. Signed-off-by: Hamza El-Saawy --- hvsock_test.go | 581 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 581 insertions(+) create mode 100644 hvsock_test.go diff --git a/hvsock_test.go b/hvsock_test.go new file mode 100644 index 00000000..ca682d58 --- /dev/null +++ b/hvsock_test.go @@ -0,0 +1,581 @@ +//go:build windows + +package winio + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/sys/windows" + + "github.com/Microsoft/go-winio/internal/socket" + "github.com/Microsoft/go-winio/pkg/guid" +) + +// TODO: timeouts on listen + +const testStr = "test" + +func randHvsockAddr() *HvsockAddr { + p := uint32(rand.Int31()) + return &HvsockAddr{ + VMID: HvsockGUIDLoopback(), + ServiceID: VsockServiceID(p), + } +} + +func serverListen(u *testUtil) (*HvsockListener, *HvsockAddr) { + a := randHvsockAddr() + l, err := ListenHvsock(a) + u.must(err, "could not listen") + u.T.Cleanup(func() { + if err := l.Close(); err != nil { + u.T.Logf("could not close Hyper-V socket listener: %v", err) + } + }) + + return l, a +} + +func TestHvSockConstants(t *testing.T) { + // not really constants ... + tests := []struct { + name string + want string + give guid.GUID + }{ + {"wildcard", "00000000-0000-0000-0000-000000000000", HvsockGUIDWildcard()}, + {"broadcast", "ffffffff-ffff-ffff-ffff-ffffffffffff", HvsockGUIDBroadcast()}, + {"loopback", "e0e16197-dd56-4a10-9195-5ee7a155a838", HvsockGUIDLoopback()}, + {"children", "90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd", HvsockGUIDChildren()}, + {"parent", "a42e7cda-d03f-480c-9cc2-a4de20abb878", HvsockGUIDParent()}, + {"silohost", "36bd0c5c-7276-4223-88ba-7d03b654c568", HvsockGUIDSiloHost()}, + {"vsock template", "00000000-facb-11e6-bd58-64006a7986d3", hvsockVsockServiceTemplate()}, + } + for _, tt := range tests { + if tt.give.String() != tt.want { + t.Errorf("%s give: %v; want: %s", tt.name, tt.give, tt.want) + } + } +} +func TestHvSockAddresses(t *testing.T) { + u := newUtil(t) + ch := make(chan struct{}) + + l, addr := serverListen(u) + var sv *HvsockConn + u.Go(func(u *testUtil) { + defer close(ch) + + c, err := l.Accept() + u.must(err, "listener accept error") + sv = c.(*HvsockConn) + }) + + cl, err := Dial(context.Background(), addr) + u.must(err, "could not dial") + defer cl.Close() + + u.wait(ch, time.Second) + check(t) + defer sv.Close() + + la := (l.Addr()).(*HvsockAddr) + sra := (sv.RemoteAddr()).(*HvsockAddr) + sla := (sv.LocalAddr()).(*HvsockAddr) + cra := (cl.RemoteAddr()).(*HvsockAddr) + cla := (cl.LocalAddr()).(*HvsockAddr) + + t.Run("Info", func(t *testing.T) { + tests := []struct { + name string + give *HvsockAddr + want HvsockAddr + }{ + {"listener", la, *addr}, + {"client local", cla, HvsockAddr{HvsockGUIDChildren(), sra.ServiceID}}, + {"client remote", cra, *addr}, + {"server local", sla, HvsockAddr{HvsockGUIDChildren(), addr.ServiceID}}, + {"server remote", sra, HvsockAddr{HvsockGUIDLoopback(), cla.ServiceID}}, + } + for _, tt := range tests { + if *tt.give != tt.want { + t.Errorf("%s address give: %v; want: %v", tt.name, tt.give, tt.want) + } + } + }) + + t.Run("OSinfo", func(t *testing.T) { + u := newUtil(t) + ra := rawHvsockAddr{} + sa := HvsockAddr{} + + localTests := []struct { + name string + giveSock *win32File + wantAddr HvsockAddr + }{ + {"listener", l.sock, *addr}, + {"client", cl.sock, HvsockAddr{HvsockGUIDChildren(), cla.ServiceID}}, + // The server sockets local address seems arbitrary, so skip this test + // see comment in `(*HvsockListener) Accept()` for more info + // {"server", sv.sock, _sla}, + } + for _, tt := range localTests { + u.must(socket.GetSockName(windows.Handle(tt.giveSock.handle), &ra)) + sa.fromRaw(&ra) + if sa != tt.wantAddr { + t.Errorf("%s local addr give: %v; want: %v", tt.name, sa, tt.wantAddr) + } + } + + remoteTests := []struct { + name string + giveConn *HvsockConn + }{ + {"client", cl}, + {"server", sv}, + } + for _, tt := range remoteTests { + u.must(socket.GetPeerName(windows.Handle(tt.giveConn.sock.handle), &ra)) + sa.fromRaw(&ra) + if sa != tt.giveConn.remote { + t.Errorf("%s remote addr give: %v; want: %v", tt.name, sa, tt.giveConn.remote) + } + } + }) +} + +func TestHvSockReadWrite(t *testing.T) { + u := newUtil(t) + l, addr := serverListen(u) + tests := []struct { + req, rsp string + }{ + {"hello ", "world!"}, + {"ping", "pong"}, + } + + svch := make(chan struct{}) + u.Go(func(u *testUtil) { + defer close(svch) + c, err := l.Accept() + u.must(err, "listener accept") + defer c.Close() + + b := make([]byte, 64) + for _, tt := range tests { + n, err := c.Read(b) + u.must(err, "server rx") + + r := string(b[:n]) + u.assert(r == tt.req, fmt.Sprintf("server rx error: got %q; wanted %q", r, tt.req)) + + _, err = c.Write([]byte(tt.rsp)) + u.must(err, "server tx error, could not send "+tt.rsp) + } + n, err := c.Read(b) + u.assert(n == 0, "server did not get EOF") + u.is(err, io.EOF, "server did not get EOF") + }) + + clch := make(chan error) + u.Go(func(u *testUtil) { + defer close(clch) + cl, err := Dial(context.Background(), addr) + if err != nil { + clch <- fmt.Errorf( + "client dial error") + return + } + defer cl.Close() + + b := make([]byte, 64) + for _, tt := range tests { + _, err := cl.Write([]byte(tt.req)) + u.must(err, "client tx error, could not send "+tt.req) + + n, err := cl.Read(b) + u.must(err, "client rx") + + r := string(b[:n]) + u.assert(r == tt.rsp, fmt.Sprintf("client rx error: got %q; wanted %q", b[:n], tt.rsp)) + } + + u.must(cl.CloseWrite()) + clch <- nil + }) + + select { + case <-time.After(time.Minute): + t.Fatalf("timed out") + case <-svch: + case <-clch: + } +} + +func TestHvSockReadTooSmall(t *testing.T) { + u := newUtil(t) + s := "this is a really long string that hopefully takes up more than 16 bytes ..." + l, addr := serverListen(u) + + var wg sync.WaitGroup + wg.Add(1) + u.Go(func(u *testUtil) { + defer wg.Done() + c, err := l.Accept() + u.must(err, "listener accept error") + defer c.Close() + + b := make([]byte, 16) + ss := "" + for { + n, err := c.Read(b) + if errors.Is(err, io.EOF) { + break + } + u.must(err, "server rx error") + ss += string(b[:n]) + } + + u.assert(ss == s, fmt.Sprintf("got %q, wanted: %q", ss, s)) + }) + + wg.Add(1) + u.Go(func(u *testUtil) { + defer wg.Done() + cl, err := Dial(context.Background(), addr) + u.must(err, "client dial error") + defer cl.Close() + + _, err = cl.Write([]byte(s)) + u.must(err, "client tx error, could not send") + }) + + ch := make(chan struct{}) + go func() { + wg.Wait() + close(ch) + }() + u.wait(ch, time.Minute) +} + +func TestHvSockCloseReadWriteListener(t *testing.T) { + u := newUtil(t) + l, addr := serverListen(u) + + ch := make(chan struct{}) + u.Go(func(u *testUtil) { + defer close(ch) + c, err := l.Accept() + u.must(err, "listener accept error") + defer c.Close() + + hv := c.(*HvsockConn) + // + // test CloseWrite() + // + _, err = c.Write([]byte(testStr)) + u.must(err, "server tx error") + + u.must(hv.CloseWrite(), "server close write") + + _, err = c.Write([]byte(testStr)) + u.is(err, windows.WSAESHUTDOWN, "server did not shutdown writes") + + // safe to call multiple times + u.must(hv.CloseWrite(), "server second close write") + + // + // test CloseRead() + // + b := make([]byte, 256) + n, err := c.Read(b) + u.must(err, "server read") + u.assert(string(b[:n]) == testStr, fmt.Sprintf("server got %q; wanted %q", b[:n], testStr)) + + u.must(hv.CloseRead(), "server close read") + + ch <- struct{}{} + + // signal the client to send more info + // if it was sent before, the read would succeed if the data was buffered prior + _, err = c.Read(b) + u.is(err, windows.WSAESHUTDOWN, "server did not shutdown reads") + + // safe to call multiple times + u.must(hv.CloseRead(), "server second close read") + + c.Close() + u.is(hv.CloseWrite(), socket.ErrSocketClosed, "client close write") + u.is(hv.CloseRead(), socket.ErrSocketClosed, "client close read") + }) + + cl, err := Dial(context.Background(), addr) + u.must(err, "could not dial") + defer cl.Close() + + b := make([]byte, 256) + n, err := cl.Read(b) + u.must(err, "client read") + u.assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr)) + + n, err = cl.Read(b) + u.assert(n == 0, "client did not get EOF") + u.is(err, io.EOF, "client did not get EOF") + + _, err = cl.Write([]byte(testStr)) + u.must(err, "client write") + + u.wait(ch, time.Second) + check(t) + + // this should succeed + _, err = cl.Write([]byte("test2")) + u.must(err, "client write") +} + +func TestHvSockCloseReadWriteDial(t *testing.T) { + u := newUtil(t) + l, addr := serverListen(u) + + ch := make(chan struct{}) + u.Go(func(u *testUtil) { + defer close(ch) + c, err := l.Accept() + u.must(err, "listener accept") + defer c.Close() + + b := make([]byte, 256) + n, err := c.Read(b) + u.must(err, "server read") + u.assert(string(b[:n]) == testStr, fmt.Sprintf("server got %q; wanted %q", b[:n], testStr)) + + n, err = c.Read(b) + u.assert(n == 0, "server did not get EOF") + u.is(err, io.EOF, "server did not get EOF") + + _, err = c.Write([]byte(testStr)) + u.must(err, "server tx") + + ch <- struct{}{} + + _, err = c.Write([]byte(testStr)) + u.must(err, "server tx") + + c.Close() + }) + + cl, err := Dial(context.Background(), addr) + u.must(err, "could not dial") + defer cl.Close() + + // + // test CloseWrite() + // + _, err = cl.Write([]byte(testStr)) + u.must(err, "client write") + u.must(cl.CloseWrite(), "client close write") + + _, err = cl.Write([]byte(testStr)) + u.is(err, windows.WSAESHUTDOWN, "client did not shutdown writes") + + // safe to call multiple times + u.must(cl.CloseWrite(), "client second close write") + + // + // test CloseRead() + // + b := make([]byte, 256) + n, err := cl.Read(b) + u.must(err, "client read") + u.assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr)) + u.must(cl.CloseRead(), "client close read") + + u.wait(ch, time.Millisecond) + check(t) + + // signal the client to send more info + // if it was sent before, the read would succeed if the data was buffered prior + _, err = cl.Read(b) + u.is(err, windows.WSAESHUTDOWN, "client did not shutdown reads") + + // safe to call multiple times + u.must(cl.CloseRead(), "client second close write") + + l.Close() + cl.Close() + + wantErr := socket.ErrSocketClosed + u.is(cl.CloseWrite(), wantErr, "client close write") + u.is(cl.CloseRead(), wantErr, "client close read") +} + +func TestHvSockDialNoTimeout(t *testing.T) { + u := newUtil(t) + ch := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + u.Go(func(u *testUtil) { + defer close(ch) + + addr := randHvsockAddr() + cl, err := Dial(ctx, addr) + if err == nil { + cl.Close() + } + u.is(err, windows.WSAECONNREFUSED) + }) + + // connections usually take about ~500µs + u.wait(ch, 2*time.Millisecond, "dial did not time out") +} + +func TestHvSockDialDeadline(t *testing.T) { + u := newUtil(t) + d := &HvsockDialer{} + d.Deadline = time.Now().Add(50 * time.Microsecond) + d.Retries = 1 + // we need the wait time to be long enough for the deadline goroutine to run first and signal + // timeout + d.RetryWait = 100 * time.Millisecond + addr := randHvsockAddr() + cl, err := d.Dial(context.Background(), addr) + if err == nil { + cl.Close() + t.Fatalf("dial should not have finished") + } + u.is(err, context.DeadlineExceeded, "dial did not exceed deadline") +} + +func TestHvSockDialContext(t *testing.T) { + u := newUtil(t) + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(50*time.Microsecond, cancel) + + d := &HvsockDialer{} + d.Retries = 1 + d.RetryWait = 100 * time.Millisecond + addr := randHvsockAddr() + cl, err := d.Dial(ctx, addr) + if err == nil { + cl.Close() + t.Fatalf("dial should not have finished") + } + u.is(err, context.Canceled, "dial was not canceled") +} + +func TestHvSockAcceptClose(t *testing.T) { + u := newUtil(t) + l, _ := serverListen(u) + go func() { + time.Sleep(50 * time.Millisecond) + l.Close() + }() + + c, err := l.Accept() + if err == nil { + c.Close() + t.Fatal("listener should not have accepted anything") + } + u.is(err, ErrFileClosed) +} + +// checks stops execution if testing failed in another go-routine +func check(t testing.TB) { + if t.Failed() { + t.FailNow() + } +} + +// calling FailNow() (or Fatal()) will freeze the tests, since runtime.Goexit should only +// be called from the main testing functions. +// create a testing.TB wrappers that remembers if it is in a goroutine or not, and then either +// calls FailNow or panics with a known key that can be recovered. + +type testUtilKey struct{} + +type testUtil struct { + T testing.TB + // whether `FailNow` can safely be called + canFail bool +} + +func newUtil(t testing.TB) *testUtil { + return &testUtil{ + T: t, + canFail: true, + } +} + +func (u *testUtil) assert(b bool, msgs ...string) { + if b { + return + } + u.T.Helper() + u.T.Errorf(_msgJoin(msgs, "failed assertion")) + u.fail() +} + +func (u *testUtil) is(err, target error, msgs ...string) { + if errors.Is(err, target) { + return + } + u.T.Helper() + u.T.Errorf(_msgJoin(msgs, "got error %q; wanted %q"), err, target) +} + +func (u *testUtil) must(err error, msgs ...string) { + if err == nil { + return + } + u.T.Helper() + u.T.Errorf(_msgJoin(msgs, "%v"), err) + u.fail() +} + +func (u *testUtil) fail() { + if u.canFail { + u.T.FailNow() + } + panic(testUtilKey{}) +} + +func (u *testUtil) Go(f func(u *testUtil)) { + _u := *u + _u.canFail = false + go func() { + defer func() { + r := recover() + switch r.(type) { + case nil, testUtilKey: + default: + panic(r) + } + }() + f(&_u) + }() +} + +func (u *testUtil) wait(ch <-chan struct{}, d time.Duration, msgs ...string) { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ch: + case <-t.C: + u.T.Helper() + u.T.Errorf(_msgJoin(msgs, "timed out after waiting %v"), d) + } +} + +func _msgJoin(pre []string, s string) string { + return strings.Join(append(pre, s), ": ") +}