diff --git a/test/stats_test.go b/test/stats_test.go index f4aa4d39877f..8d6587ec5ad6 100644 --- a/test/stats_test.go +++ b/test/stats_test.go @@ -20,7 +20,9 @@ package test import ( "context" + "fmt" "net" + "sync" "testing" "google.golang.org/grpc" @@ -32,15 +34,15 @@ import ( ) // TestPeerForClientStatsHandler configures a stats handler that -// verifies that peer is sent for OutPayload, InPayload, End -// stats handlers. +// verifies that peer is sent all stats handler callouts instead +// of Begin and PickerUpdated. func (s) TestPeerForClientStatsHandler(t *testing.T) { - statsHandler := &peerStatsHandler{} + psh := &peerStatsHandler{} - // Define expected stats callouts and whether a peer object should be populated. + // Stats callouts & peer object population. // Note: - // * Begin stats don't have peer information as the RPC begins before peer resolution. - // * PickerUpdated stats don't have peer information as the picker operates without transport-level knowledge. + // * Begin stats lack peer info (RPC starts pre-resolution). + // * PickerUpdated: no peer info (picker lacks transport details). expectedCallouts := map[stats.RPCStats]bool{ &stats.OutPayload{}: true, &stats.InHeader{}: true, @@ -74,26 +76,27 @@ func (s) TestPeerForClientStatsHandler(t *testing.T) { cc, err := grpc.NewClient( l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithStatsHandler(statsHandler)) + grpc.WithStatsHandler(psh)) if err != nil { t.Fatal(err) } - t.Cleanup(func() { - if err := cc.Close(); err != nil { - t.Error(err) - } - }) - + defer cc.Close() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() client := testgrpc.NewTestServiceClient(cc) - interop.DoClientStreaming(ctx, client) + interop.DoEmptyUnaryCall(ctx, client) - if len(getUniqueRPCStats(statsHandler.Args)) < len(expectedCallouts) { - t.Errorf("Unexpected number of stats handler callouts.") + psh.mu.Lock() + pshArgs := psh.args + psh.mu.Unlock() + + // Fetch the total unique number of stats handlers having peer not nil + sc := getUniqueRPCStatsCount(pshArgs) + if sc != len(expectedCallouts) { + t.Errorf("Unexpected number of stats handler callouts. Got %v, want %v", sc, len(expectedCallouts)) } - for _, callbackArgs := range statsHandler.Args { + for _, callbackArgs := range pshArgs { expectedPeer, found := expectedCallouts[callbackArgs.rpcStats] // In case expectation is set to false and still we got the peer, // then it's good to have it. So no need to assert those conditions. @@ -106,19 +109,17 @@ func (s) TestPeerForClientStatsHandler(t *testing.T) { } // getUniqueRPCStats extracts a list of unique stats.RPCStats types from peer list of RPC callback. -func getUniqueRPCStats(args []peerStats) []stats.RPCStats { - uniqueStatsTypes := make(map[stats.RPCStats]struct{}) - +func getUniqueRPCStatsCount(args []peerStats) int { + uniqueStatsTypes := make(map[string]struct{}) for _, callbackArgs := range args { - uniqueStatsTypes[callbackArgs.rpcStats] = struct{}{} - } - - var uniqueStatsList []stats.RPCStats - for statsType := range uniqueStatsTypes { - uniqueStatsList = append(uniqueStatsList, statsType) + key := fmt.Sprintf("%T", callbackArgs.rpcStats) + if _, exists := uniqueStatsTypes[key]; exists { + continue + } + uniqueStatsTypes[fmt.Sprintf("%T", callbackArgs.rpcStats)] = struct{}{} } - return uniqueStatsList + return len(uniqueStatsTypes) } type peerStats struct { @@ -127,7 +128,8 @@ type peerStats struct { } type peerStatsHandler struct { - Args []peerStats + args []peerStats + mu sync.Mutex } func (h *peerStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { @@ -136,7 +138,9 @@ func (h *peerStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) c func (h *peerStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { p, _ := peer.FromContext(ctx) - h.Args = append(h.Args, peerStats{rs, p}) + h.mu.Lock() + defer h.mu.Unlock() + h.args = append(h.args, peerStats{rs, p}) } func (h *peerStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {