diff --git a/dialoptions.go b/dialoptions.go index f2f605a17c4..2c7d6b05c19 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -380,7 +380,7 @@ func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { // all the RPCs and underlying network connections in this ClientConn. func WithStatsHandler(h stats.Handler) DialOption { return newFuncDialOption(func(o *dialOptions) { - o.copts.StatsHandler = h + o.copts.StatsHandlers = append(o.copts.StatsHandlers, h) }) } diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 1c3459c2b4c..090120925bb 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -49,7 +49,7 @@ import ( // NewServerHandlerTransport returns a ServerTransport handling gRPC // from inside an http.Handler. It requires that the http Server // supports HTTP/2. -func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) { +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) { if r.ProtoMajor != 2 { return nil, errors.New("gRPC requires HTTP/2") } @@ -138,7 +138,7 @@ type serverHandlerTransport struct { // TODO make sure this is consistent across handler_server and http2_server contentSubtype string - stats stats.Handler + stats []stats.Handler } func (ht *serverHandlerTransport) Close() { @@ -228,10 +228,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro }) if err == nil { // transport has not been closed - if ht.stats != nil { - // Note: The trailer fields are compressed with hpack after this call returns. - // No WireLength field is set here. - ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{ + // Note: The trailer fields are compressed with hpack after this call returns. + // No WireLength field is set here. + for _, sh := range ht.stats { + sh.HandleRPC(s.Context(), &stats.OutTrailer{ Trailer: s.trailer.Copy(), }) } @@ -314,10 +314,10 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { }) if err == nil { - if ht.stats != nil { + for _, sh := range ht.stats { // Note: The header fields are compressed with hpack after this call returns. // No WireLength field is set here. - ht.stats.HandleRPC(s.Context(), &stats.OutHeader{ + sh.HandleRPC(s.Context(), &stats.OutHeader{ Header: md.Copy(), Compression: s.sendCompress, }) @@ -369,14 +369,14 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace } ctx = metadata.NewIncomingContext(ctx, ht.headerMD) s.ctx = peer.NewContext(ctx, pr) - if ht.stats != nil { - s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + for _, sh := range ht.stats { + s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ FullMethod: s.method, RemoteAddr: ht.RemoteAddr(), Compression: s.recvCompress, } - ht.stats.HandleRPC(s.ctx, inHeader) + sh.HandleRPC(s.ctx, inHeader) } s.trReader = &transportReader{ reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 24ca59084b4..be371c6e0f7 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -90,7 +90,7 @@ type http2Client struct { kp keepalive.ClientParameters keepaliveEnabled bool - statsHandler stats.Handler + statsHandlers []stats.Handler initialWindowSize int32 @@ -311,7 +311,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts isSecure: isSecure, perRPCCreds: perRPCCreds, kp: kp, - statsHandler: opts.StatsHandler, + statsHandlers: opts.StatsHandlers, initialWindowSize: initialWindowSize, onPrefaceReceipt: onPrefaceReceipt, nextID: 1, @@ -341,15 +341,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts updateFlowControl: t.updateFlowControl, } } - if t.statsHandler != nil { - t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ + for _, sh := range t.statsHandlers { + t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, }) connBegin := &stats.ConnBegin{ Client: true, } - t.statsHandler.HandleConn(t.ctx, connBegin) + sh.HandleConn(t.ctx, connBegin) } t.channelzID, err = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr)) if err != nil { @@ -773,24 +773,27 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true} } } - if t.statsHandler != nil { + if len(t.statsHandlers) != 0 { header, ok := metadata.FromOutgoingContext(ctx) if ok { header.Set("user-agent", t.userAgent) } else { header = metadata.Pairs("user-agent", t.userAgent) } - // Note: The header fields are compressed with hpack after this call returns. - // No WireLength field is set here. - outHeader := &stats.OutHeader{ - Client: true, - FullMethod: callHdr.Method, - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - Compression: callHdr.SendCompress, - Header: header, + for _, sh := range t.statsHandlers { + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + // Note: Creating a new stats object to prevent pollution. + outHeader := &stats.OutHeader{ + Client: true, + FullMethod: callHdr.Method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: callHdr.SendCompress, + Header: header, + } + sh.HandleRPC(s.ctx, outHeader) } - t.statsHandler.HandleRPC(s.ctx, outHeader) } return s, nil } @@ -916,11 +919,11 @@ func (t *http2Client) Close(err error) { for _, s := range streams { t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) } - if t.statsHandler != nil { + for _, sh := range t.statsHandlers { connEnd := &stats.ConnEnd{ Client: true, } - t.statsHandler.HandleConn(t.ctx, connEnd) + sh.HandleConn(t.ctx, connEnd) } } @@ -1432,7 +1435,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { close(s.headerChan) } - if t.statsHandler != nil { + for _, sh := range t.statsHandlers { if isHeader { inHeader := &stats.InHeader{ Client: true, @@ -1440,14 +1443,14 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { Header: metadata.MD(mdata).Copy(), Compression: s.recvCompress, } - t.statsHandler.HandleRPC(s.ctx, inHeader) + sh.HandleRPC(s.ctx, inHeader) } else { inTrailer := &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), Trailer: metadata.MD(mdata).Copy(), } - t.statsHandler.HandleRPC(s.ctx, inTrailer) + sh.HandleRPC(s.ctx, inTrailer) } } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 4969102f4af..d6d3cb01b8c 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -82,7 +82,7 @@ type http2Server struct { // updates, reset streams, and various settings) to the controller. controlBuf *controlBuffer fc *trInFlow - stats stats.Handler + stats []stats.Handler // Keepalive and max-age parameters for the server. kp keepalive.ServerParameters // Keepalive enforcement policy. @@ -257,7 +257,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, fc: &trInFlow{limit: uint32(icwz)}, state: reachable, activeStreams: make(map[uint32]*Stream), - stats: config.StatsHandler, + stats: config.StatsHandlers, kp: kp, idle: time.Now(), kep: kep, @@ -272,13 +272,13 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, updateFlowControl: t.updateFlowControl, } } - if t.stats != nil { - t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{ + for _, sh := range t.stats { + t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, }) connBegin := &stats.ConnBegin{} - t.stats.HandleConn(t.ctx, connBegin) + sh.HandleConn(t.ctx, connBegin) } t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr)) if err != nil { @@ -566,8 +566,8 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.adjustWindow(s, uint32(n)) } s.ctx = traceCtx(s.ctx, s.method) - if t.stats != nil { - s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + for _, sh := range t.stats { + s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ FullMethod: s.method, RemoteAddr: t.remoteAddr, @@ -576,7 +576,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( WireLength: int(frame.Header().Length), Header: metadata.MD(mdata).Copy(), } - t.stats.HandleRPC(s.ctx, inHeader) + sh.HandleRPC(s.ctx, inHeader) } s.ctxDone = s.ctx.Done() s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) @@ -992,14 +992,14 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { t.closeStream(s, true, http2.ErrCodeInternal, false) return ErrHeaderListSizeLimitViolation } - if t.stats != nil { + for _, sh := range t.stats { // Note: Headers are compressed with hpack after this call returns. // No WireLength field is set here. outHeader := &stats.OutHeader{ Header: s.header.Copy(), Compression: s.sendCompress, } - t.stats.HandleRPC(s.Context(), outHeader) + sh.HandleRPC(s.Context(), outHeader) } return nil } @@ -1060,10 +1060,10 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { // Send a RST_STREAM after the trailers if the client has not already half-closed. rst := s.getState() == streamActive t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) - if t.stats != nil { + for _, sh := range t.stats { // Note: The trailer fields are compressed with hpack after this call returns. // No WireLength field is set here. - t.stats.HandleRPC(s.Context(), &stats.OutTrailer{ + sh.HandleRPC(s.Context(), &stats.OutTrailer{ Trailer: s.trailer.Copy(), }) } @@ -1218,9 +1218,9 @@ func (t *http2Server) Close() { for _, s := range streams { s.cancel() } - if t.stats != nil { + for _, sh := range t.stats { connEnd := &stats.ConnEnd{} - t.stats.HandleConn(t.ctx, connEnd) + sh.HandleConn(t.ctx, connEnd) } } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index a9ce717f160..6c3ba851594 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -523,7 +523,7 @@ type ServerConfig struct { ConnectionTimeout time.Duration Credentials credentials.TransportCredentials InTapHandle tap.ServerInHandle - StatsHandler stats.Handler + StatsHandlers []stats.Handler KeepaliveParams keepalive.ServerParameters KeepalivePolicy keepalive.EnforcementPolicy InitialWindowSize int32 @@ -553,8 +553,8 @@ type ConnectOptions struct { CredsBundle credentials.Bundle // KeepaliveParams stores the keepalive parameters. KeepaliveParams keepalive.ClientParameters - // StatsHandler stores the handler for stats. - StatsHandler stats.Handler + // StatsHandlers stores the handler for stats. + StatsHandlers []stats.Handler // InitialWindowSize sets the initial window size for a stream. InitialWindowSize int32 // InitialConnWindowSize sets the initial window size for a connection. diff --git a/server.go b/server.go index 65de84b3007..ef6ecf31519 100644 --- a/server.go +++ b/server.go @@ -150,7 +150,7 @@ type serverOptions struct { chainUnaryInts []UnaryServerInterceptor chainStreamInts []StreamServerInterceptor inTapHandle tap.ServerInHandle - statsHandler stats.Handler + statsHandlers []stats.Handler maxConcurrentStreams uint32 maxReceiveMessageSize int maxSendMessageSize int @@ -435,7 +435,7 @@ func InTapHandle(h tap.ServerInHandle) ServerOption { // StatsHandler returns a ServerOption that sets the stats handler for the server. func StatsHandler(h stats.Handler) ServerOption { return newFuncServerOption(func(o *serverOptions) { - o.statsHandler = h + o.statsHandlers = append(o.statsHandlers, h) }) } @@ -867,7 +867,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { ConnectionTimeout: s.opts.connectionTimeout, Credentials: s.opts.creds, InTapHandle: s.opts.inTapHandle, - StatsHandler: s.opts.statsHandler, + StatsHandlers: s.opts.statsHandlers, KeepaliveParams: s.opts.keepaliveParams, KeepalivePolicy: s.opts.keepalivePolicy, InitialWindowSize: s.opts.initialWindowSize, @@ -963,7 +963,7 @@ var _ http.Handler = (*Server)(nil) // Notice: This API is EXPERIMENTAL and may be changed or removed in a // later release. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler) + st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -1076,8 +1076,10 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) } err = t.Write(stream, hdr, payload, opts) - if err == nil && s.opts.statsHandler != nil { - s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + if err == nil { + for _, sh := range s.opts.statsHandlers { + sh.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + } } return err } @@ -1124,13 +1126,13 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { - sh := s.opts.statsHandler - if sh != nil || trInfo != nil || channelz.IsOn() { + shs := s.opts.statsHandlers + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { if channelz.IsOn() { s.incrCallsStarted() } var statsBegin *stats.Begin - if sh != nil { + for _, sh := range shs { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, @@ -1161,7 +1163,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. trInfo.tr.Finish() } - if sh != nil { + for _, sh := range shs { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1243,7 +1245,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } var payInfo *payloadInfo - if sh != nil || binlog != nil { + if len(shs) != 0 || binlog != nil { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) @@ -1260,7 +1262,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } - if sh != nil { + for _, sh := range shs { sh.HandleRPC(stream.Context(), &stats.InPayload{ RecvTime: time.Now(), Payload: v, @@ -1418,16 +1420,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if channelz.IsOn() { s.incrCallsStarted() } - sh := s.opts.statsHandler + shs := s.opts.statsHandlers var statsBegin *stats.Begin - if sh != nil { + if len(shs) != 0 { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - sh.HandleRPC(stream.Context(), statsBegin) + for _, sh := range shs { + sh.HandleRPC(stream.Context(), statsBegin) + } } ctx := NewContextWithServerTransportStream(stream.Context(), stream) ss := &serverStream{ @@ -1439,10 +1443,10 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, - statsHandler: sh, + statsHandler: shs, } - if sh != nil || trInfo != nil || channelz.IsOn() { + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { // See comment in processUnaryRPC on defers. defer func() { if trInfo != nil { @@ -1456,7 +1460,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } - if sh != nil { + if len(shs) != 0 { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1464,7 +1468,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - sh.HandleRPC(stream.Context(), end) + for _, sh := range shs { + sh.HandleRPC(stream.Context(), end) + } } if channelz.IsOn() { diff --git a/stats/stats_test.go b/stats/stats_test.go index 1b08568b906..9a1a6c11253 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -176,10 +176,10 @@ func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, // func, modified as needed, and then started with its startServer method. // It should be cleaned up with the tearDown method. type test struct { - t *testing.T - compress string - clientStatsHandler stats.Handler - serverStatsHandler stats.Handler + t *testing.T + compress string + clientStatsHandlers []stats.Handler + serverStatsHandlers []stats.Handler testServer testgrpc.TestServiceServer // nil means none // srv and srvAddr are set once startServer is called. @@ -204,12 +204,12 @@ type testConfig struct { // newTest returns a new test using the provided testing.T and // environment. It is returned with default values. Tests should // modify it before calling its startServer and clientConn methods. -func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test { +func newTest(t *testing.T, tc *testConfig, chs []stats.Handler, shs []stats.Handler) *test { te := &test{ - t: t, - compress: tc.compress, - clientStatsHandler: ch, - serverStatsHandler: sh, + t: t, + compress: tc.compress, + clientStatsHandlers: chs, + serverStatsHandlers: shs, } return te } @@ -229,8 +229,8 @@ func (te *test) startServer(ts testgrpc.TestServiceServer) { grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.serverStatsHandler != nil { - opts = append(opts, grpc.StatsHandler(te.serverStatsHandler)) + for _, sh := range te.serverStatsHandlers { + opts = append(opts, grpc.StatsHandler(sh)) } s := grpc.NewServer(opts...) te.srv = s @@ -257,8 +257,8 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.clientStatsHandler != nil { - opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler)) + for _, sh := range te.clientStatsHandlers { + opts = append(opts, grpc.WithStatsHandler(sh)) } var err error @@ -846,7 +846,7 @@ func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkF func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { h := &statshandler{} - te := newTest(t, tc, nil, h) + te := newTest(t, tc, nil, []stats.Handler{h}) te.startServer(&testServer{}) defer te.tearDown() @@ -1146,7 +1146,7 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) { h := &statshandler{} - te := newTest(t, tc, h, nil) + te := newTest(t, tc, []stats.Handler{h}, nil) te.startServer(&testServer{}) defer te.tearDown() @@ -1375,3 +1375,95 @@ func (s) TestTrace(t *testing.T) { t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr) } } + +func (s) TestMultipleClientStatsHandler(t *testing.T) { + h := &statshandler{} + tc := &testConfig{compress: ""} + te := newTest(t, tc, []stats.Handler{h, h}, nil) + te.startServer(&testServer{}) + defer te.tearDown() + + cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC} + _, _, err := te.doUnaryCall(cc) + if cc.success != (err == nil) { + t.Fatalf("cc.success: %v, got error: %v", cc.success, err) + } + te.cc.Close() + te.srv.GracefulStop() // Wait for the server to stop. + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + // Each RPC generates 6 stats events on the client-side, times 2 StatsHandler + if len(h.gotRPC) != 12 { + t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) + } + + // Each connection generates 4 conn events on the client-side, times 2 StatsHandler + if len(h.gotConn) != 4 { + t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) + } +} + +func (s) TestMultipleServerStatsHandler(t *testing.T) { + h := &statshandler{} + tc := &testConfig{compress: ""} + te := newTest(t, tc, nil, []stats.Handler{h, h}) + te.startServer(&testServer{}) + defer te.tearDown() + + cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC} + _, _, err := te.doUnaryCall(cc) + if cc.success != (err == nil) { + t.Fatalf("cc.success: %v, got error: %v", cc.success, err) + } + te.cc.Close() + te.srv.GracefulStop() // Wait for the server to stop. + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + // Each RPC generates 6 stats events on the server-side, times 2 StatsHandler + if len(h.gotRPC) != 12 { + t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) + } + + // Each connection generates 4 conn events on the server-side, times 2 StatsHandler + if len(h.gotConn) != 4 { + t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) + } +} diff --git a/stream.go b/stream.go index 236fc17ec3c..6d82e0d7cca 100644 --- a/stream.go +++ b/stream.go @@ -374,9 +374,9 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.cp, cs.comp) method := cs.callHdr.Method - sh := cs.cc.dopts.copts.StatsHandler var beginTime time.Time - if sh != nil { + shs := cs.cc.dopts.copts.StatsHandlers + for _, sh := range shs { ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) beginTime = time.Now() begin := &stats.Begin{ @@ -414,12 +414,12 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) } return &csAttempt{ - ctx: ctx, - beginTime: beginTime, - cs: cs, - dc: cs.cc.dopts.dc, - statsHandler: sh, - trInfo: trInfo, + ctx: ctx, + beginTime: beginTime, + cs: cs, + dc: cs.cc.dopts.dc, + statsHandlers: shs, + trInfo: trInfo, }, nil } @@ -536,8 +536,8 @@ type csAttempt struct { // and cleared when the finish method is called. trInfo *traceInfo - statsHandler stats.Handler - beginTime time.Time + statsHandlers []stats.Handler + beginTime time.Time // set for newStream errors that may be transparently retried allowTransparentRetry bool @@ -960,8 +960,8 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { } return io.EOF } - if a.statsHandler != nil { - a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) + for _, sh := range a.statsHandlers { + sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) } if channelz.IsOn() { a.t.IncrMsgSent() @@ -971,7 +971,7 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { cs := a.cs - if a.statsHandler != nil && payInfo == nil { + if len(a.statsHandlers) != 0 && payInfo == nil { payInfo = &payloadInfo{} } @@ -1008,8 +1008,8 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { } a.mu.Unlock() } - if a.statsHandler != nil { - a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{ + for _, sh := range a.statsHandlers { + sh.HandleRPC(a.ctx, &stats.InPayload{ Client: true, RecvTime: time.Now(), Payload: m, @@ -1068,7 +1068,7 @@ func (a *csAttempt) finish(err error) { ServerLoad: balancerload.Parse(tr), }) } - if a.statsHandler != nil { + for _, sh := range a.statsHandlers { end := &stats.End{ Client: true, BeginTime: a.beginTime, @@ -1076,7 +1076,7 @@ func (a *csAttempt) finish(err error) { Trailer: tr, Error: err, } - a.statsHandler.HandleRPC(a.ctx, end) + sh.HandleRPC(a.ctx, end) } if a.trInfo != nil && a.trInfo.tr != nil { if err == nil { @@ -1445,7 +1445,7 @@ type serverStream struct { maxSendMessageSize int trInfo *traceInfo - statsHandler stats.Handler + statsHandler []stats.Handler binlog binarylog.MethodLogger // serverHeaderBinlogged indicates whether server header has been logged. It @@ -1555,8 +1555,10 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { Message: data, }) } - if ss.statsHandler != nil { - ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now())) + if len(ss.statsHandler) != 0 { + for _, sh := range ss.statsHandler { + sh.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now())) + } } return nil } @@ -1590,7 +1592,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } }() var payInfo *payloadInfo - if ss.statsHandler != nil || ss.binlog != nil { + if len(ss.statsHandler) != 0 || ss.binlog != nil { payInfo = &payloadInfo{} } if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { @@ -1605,15 +1607,17 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } return toRPCErr(err) } - if ss.statsHandler != nil { - ss.statsHandler.HandleRPC(ss.s.Context(), &stats.InPayload{ - RecvTime: time.Now(), - Payload: m, - // TODO truncate large payload. - Data: payInfo.uncompressedBytes, - WireLength: payInfo.wireLength + headerLen, - Length: len(payInfo.uncompressedBytes), - }) + if len(ss.statsHandler) != 0 { + for _, sh := range ss.statsHandler { + sh.HandleRPC(ss.s.Context(), &stats.InPayload{ + RecvTime: time.Now(), + Payload: m, + // TODO truncate large payload. + Data: payInfo.uncompressedBytes, + WireLength: payInfo.wireLength + headerLen, + Length: len(payInfo.uncompressedBytes), + }) + } } if ss.binlog != nil { ss.binlog.Log(&binarylog.ClientMessage{