From 3c24e363c83011ac4dcdb715c6b9148812a3cf5d Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Thu, 28 Jul 2022 12:27:56 -0400 Subject: [PATCH] pr: asserts, naming, fatal in test Signed-off-by: Hamza El-Saawy --- .gitignore | 3 + hvsock_test.go | 397 +++++++++++++++++++++++++++++++------------------ 2 files changed, 257 insertions(+), 143 deletions(-) diff --git a/.gitignore b/.gitignore index 9d428772..815e2066 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ *.exe +# testing +testdata + # go workspaces go.work go.work.sum diff --git a/hvsock_test.go b/hvsock_test.go index 3f7154aa..0c84a2c6 100644 --- a/hvsock_test.go +++ b/hvsock_test.go @@ -18,21 +18,27 @@ import ( "github.com/Microsoft/go-winio/pkg/guid" ) -// TODO: timeouts on listen - const testStr = "test" func randHvsockAddr() *HvsockAddr { - p := uint32(rand.Int31()) + p := rand.Uint32() //nolint:gosec // used for testing return &HvsockAddr{ VMID: HvsockGUIDLoopback(), ServiceID: VsockServiceID(p), } } -func serverListen(u testUtil) (*HvsockListener, *HvsockAddr) { - a := randHvsockAddr() - l, err := ListenHvsock(a) +func serverListen(u testUtil) (l *HvsockListener, a *HvsockAddr) { + var err error + for i := 0; i < 3; i++ { + a = randHvsockAddr() + l, err = ListenHvsock(a) + if errors.Is(err, windows.WSAEADDRINUSE) { + u.T.Logf("address collision %v", a) + continue + } + break + } u.Must(err, "could not listen") u.T.Cleanup(func() { if l != nil { @@ -45,21 +51,18 @@ func serverListen(u testUtil) (*HvsockListener, *HvsockAddr) { func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) { l, addr := serverListen(u) - ch := make(chan struct{}) - go func() { - defer close(ch) - + ch := u.Go(func() error { conn, err := l.Accept() - u.Must(err, "listener accept") + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } sv = conn.(*HvsockConn) - u.T.Cleanup(func() { - if sv != nil { - u.Must(sv.Close(), "server close") - } - }) - u.Must(l.Close()) + if err := l.Close(); err != nil { + return err + } l = nil - }() + return nil + }) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -71,8 +74,12 @@ func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) { } }) - u.Wait(ch, time.Second) - u.Check() + u.WaitErr(ch, time.Second) + u.T.Cleanup(func() { + if sv != nil { + u.Must(sv.Close(), "server close") + } + }) return cl, sv, addr } @@ -190,53 +197,67 @@ func TestHvSockReadWrite(t *testing.T) { // a sync.WaitGroup doesnt offer a channel to use in a select with a timeout // could use an errgroup.Group, but for now dual channels work fine - svCh := make(chan struct{}) - go func() { - defer close(svCh) + svCh := u.Go(func() error { c, err := l.Accept() - u.Must(err, "listener accept") + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } defer c.Close() b := make([]byte, 64) for _, tt := range tests { n, err := c.Read(b) - u.Must(err, "server rx") + if err != nil { + return fmt.Errorf("server rx: %w", err) + } r := string(b[:n]) - u.Assert(r == tt.req, fmt.Sprintf("server rx error: got %q; wanted %q", r, tt.req)) - - _, err = c.Write([]byte(tt.rsp)) - u.Must(err, "server tx error, could not send "+tt.rsp) + if r != tt.req { + return fmt.Errorf("server rx error: got %q; wanted %q", r, tt.req) + } + if _, err = c.Write([]byte(tt.rsp)); err != nil { + return fmt.Errorf("server tx error, could not send %q: %w", tt.rsp, err) + } } n, err := c.Read(b) - u.Assert(n == 0, "server did not get EOF") - u.Is(err, io.EOF, "server did not get EOF") - }() + if n != 0 { + return errors.New("server did not get EOF") + } + if !errors.Is(err, io.EOF) { + return fmt.Errorf("server did not get EOF: %w", err) + } + return nil + }) - clCh := make(chan struct{}) - go func() { - defer close(clCh) + clCh := u.Go(func() error { cl, err := Dial(context.Background(), addr) - u.Must(err, "client dial") + if err != nil { + return fmt.Errorf("client dial: %w", err) + } defer cl.Close() b := make([]byte, 64) for _, tt := range tests { _, err := cl.Write([]byte(tt.req)) - u.Must(err, "client tx error, could not send "+tt.req) + if err != nil { + return fmt.Errorf("client tx error, could not send %q: %w", tt.req, err) + } n, err := cl.Read(b) - u.Must(err, "client rx") + if err != nil { + return fmt.Errorf("client tx: %w", err) + } r := string(b[:n]) - u.Assert(r == tt.rsp, fmt.Sprintf("client rx error: got %q; wanted %q", b[:n], tt.rsp)) + if r != tt.rsp { + return fmt.Errorf("client rx error: got %q; wanted %q", b[:n], tt.rsp) + } } + return cl.CloseWrite() + }) - u.Must(cl.CloseWrite()) - }() - - u.Wait(svCh, 15*time.Second, "server") - u.Wait(clCh, 15*time.Second, "client") + u.WaitErr(svCh, 15*time.Second, "server") + u.WaitErr(clCh, 15*time.Second, "client") } func TestHvSockReadTooSmall(t *testing.T) { @@ -244,11 +265,11 @@ func TestHvSockReadTooSmall(t *testing.T) { s := "this is a really long string that hopefully takes up more than 16 bytes ..." l, addr := serverListen(u) - svCh := make(chan struct{}) - go func() { - defer close(svCh) + svCh := u.Go(func() error { c, err := l.Accept() - u.Must(err, "listener accept") + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } defer c.Close() b := make([]byte, 16) @@ -258,26 +279,33 @@ func TestHvSockReadTooSmall(t *testing.T) { if errors.Is(err, io.EOF) { break } - u.Must(err, "server rx") + if err != nil { + return fmt.Errorf("server rx: %w", err) + } ss += string(b[:n]) } - u.Assert(ss == s, fmt.Sprintf("got %q, wanted: %q", ss, s)) - }() + if ss != s { + return fmt.Errorf("got %q, wanted: %q", ss, s) + } + return nil + }) - clCh := make(chan struct{}) - go func() { - defer close(clCh) + clCh := u.Go(func() error { cl, err := Dial(context.Background(), addr) - u.Must(err, "client dial") + if err != nil { + return fmt.Errorf("client dial: %w", err) + } defer cl.Close() - _, err = cl.Write([]byte(s)) - u.Must(err, "client tx error, could not send") - }() + if _, err = cl.Write([]byte(s)); err != nil { + return fmt.Errorf("client tx error, could not send: %w", err) + } + return nil + }) - u.Wait(svCh, 15*time.Second, "server") - u.Wait(clCh, 15*time.Second, "client") + u.WaitErr(svCh, 15*time.Second, "server") + u.WaitErr(clCh, 15*time.Second, "client") } func TestHvSockCloseReadWriteListener(t *testing.T) { @@ -285,51 +313,78 @@ func TestHvSockCloseReadWriteListener(t *testing.T) { l, addr := serverListen(u) ch := make(chan struct{}) - go func() { + svCh := u.Go(func() error { defer close(ch) c, err := l.Accept() - u.Must(err, "listener accept") + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } defer c.Close() hv := c.(*HvsockConn) // // test CloseWrite() // - _, err = c.Write([]byte(testStr)) - u.Must(err, "server tx") - - u.Must(hv.CloseWrite(), "server close write") + n, err := c.Write([]byte(testStr)) + if err != nil { + return fmt.Errorf("server tx: %w", err) + } + if n != len(testStr) { + return fmt.Errorf("server wrote %d bytes, wanted %d", n, len(testStr)) + } - _, err = c.Write([]byte(testStr)) - u.Is(err, windows.WSAESHUTDOWN, "server did not shutdown writes") + if err := hv.CloseWrite(); err != nil { + return fmt.Errorf("server close write: %w", err) + } + if _, err = c.Write([]byte(testStr)); !errors.Is(err, windows.WSAESHUTDOWN) { + return fmt.Errorf("server did not shutdown writes: %w", err) + } // safe to call multiple times - u.Must(hv.CloseWrite(), "server second close write") + if err := hv.CloseWrite(); err != nil { + return fmt.Errorf("server second close write: %w", err) + } // // test CloseRead() // b := make([]byte, 256) - n, err := c.Read(b) - u.Must(err, "server read") - u.Assert(string(b[:n]) == testStr, fmt.Sprintf("server got %q; wanted %q", b[:n], testStr)) - - u.Must(hv.CloseRead(), "server close read") + n, err = c.Read(b) + if err != nil { + return fmt.Errorf("server read: %w", err) + } + if n != len(testStr) { + return fmt.Errorf("server read %d bytes, wanted %d", n, len(testStr)) + } + if string(b[:n]) != testStr { + return fmt.Errorf("server got %q; wanted %q", b[:n], testStr) + } + if err := hv.CloseRead(); err != nil { + return fmt.Errorf("server close read: %w", err) + } ch <- struct{}{} // signal the client to send more info // if it was sent before, the read would succeed if the data was buffered prior _, err = c.Read(b) - u.Is(err, windows.WSAESHUTDOWN, "server did not shutdown reads") - + if !errors.Is(err, windows.WSAESHUTDOWN) { + return fmt.Errorf("server did not shutdown reads: %w", err) + } // safe to call multiple times - u.Must(hv.CloseRead(), "server second close read") + if err := hv.CloseRead(); err != nil { + return fmt.Errorf("server second close read: %w", err) + } c.Close() - u.Is(hv.CloseWrite(), socket.ErrSocketClosed, "client close write") - u.Is(hv.CloseRead(), socket.ErrSocketClosed, "client close read") - }() + if err := hv.CloseWrite(); !errors.Is(err, socket.ErrSocketClosed) { + return fmt.Errorf("server close write: %w", err) + } + if err := hv.CloseRead(); !errors.Is(err, socket.ErrSocketClosed) { + return fmt.Errorf("server close read: %w", err) + } + return nil + }) cl, err := Dial(context.Background(), addr) u.Must(err, "could not dial") @@ -338,21 +393,23 @@ func TestHvSockCloseReadWriteListener(t *testing.T) { b := make([]byte, 256) n, err := cl.Read(b) u.Must(err, "client read") + u.Assert(n == len(testStr), fmt.Sprintf("client read %d bytes, wanted %d", n, len(testStr))) u.Assert(string(b[:n]) == testStr, fmt.Sprintf("client got %q; wanted %q", b[:n], testStr)) n, err = cl.Read(b) u.Assert(n == 0, "client did not get EOF") u.Is(err, io.EOF, "client did not get EOF") - _, err = cl.Write([]byte(testStr)) + n, err = cl.Write([]byte(testStr)) u.Must(err, "client write") + u.Assert(n == len(testStr), fmt.Sprintf("client wrote %d bytes, wanted %d", n, len(testStr))) u.Wait(ch, time.Second) - u.Check() // this should succeed _, err = cl.Write([]byte("test2")) u.Must(err, "client write") + u.WaitErr(svCh, time.Second, "server") } func TestHvSockCloseReadWriteDial(t *testing.T) { @@ -360,31 +417,44 @@ func TestHvSockCloseReadWriteDial(t *testing.T) { l, addr := serverListen(u) ch := make(chan struct{}) - go func() { + clCh := u.Go(func() error { defer close(ch) c, err := l.Accept() - u.Must(err, "listener accept") + if err != nil { + return fmt.Errorf("listener accept: %w", err) + } defer c.Close() b := make([]byte, 256) n, err := c.Read(b) - u.Must(err, "server read") - u.Assert(string(b[:n]) == testStr, fmt.Sprintf("server got %q; wanted %q", b[:n], testStr)) + if err != nil { + return fmt.Errorf("server read: %w", err) + } + if string(b[:n]) != testStr { + return fmt.Errorf("server got %q; wanted %q", b[:n], testStr) + } n, err = c.Read(b) - u.Assert(n == 0, "server did not get EOF") - u.Is(err, io.EOF, "server did not get EOF") + if n != 0 { + return fmt.Errorf("server did not get EOF") + } + if !errors.Is(err, io.EOF) { + return errors.New("server did not get EOF") + } _, err = c.Write([]byte(testStr)) - u.Must(err, "server tx") + if err != nil { + return fmt.Errorf("server tx: %w", err) + } ch <- struct{}{} _, err = c.Write([]byte(testStr)) - u.Must(err, "server tx") - - c.Close() - }() + if err != nil { + return fmt.Errorf("server tx: %w", err) + } + return c.Close() + }) cl, err := Dial(context.Background(), addr) u.Must(err, "could not dial") @@ -413,7 +483,6 @@ func TestHvSockCloseReadWriteDial(t *testing.T) { u.Must(cl.CloseRead(), "client close read") u.Wait(ch, time.Millisecond) - u.Check() // signal the client to send more info // if it was sent before, the read would succeed if the data was buffered prior @@ -429,27 +498,27 @@ func TestHvSockCloseReadWriteDial(t *testing.T) { wantErr := socket.ErrSocketClosed u.Is(cl.CloseWrite(), wantErr, "client close write") u.Is(cl.CloseRead(), wantErr, "client close read") + u.WaitErr(clCh, time.Second, "client") } func TestHvSockDialNoTimeout(t *testing.T) { u := newUtil(t) - ch := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - - go func() { - defer close(ch) - + ch := u.Go(func() error { addr := randHvsockAddr() cl, err := Dial(ctx, addr) if err == nil { cl.Close() } - u.Is(err, windows.WSAECONNREFUSED) - }() + if !errors.Is(err, windows.WSAECONNREFUSED) { + return err + } + return nil + }) // connections usually take about ~500µs - u.Wait(ch, 2*time.Millisecond, "dial did not time out") + u.WaitErr(ch, 2*time.Millisecond, "dial did not time out") } func TestHvSockDialDeadline(t *testing.T) { @@ -502,7 +571,7 @@ func TestHvSockAcceptClose(t *testing.T) { u.Is(err, ErrFileClosed) } -func FuzzRxTx(f *testing.F) { +func FuzzHvSockRxTx(f *testing.F) { for _, b := range [][]byte{ []byte("hello?"), []byte("This is a really long string that should be a good example of the really long " + @@ -526,43 +595,57 @@ func FuzzRxTx(f *testing.F) { u := newUtil(t) cl, sv, _ := clientServer(u) - svCh := make(chan struct{}) - go func() { - defer close(svCh) - + svCh := u.Go(func() error { n, err := cl.Write(a) - u.Must(err, "client write") - u.Assert(n == len(a), "client did not send full message") - t.Log("client sent") + if err != nil { + return fmt.Errorf("client write: %w", err) + } + if n != len(a) { + return errors.New("client did not send full message") + } b := make([]byte, len(a)+5) // a little extra to make sure nothing else is sent n, err = cl.Read(b) - u.Must(err, "cl read") - u.Assert(n == len(a), "client did not read full message") + if err != nil { + return fmt.Errorf("client read: %w", err) + } + if n != len(a) { + return errors.New("client did not read full message") + } bn := b[:n] - u.Assert(string(a) == string(bn), fmt.Sprintf("payload mismatch %q != %q", a, bn)) + if string(a) != string(bn) { + return fmt.Errorf("client payload mismatch %q != %q", a, bn) + } t.Log("client received") - }() - - clCh := make(chan struct{}) - go func() { - defer close(clCh) + return nil + }) + clCh := u.Go(func() error { b := make([]byte, len(a)+5) // a little extra to make sure nothing else is sent n, err := sv.Read(b) - u.Must(err, "server read") - u.Assert(n == len(a), "server did not read full message") + if err != nil { + return fmt.Errorf("server read: %w", err) + } + if n != len(a) { + return errors.New("server did not read full message") + } bn := b[:n] - u.Assert(string(a) == string(bn), fmt.Sprintf("payload mismatch %q != %q", a, bn)) - t.Log("server received") + if string(a) != string(bn) { + return fmt.Errorf("server payload mismatch %q != %q", a, bn) + } n, err = sv.Write(bn) - u.Must(err, "server write") - u.Assert(n == len(bn), "server did not send full message") + if err != nil { + return fmt.Errorf("server write: %w", err) + } + if n != len(a) { + return errors.New("server did not send full message") + } t.Log("server sent") - }() - u.Wait(svCh, 250*time.Millisecond) - u.Wait(clCh, 250*time.Millisecond) + return nil + }) + u.WaitErr(svCh, 250*time.Millisecond) + u.WaitErr(clCh, 250*time.Millisecond) }) } @@ -580,10 +663,42 @@ func newUtil(t testing.TB) testUtil { } } -// checks stops execution if testing failed in another go-routine -func (u testUtil) Check() { - if u.T.Failed() { - u.T.FailNow() +// Go launches f in a go routine and returns a channel that can be monitored for the result. +// ch is closed after f completes. +// +// Intended for use with [testUtil.WaitErr]. +func (*testUtil) Go(f func() error) chan error { + ch := make(chan error) + go func() { + defer close(ch) + ch <- f() + }() + return ch +} + +func (u testUtil) Wait(ch <-chan struct{}, d time.Duration, msgs ...string) { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ch: + case <-t.C: + u.T.Helper() + u.T.Fatalf(msgJoin(msgs, "timed out after %v"), d) + } +} + +func (u testUtil) WaitErr(ch <-chan error, d time.Duration, msgs ...string) { + t := time.NewTimer(d) + defer t.Stop() + select { + case err := <-ch: + if err != nil { + u.T.Helper() + u.T.Fatalf(msgJoin(msgs, "%v"), err) + } + case <-t.C: + u.T.Helper() + u.T.Fatalf(msgJoin(msgs, "timed out after %v"), d) } } @@ -592,7 +707,7 @@ func (u testUtil) Assert(b bool, msgs ...string) { return } u.T.Helper() - u.T.Fatalf(_msgJoin(msgs, "failed assertion")) + u.T.Fatalf(msgJoin(msgs, "failed assertion")) } func (u testUtil) Is(err, target error, msgs ...string) { @@ -600,7 +715,7 @@ func (u testUtil) Is(err, target error, msgs ...string) { return } u.T.Helper() - u.T.Fatalf(_msgJoin(msgs, "got error %q; wanted %q"), err, target) + u.T.Fatalf(msgJoin(msgs, "got error %q; wanted %q"), err, target) } func (u testUtil) Must(err error, msgs ...string) { @@ -608,20 +723,16 @@ func (u testUtil) Must(err error, msgs ...string) { return } u.T.Helper() - u.T.Fatalf(_msgJoin(msgs, "%v"), err) + u.T.Fatalf(msgJoin(msgs, "%v"), err) } -func (u testUtil) Wait(ch <-chan struct{}, d time.Duration, msgs ...string) { - t := time.NewTimer(d) - defer t.Stop() - select { - case <-ch: - case <-t.C: - u.T.Helper() - u.T.Fatalf(_msgJoin(msgs, "timed out after %v"), d) +// Check stops execution if testing failed in another go-routine. +func (u testUtil) Check() { + if u.T.Failed() { + u.T.FailNow() } } -func _msgJoin(pre []string, s string) string { +func msgJoin(pre []string, s string) string { return strings.Join(append(pre, s), ": ") }