From 6868f8b459da9c341ae01cda4ca0a8829118b3c5 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 12 May 2022 10:38:57 -0700 Subject: [PATCH 1/6] stats: Support multiple StatsHandlers --- dialoptions.go | 2 +- internal/transport/handler_server.go | 46 +++++---- internal/transport/http2_client.go | 87 +++++++++------- internal/transport/http2_server.go | 78 +++++++------- internal/transport/transport.go | 4 +- server.go | 146 ++++++++++++++------------- stats/stats_test.go | 114 +++++++++++++++++++-- stream.go | 107 +++++++++++--------- 8 files changed, 362 insertions(+), 222 deletions(-) diff --git a/dialoptions.go b/dialoptions.go index f2f605a17c4..001bff1a6f0 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -380,7 +380,7 @@ func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { // all the RPCs and underlying network connections in this ClientConn. func WithStatsHandler(h stats.Handler) DialOption { return newFuncDialOption(func(o *dialOptions) { - o.copts.StatsHandler = h + o.copts.StatsHandler = append(o.copts.StatsHandler, h) }) } diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 1c3459c2b4c..2dd0cd7f20d 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -49,7 +49,7 @@ import ( // NewServerHandlerTransport returns a ServerTransport handling gRPC // from inside an http.Handler. It requires that the http Server // supports HTTP/2. -func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) { +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) { if r.ProtoMajor != 2 { return nil, errors.New("gRPC requires HTTP/2") } @@ -138,7 +138,7 @@ type serverHandlerTransport struct { // TODO make sure this is consistent across handler_server and http2_server contentSubtype string - stats stats.Handler + stats []stats.Handler } func (ht *serverHandlerTransport) Close() { @@ -228,12 +228,14 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro }) if err == nil { // transport has not been closed - if ht.stats != nil { + if len(ht.stats) != 0 { // Note: The trailer fields are compressed with hpack after this call returns. // No WireLength field is set here. - ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{ - Trailer: s.trailer.Copy(), - }) + for _, sh := range ht.stats { + sh.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) + } } } ht.Close() @@ -314,13 +316,15 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { }) if err == nil { - if ht.stats != nil { - // Note: The header fields are compressed with hpack after this call returns. - // No WireLength field is set here. - ht.stats.HandleRPC(s.Context(), &stats.OutHeader{ - Header: md.Copy(), - Compression: s.sendCompress, - }) + if len(ht.stats) != 0 { + for _, sh := range ht.stats { + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + sh.HandleRPC(s.Context(), &stats.OutHeader{ + Header: md.Copy(), + Compression: s.sendCompress, + }) + } } } return err @@ -369,14 +373,16 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace } ctx = metadata.NewIncomingContext(ctx, ht.headerMD) s.ctx = peer.NewContext(ctx, pr) - if ht.stats != nil { - s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) - inHeader := &stats.InHeader{ - FullMethod: s.method, - RemoteAddr: ht.RemoteAddr(), - Compression: s.recvCompress, + if len(ht.stats) != 0 { + for _, sh := range ht.stats { + s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: ht.RemoteAddr(), + Compression: s.recvCompress, + } + sh.HandleRPC(s.ctx, inHeader) } - ht.stats.HandleRPC(s.ctx, inHeader) } s.trReader = &transportReader{ reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 24ca59084b4..3b3ccd08f21 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -90,7 +90,7 @@ type http2Client struct { kp keepalive.ClientParameters keepaliveEnabled bool - statsHandler stats.Handler + statsHandler []stats.Handler initialWindowSize int32 @@ -341,15 +341,17 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts updateFlowControl: t.updateFlowControl, } } - if t.statsHandler != nil { - t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - }) - connBegin := &stats.ConnBegin{ - Client: true, + if len(t.statsHandler) != 0 { + for _, sh := range t.statsHandler { + t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{ + Client: true, + } + sh.HandleConn(t.ctx, connBegin) } - t.statsHandler.HandleConn(t.ctx, connBegin) } t.channelzID, err = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr)) if err != nil { @@ -773,24 +775,27 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true} } } - if t.statsHandler != nil { + if len(t.statsHandler) != 0 { header, ok := metadata.FromOutgoingContext(ctx) if ok { header.Set("user-agent", t.userAgent) } else { header = metadata.Pairs("user-agent", t.userAgent) } - // Note: The header fields are compressed with hpack after this call returns. - // No WireLength field is set here. - outHeader := &stats.OutHeader{ - Client: true, - FullMethod: callHdr.Method, - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - Compression: callHdr.SendCompress, - Header: header, + for _, sh := range t.statsHandler { + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + // Note: Creating a new stats object to prevent pollution. + outHeader := &stats.OutHeader{ + Client: true, + FullMethod: callHdr.Method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: callHdr.SendCompress, + Header: header, + } + sh.HandleRPC(s.ctx, outHeader) } - t.statsHandler.HandleRPC(s.ctx, outHeader) } return s, nil } @@ -916,11 +921,13 @@ func (t *http2Client) Close(err error) { for _, s := range streams { t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) } - if t.statsHandler != nil { - connEnd := &stats.ConnEnd{ - Client: true, + if len(t.statsHandler) != 0 { + for _, sh := range t.statsHandler { + connEnd := &stats.ConnEnd{ + Client: true, + } + sh.HandleConn(t.ctx, connEnd) } - t.statsHandler.HandleConn(t.ctx, connEnd) } } @@ -1432,22 +1439,24 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { close(s.headerChan) } - if t.statsHandler != nil { - if isHeader { - inHeader := &stats.InHeader{ - Client: true, - WireLength: int(frame.Header().Length), - Header: metadata.MD(mdata).Copy(), - Compression: s.recvCompress, - } - t.statsHandler.HandleRPC(s.ctx, inHeader) - } else { - inTrailer := &stats.InTrailer{ - Client: true, - WireLength: int(frame.Header().Length), - Trailer: metadata.MD(mdata).Copy(), + if len(t.statsHandler) != 0 { + for _, sh := range t.statsHandler { + if isHeader { + inHeader := &stats.InHeader{ + Client: true, + WireLength: int(frame.Header().Length), + Header: metadata.MD(mdata).Copy(), + Compression: s.recvCompress, + } + sh.HandleRPC(s.ctx, inHeader) + } else { + inTrailer := &stats.InTrailer{ + Client: true, + WireLength: int(frame.Header().Length), + Trailer: metadata.MD(mdata).Copy(), + } + sh.HandleRPC(s.ctx, inTrailer) } - t.statsHandler.HandleRPC(s.ctx, inTrailer) } } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 4969102f4af..cdffc369d3a 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -82,7 +82,7 @@ type http2Server struct { // updates, reset streams, and various settings) to the controller. controlBuf *controlBuffer fc *trInFlow - stats stats.Handler + stats []stats.Handler // Keepalive and max-age parameters for the server. kp keepalive.ServerParameters // Keepalive enforcement policy. @@ -272,13 +272,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, updateFlowControl: t.updateFlowControl, } } - if t.stats != nil { - t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{ - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - }) - connBegin := &stats.ConnBegin{} - t.stats.HandleConn(t.ctx, connBegin) + if len(t.stats) != 0 { + for _, sh := range t.stats { + t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{} + sh.HandleConn(t.ctx, connBegin) + } } t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr)) if err != nil { @@ -566,17 +568,19 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.adjustWindow(s, uint32(n)) } s.ctx = traceCtx(s.ctx, s.method) - if t.stats != nil { - s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) - inHeader := &stats.InHeader{ - FullMethod: s.method, - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - Compression: s.recvCompress, - WireLength: int(frame.Header().Length), - Header: metadata.MD(mdata).Copy(), + if len(t.stats) != 0 { + for _, sh := range t.stats { + s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: s.recvCompress, + WireLength: int(frame.Header().Length), + Header: metadata.MD(mdata).Copy(), + } + sh.HandleRPC(s.ctx, inHeader) } - t.stats.HandleRPC(s.ctx, inHeader) } s.ctxDone = s.ctx.Done() s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) @@ -992,14 +996,16 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { t.closeStream(s, true, http2.ErrCodeInternal, false) return ErrHeaderListSizeLimitViolation } - if t.stats != nil { - // Note: Headers are compressed with hpack after this call returns. - // No WireLength field is set here. - outHeader := &stats.OutHeader{ - Header: s.header.Copy(), - Compression: s.sendCompress, + if len(t.stats) != 0 { + for _, sh := range t.stats { + // Note: Headers are compressed with hpack after this call returns. + // No WireLength field is set here. + outHeader := &stats.OutHeader{ + Header: s.header.Copy(), + Compression: s.sendCompress, + } + sh.HandleRPC(s.Context(), outHeader) } - t.stats.HandleRPC(s.Context(), outHeader) } return nil } @@ -1060,12 +1066,14 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { // Send a RST_STREAM after the trailers if the client has not already half-closed. rst := s.getState() == streamActive t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) - if t.stats != nil { - // Note: The trailer fields are compressed with hpack after this call returns. - // No WireLength field is set here. - t.stats.HandleRPC(s.Context(), &stats.OutTrailer{ - Trailer: s.trailer.Copy(), - }) + if len(t.stats) != 0 { + for _, sh := range t.stats { + // Note: The trailer fields are compressed with hpack after this call returns. + // No WireLength field is set here. + sh.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) + } } return nil } @@ -1218,9 +1226,11 @@ func (t *http2Server) Close() { for _, s := range streams { s.cancel() } - if t.stats != nil { - connEnd := &stats.ConnEnd{} - t.stats.HandleConn(t.ctx, connEnd) + if len(t.stats) != 0 { + for _, sh := range t.stats { + connEnd := &stats.ConnEnd{} + sh.HandleConn(t.ctx, connEnd) + } } } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index a9ce717f160..7068ae512d8 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -523,7 +523,7 @@ type ServerConfig struct { ConnectionTimeout time.Duration Credentials credentials.TransportCredentials InTapHandle tap.ServerInHandle - StatsHandler stats.Handler + StatsHandler []stats.Handler KeepaliveParams keepalive.ServerParameters KeepalivePolicy keepalive.EnforcementPolicy InitialWindowSize int32 @@ -554,7 +554,7 @@ type ConnectOptions struct { // KeepaliveParams stores the keepalive parameters. KeepaliveParams keepalive.ClientParameters // StatsHandler stores the handler for stats. - StatsHandler stats.Handler + StatsHandler []stats.Handler // InitialWindowSize sets the initial window size for a stream. InitialWindowSize int32 // InitialConnWindowSize sets the initial window size for a connection. diff --git a/server.go b/server.go index 65de84b3007..841f77cca7e 100644 --- a/server.go +++ b/server.go @@ -150,7 +150,7 @@ type serverOptions struct { chainUnaryInts []UnaryServerInterceptor chainStreamInts []StreamServerInterceptor inTapHandle tap.ServerInHandle - statsHandler stats.Handler + statsHandler []stats.Handler maxConcurrentStreams uint32 maxReceiveMessageSize int maxSendMessageSize int @@ -435,7 +435,7 @@ func InTapHandle(h tap.ServerInHandle) ServerOption { // StatsHandler returns a ServerOption that sets the stats handler for the server. func StatsHandler(h stats.Handler) ServerOption { return newFuncServerOption(func(o *serverOptions) { - o.statsHandler = h + o.statsHandler = append(o.statsHandler, h) }) } @@ -1076,8 +1076,10 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) } err = t.Write(stream, hdr, payload, opts) - if err == nil && s.opts.statsHandler != nil { - s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + if err == nil && len(s.opts.statsHandler) != 0 { + for _, sh := range s.opts.statsHandler { + sh.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + } } return err } @@ -1124,62 +1126,63 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { - sh := s.opts.statsHandler - if sh != nil || trInfo != nil || channelz.IsOn() { - if channelz.IsOn() { - s.incrCallsStarted() - } - var statsBegin *stats.Begin - if sh != nil { - beginTime := time.Now() - statsBegin = &stats.Begin{ - BeginTime: beginTime, - IsClientStream: false, - IsServerStream: false, + for _, sh := range s.opts.statsHandler { + if sh != nil || trInfo != nil || channelz.IsOn() { + if channelz.IsOn() { + s.incrCallsStarted() } - sh.HandleRPC(stream.Context(), statsBegin) - } - if trInfo != nil { - trInfo.tr.LazyLog(&trInfo.firstLine, false) - } - // The deferred error handling for tracing, stats handler and channelz are - // combined into one function to reduce stack usage -- a defer takes ~56-64 - // bytes on the stack, so overflowing the stack will require a stack - // re-allocation, which is expensive. - // - // To maintain behavior similar to separate deferred statements, statements - // should be executed in the reverse order. That is, tracing first, stats - // handler second, and channelz last. Note that panics *within* defers will - // lead to different behavior, but that's an acceptable compromise; that - // would be undefined behavior territory anyway. - defer func() { - if trInfo != nil { - if err != nil && err != io.EOF { - trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) - trInfo.tr.SetError() + var statsBegin *stats.Begin + if sh != nil { + beginTime := time.Now() + statsBegin = &stats.Begin{ + BeginTime: beginTime, + IsClientStream: false, + IsServerStream: false, } - trInfo.tr.Finish() + sh.HandleRPC(stream.Context(), statsBegin) } - - if sh != nil { - end := &stats.End{ - BeginTime: statsBegin.BeginTime, - EndTime: time.Now(), + if trInfo != nil { + trInfo.tr.LazyLog(&trInfo.firstLine, false) + } + // The deferred error handling for tracing, stats handler and channelz are + // combined into one function to reduce stack usage -- a defer takes ~56-64 + // bytes on the stack, so overflowing the stack will require a stack + // re-allocation, which is expensive. + // + // To maintain behavior similar to separate deferred statements, statements + // should be executed in the reverse order. That is, tracing first, stats + // handler second, and channelz last. Note that panics *within* defers will + // lead to different behavior, but that's an acceptable compromise; that + // would be undefined behavior territory anyway. + defer func() { + if trInfo != nil { + if err != nil && err != io.EOF { + trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + trInfo.tr.SetError() + } + trInfo.tr.Finish() } - if err != nil && err != io.EOF { - end.Error = toRPCErr(err) + + if sh != nil { + end := &stats.End{ + BeginTime: statsBegin.BeginTime, + EndTime: time.Now(), + } + if err != nil && err != io.EOF { + end.Error = toRPCErr(err) + } + sh.HandleRPC(stream.Context(), end) } - sh.HandleRPC(stream.Context(), end) - } - if channelz.IsOn() { - if err != nil && err != io.EOF { - s.incrCallsFailed() - } else { - s.incrCallsSucceeded() + if channelz.IsOn() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() + } } - } - }() + }() + } } binlog := binarylog.GetMethodLogger(stream.Method()) @@ -1243,7 +1246,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } var payInfo *payloadInfo - if sh != nil || binlog != nil { + if len(s.opts.statsHandler) != 0 || binlog != nil { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) @@ -1260,14 +1263,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } - if sh != nil { - sh.HandleRPC(stream.Context(), &stats.InPayload{ - RecvTime: time.Now(), - Payload: v, - WireLength: payInfo.wireLength + headerLen, - Data: d, - Length: len(d), - }) + if len(s.opts.statsHandler) != 0 { + for _, sh := range s.opts.statsHandler { + sh.HandleRPC(stream.Context(), &stats.InPayload{ + RecvTime: time.Now(), + Payload: v, + WireLength: payInfo.wireLength + headerLen, + Data: d, + Length: len(d), + }) + } } if binlog != nil { binlog.Log(&binarylog.ClientMessage{ @@ -1418,16 +1423,17 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if channelz.IsOn() { s.incrCallsStarted() } - sh := s.opts.statsHandler var statsBegin *stats.Begin - if sh != nil { + if len(s.opts.statsHandler) != 0 { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - sh.HandleRPC(stream.Context(), statsBegin) + for _, sh := range s.opts.statsHandler { + sh.HandleRPC(stream.Context(), statsBegin) + } } ctx := NewContextWithServerTransportStream(stream.Context(), stream) ss := &serverStream{ @@ -1439,10 +1445,10 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, - statsHandler: sh, + statsHandler: s.opts.statsHandler, } - if sh != nil || trInfo != nil || channelz.IsOn() { + if len(s.opts.statsHandler) != 0 || trInfo != nil || channelz.IsOn() { // See comment in processUnaryRPC on defers. defer func() { if trInfo != nil { @@ -1456,7 +1462,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } - if sh != nil { + if len(s.opts.statsHandler) != 0 { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1464,7 +1470,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - sh.HandleRPC(stream.Context(), end) + for _, sh := range s.opts.statsHandler { + sh.HandleRPC(stream.Context(), end) + } } if channelz.IsOn() { diff --git a/stats/stats_test.go b/stats/stats_test.go index 1b08568b906..8537386f646 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -178,8 +178,8 @@ func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, type test struct { t *testing.T compress string - clientStatsHandler stats.Handler - serverStatsHandler stats.Handler + clientStatsHandler []stats.Handler + serverStatsHandler []stats.Handler testServer testgrpc.TestServiceServer // nil means none // srv and srvAddr are set once startServer is called. @@ -204,7 +204,7 @@ type testConfig struct { // newTest returns a new test using the provided testing.T and // environment. It is returned with default values. Tests should // modify it before calling its startServer and clientConn methods. -func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test { +func newTest(t *testing.T, tc *testConfig, ch []stats.Handler, sh []stats.Handler) *test { te := &test{ t: t, compress: tc.compress, @@ -229,8 +229,10 @@ func (te *test) startServer(ts testgrpc.TestServiceServer) { grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.serverStatsHandler != nil { - opts = append(opts, grpc.StatsHandler(te.serverStatsHandler)) + if len(te.serverStatsHandler) != 0 { + for _, sh := range te.serverStatsHandler { + opts = append(opts, grpc.StatsHandler(sh)) + } } s := grpc.NewServer(opts...) te.srv = s @@ -257,8 +259,10 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.clientStatsHandler != nil { - opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler)) + if len(te.clientStatsHandler) != 0 { + for _, sh := range te.clientStatsHandler { + opts = append(opts, grpc.WithStatsHandler(sh)) + } } var err error @@ -846,7 +850,7 @@ func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkF func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { h := &statshandler{} - te := newTest(t, tc, nil, h) + te := newTest(t, tc, nil, []stats.Handler{h}) te.startServer(&testServer{}) defer te.tearDown() @@ -1146,7 +1150,7 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) { h := &statshandler{} - te := newTest(t, tc, h, nil) + te := newTest(t, tc, []stats.Handler{h}, nil) te.startServer(&testServer{}) defer te.tearDown() @@ -1375,3 +1379,95 @@ func (s) TestTrace(t *testing.T) { t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr) } } + +func (s) TestMultipleClientStatsHandler(t *testing.T) { + h := &statshandler{} + tc := &testConfig{compress: ""} + te := newTest(t, tc, []stats.Handler{h, h}, nil) + te.startServer(&testServer{}) + defer te.tearDown() + + cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC} + _, _, err := te.doUnaryCall(cc) + if cc.success != (err == nil) { + t.Fatalf("cc.success: %v, got error: %v", cc.success, err) + } + te.cc.Close() + te.srv.GracefulStop() // Wait for the server to stop. + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + // Each RPC generates 6 stats events on the client-side, times 2 StatsHandler + if len(h.gotRPC) != 12 { + t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) + } + + // Each connection generates 4 conn events on the client-side, times 2 StatsHandler + if len(h.gotConn) != 4 { + t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) + } +} + +func (s) TestMultipleServerStatsHandler(t *testing.T) { + h := &statshandler{} + tc := &testConfig{compress: ""} + te := newTest(t, tc, nil, []stats.Handler{h, h}) + te.startServer(&testServer{}) + defer te.tearDown() + + cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC} + _, _, err := te.doUnaryCall(cc) + if cc.success != (err == nil) { + t.Fatalf("cc.success: %v, got error: %v", cc.success, err) + } + te.cc.Close() + te.srv.GracefulStop() // Wait for the server to stop. + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + // Each RPC generates 6 stats events on the server-side, times 2 StatsHandler + if len(h.gotRPC) != 12 { + t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) + } + + // Each connection generates 4 conn events on the server-side, times 2 StatsHandler + if len(h.gotConn) != 4 { + t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4) + } +} diff --git a/stream.go b/stream.go index 236fc17ec3c..5fef7f2e332 100644 --- a/stream.go +++ b/stream.go @@ -374,20 +374,21 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.cp, cs.comp) method := cs.callHdr.Method - sh := cs.cc.dopts.copts.StatsHandler var beginTime time.Time - if sh != nil { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) - beginTime = time.Now() - begin := &stats.Begin{ - Client: true, - BeginTime: beginTime, - FailFast: cs.callInfo.failFast, - IsClientStream: cs.desc.ClientStreams, - IsServerStream: cs.desc.ServerStreams, - IsTransparentRetryAttempt: isTransparent, + if len(cs.cc.dopts.copts.StatsHandler) != 0 { + for _, sh := range cs.cc.dopts.copts.StatsHandler { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) + beginTime = time.Now() + begin := &stats.Begin{ + Client: true, + BeginTime: beginTime, + FailFast: cs.callInfo.failFast, + IsClientStream: cs.desc.ClientStreams, + IsServerStream: cs.desc.ServerStreams, + IsTransparentRetryAttempt: isTransparent, + } + sh.HandleRPC(ctx, begin) } - sh.HandleRPC(ctx, begin) } var trInfo *traceInfo @@ -418,7 +419,7 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) beginTime: beginTime, cs: cs, dc: cs.cc.dopts.dc, - statsHandler: sh, + statsHandler: cs.cc.dopts.copts.StatsHandler, trInfo: trInfo, }, nil } @@ -536,7 +537,7 @@ type csAttempt struct { // and cleared when the finish method is called. trInfo *traceInfo - statsHandler stats.Handler + statsHandler []stats.Handler beginTime time.Time // set for newStream errors that may be transparently retried @@ -960,8 +961,10 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { } return io.EOF } - if a.statsHandler != nil { - a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) + if len(a.statsHandler) != 0 { + for _, sh := range a.statsHandler { + sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) + } } if channelz.IsOn() { a.t.IncrMsgSent() @@ -971,7 +974,7 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { cs := a.cs - if a.statsHandler != nil && payInfo == nil { + if len(a.statsHandler) != 0 && payInfo == nil { payInfo = &payloadInfo{} } @@ -1008,16 +1011,18 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { } a.mu.Unlock() } - if a.statsHandler != nil { - a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{ - Client: true, - RecvTime: time.Now(), - Payload: m, - // TODO truncate large payload. - Data: payInfo.uncompressedBytes, - WireLength: payInfo.wireLength + headerLen, - Length: len(payInfo.uncompressedBytes), - }) + if len(a.statsHandler) != 0 { + for _, sh := range a.statsHandler { + sh.HandleRPC(a.ctx, &stats.InPayload{ + Client: true, + RecvTime: time.Now(), + Payload: m, + // TODO truncate large payload. + Data: payInfo.uncompressedBytes, + WireLength: payInfo.wireLength + headerLen, + Length: len(payInfo.uncompressedBytes), + }) + } } if channelz.IsOn() { a.t.IncrMsgRecv() @@ -1068,15 +1073,17 @@ func (a *csAttempt) finish(err error) { ServerLoad: balancerload.Parse(tr), }) } - if a.statsHandler != nil { - end := &stats.End{ - Client: true, - BeginTime: a.beginTime, - EndTime: time.Now(), - Trailer: tr, - Error: err, + if len(a.statsHandler) != 0 { + for _, sh := range a.statsHandler { + end := &stats.End{ + Client: true, + BeginTime: a.beginTime, + EndTime: time.Now(), + Trailer: tr, + Error: err, + } + sh.HandleRPC(a.ctx, end) } - a.statsHandler.HandleRPC(a.ctx, end) } if a.trInfo != nil && a.trInfo.tr != nil { if err == nil { @@ -1445,7 +1452,7 @@ type serverStream struct { maxSendMessageSize int trInfo *traceInfo - statsHandler stats.Handler + statsHandler []stats.Handler binlog binarylog.MethodLogger // serverHeaderBinlogged indicates whether server header has been logged. It @@ -1555,8 +1562,10 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { Message: data, }) } - if ss.statsHandler != nil { - ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now())) + if len(ss.statsHandler) != 0 { + for _, sh := range ss.statsHandler { + sh.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now())) + } } return nil } @@ -1590,7 +1599,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } }() var payInfo *payloadInfo - if ss.statsHandler != nil || ss.binlog != nil { + if len(ss.statsHandler) != 0 || ss.binlog != nil { payInfo = &payloadInfo{} } if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { @@ -1605,15 +1614,17 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } return toRPCErr(err) } - if ss.statsHandler != nil { - ss.statsHandler.HandleRPC(ss.s.Context(), &stats.InPayload{ - RecvTime: time.Now(), - Payload: m, - // TODO truncate large payload. - Data: payInfo.uncompressedBytes, - WireLength: payInfo.wireLength + headerLen, - Length: len(payInfo.uncompressedBytes), - }) + if len(ss.statsHandler) != 0 { + for _, sh := range ss.statsHandler { + sh.HandleRPC(ss.s.Context(), &stats.InPayload{ + RecvTime: time.Now(), + Payload: m, + // TODO truncate large payload. + Data: payInfo.uncompressedBytes, + WireLength: payInfo.wireLength + headerLen, + Length: len(payInfo.uncompressedBytes), + }) + } } if ss.binlog != nil { ss.binlog.Log(&binarylog.ClientMessage{ From 423d741688a1d72a1570a7df9a9717384afd47ce Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 12 May 2022 11:33:41 -0700 Subject: [PATCH 2/6] Fix a side-impact to channelz --- server.go | 73 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/server.go b/server.go index 841f77cca7e..e9023b4bb0b 100644 --- a/server.go +++ b/server.go @@ -1126,13 +1126,14 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { - for _, sh := range s.opts.statsHandler { - if sh != nil || trInfo != nil || channelz.IsOn() { - if channelz.IsOn() { - s.incrCallsStarted() - } - var statsBegin *stats.Begin - if sh != nil { + + if len(s.opts.statsHandler) != 0 || trInfo != nil || channelz.IsOn() { + if channelz.IsOn() { + s.incrCallsStarted() + } + var statsBegin *stats.Begin + if len(s.opts.statsHandler) != 0 { + for _, sh := range s.opts.statsHandler { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, @@ -1141,29 +1142,31 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } sh.HandleRPC(stream.Context(), statsBegin) } + } + if trInfo != nil { + trInfo.tr.LazyLog(&trInfo.firstLine, false) + } + // The deferred error handling for tracing, stats handler and channelz are + // combined into one function to reduce stack usage -- a defer takes ~56-64 + // bytes on the stack, so overflowing the stack will require a stack + // re-allocation, which is expensive. + // + // To maintain behavior similar to separate deferred statements, statements + // should be executed in the reverse order. That is, tracing first, stats + // handler second, and channelz last. Note that panics *within* defers will + // lead to different behavior, but that's an acceptable compromise; that + // would be undefined behavior territory anyway. + defer func() { if trInfo != nil { - trInfo.tr.LazyLog(&trInfo.firstLine, false) - } - // The deferred error handling for tracing, stats handler and channelz are - // combined into one function to reduce stack usage -- a defer takes ~56-64 - // bytes on the stack, so overflowing the stack will require a stack - // re-allocation, which is expensive. - // - // To maintain behavior similar to separate deferred statements, statements - // should be executed in the reverse order. That is, tracing first, stats - // handler second, and channelz last. Note that panics *within* defers will - // lead to different behavior, but that's an acceptable compromise; that - // would be undefined behavior territory anyway. - defer func() { - if trInfo != nil { - if err != nil && err != io.EOF { - trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) - trInfo.tr.SetError() - } - trInfo.tr.Finish() + if err != nil && err != io.EOF { + trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + trInfo.tr.SetError() } + trInfo.tr.Finish() + } - if sh != nil { + if len(s.opts.statsHandler) != 0 { + for _, sh := range s.opts.statsHandler { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1173,16 +1176,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } sh.HandleRPC(stream.Context(), end) } + } - if channelz.IsOn() { - if err != nil && err != io.EOF { - s.incrCallsFailed() - } else { - s.incrCallsSucceeded() - } + if channelz.IsOn() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() } - }() - } + } + }() } binlog := binarylog.GetMethodLogger(stream.Method()) From e78e2113ead3d71b1dc6006196d7ed602e2e7184 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 1 Jun 2022 15:58:13 -0700 Subject: [PATCH 3/6] Adopt reviewer's suggestions: 1. Pluralize the variable names for slices 2. Remove the unnecessary ifs 3. Simplify the unit test --- dialoptions.go | 2 +- internal/transport/handler_server.go | 46 +++++++------- internal/transport/http2_client.go | 68 ++++++++++----------- internal/transport/http2_server.go | 78 +++++++++++------------- internal/transport/transport.go | 6 +- server.go | 80 ++++++++++++------------- stats/stats_test.go | 70 ++++------------------ stream.go | 89 +++++++++++++--------------- 8 files changed, 180 insertions(+), 259 deletions(-) diff --git a/dialoptions.go b/dialoptions.go index 001bff1a6f0..2c7d6b05c19 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -380,7 +380,7 @@ func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { // all the RPCs and underlying network connections in this ClientConn. func WithStatsHandler(h stats.Handler) DialOption { return newFuncDialOption(func(o *dialOptions) { - o.copts.StatsHandler = append(o.copts.StatsHandler, h) + o.copts.StatsHandlers = append(o.copts.StatsHandlers, h) }) } diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 2dd0cd7f20d..090120925bb 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -228,14 +228,12 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro }) if err == nil { // transport has not been closed - if len(ht.stats) != 0 { - // Note: The trailer fields are compressed with hpack after this call returns. - // No WireLength field is set here. - for _, sh := range ht.stats { - sh.HandleRPC(s.Context(), &stats.OutTrailer{ - Trailer: s.trailer.Copy(), - }) - } + // Note: The trailer fields are compressed with hpack after this call returns. + // No WireLength field is set here. + for _, sh := range ht.stats { + sh.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) } } ht.Close() @@ -316,15 +314,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { }) if err == nil { - if len(ht.stats) != 0 { - for _, sh := range ht.stats { - // Note: The header fields are compressed with hpack after this call returns. - // No WireLength field is set here. - sh.HandleRPC(s.Context(), &stats.OutHeader{ - Header: md.Copy(), - Compression: s.sendCompress, - }) - } + for _, sh := range ht.stats { + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + sh.HandleRPC(s.Context(), &stats.OutHeader{ + Header: md.Copy(), + Compression: s.sendCompress, + }) } } return err @@ -373,16 +369,14 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace } ctx = metadata.NewIncomingContext(ctx, ht.headerMD) s.ctx = peer.NewContext(ctx, pr) - if len(ht.stats) != 0 { - for _, sh := range ht.stats { - s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) - inHeader := &stats.InHeader{ - FullMethod: s.method, - RemoteAddr: ht.RemoteAddr(), - Compression: s.recvCompress, - } - sh.HandleRPC(s.ctx, inHeader) + for _, sh := range ht.stats { + s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: ht.RemoteAddr(), + Compression: s.recvCompress, } + sh.HandleRPC(s.ctx, inHeader) } s.trReader = &transportReader{ reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 3b3ccd08f21..be371c6e0f7 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -90,7 +90,7 @@ type http2Client struct { kp keepalive.ClientParameters keepaliveEnabled bool - statsHandler []stats.Handler + statsHandlers []stats.Handler initialWindowSize int32 @@ -311,7 +311,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts isSecure: isSecure, perRPCCreds: perRPCCreds, kp: kp, - statsHandler: opts.StatsHandler, + statsHandlers: opts.StatsHandlers, initialWindowSize: initialWindowSize, onPrefaceReceipt: onPrefaceReceipt, nextID: 1, @@ -341,17 +341,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts updateFlowControl: t.updateFlowControl, } } - if len(t.statsHandler) != 0 { - for _, sh := range t.statsHandler { - t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - }) - connBegin := &stats.ConnBegin{ - Client: true, - } - sh.HandleConn(t.ctx, connBegin) + for _, sh := range t.statsHandlers { + t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{ + Client: true, } + sh.HandleConn(t.ctx, connBegin) } t.channelzID, err = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr)) if err != nil { @@ -775,14 +773,14 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true} } } - if len(t.statsHandler) != 0 { + if len(t.statsHandlers) != 0 { header, ok := metadata.FromOutgoingContext(ctx) if ok { header.Set("user-agent", t.userAgent) } else { header = metadata.Pairs("user-agent", t.userAgent) } - for _, sh := range t.statsHandler { + for _, sh := range t.statsHandlers { // Note: The header fields are compressed with hpack after this call returns. // No WireLength field is set here. // Note: Creating a new stats object to prevent pollution. @@ -921,13 +919,11 @@ func (t *http2Client) Close(err error) { for _, s := range streams { t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) } - if len(t.statsHandler) != 0 { - for _, sh := range t.statsHandler { - connEnd := &stats.ConnEnd{ - Client: true, - } - sh.HandleConn(t.ctx, connEnd) + for _, sh := range t.statsHandlers { + connEnd := &stats.ConnEnd{ + Client: true, } + sh.HandleConn(t.ctx, connEnd) } } @@ -1439,24 +1435,22 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { close(s.headerChan) } - if len(t.statsHandler) != 0 { - for _, sh := range t.statsHandler { - if isHeader { - inHeader := &stats.InHeader{ - Client: true, - WireLength: int(frame.Header().Length), - Header: metadata.MD(mdata).Copy(), - Compression: s.recvCompress, - } - sh.HandleRPC(s.ctx, inHeader) - } else { - inTrailer := &stats.InTrailer{ - Client: true, - WireLength: int(frame.Header().Length), - Trailer: metadata.MD(mdata).Copy(), - } - sh.HandleRPC(s.ctx, inTrailer) + for _, sh := range t.statsHandlers { + if isHeader { + inHeader := &stats.InHeader{ + Client: true, + WireLength: int(frame.Header().Length), + Header: metadata.MD(mdata).Copy(), + Compression: s.recvCompress, + } + sh.HandleRPC(s.ctx, inHeader) + } else { + inTrailer := &stats.InTrailer{ + Client: true, + WireLength: int(frame.Header().Length), + Trailer: metadata.MD(mdata).Copy(), } + sh.HandleRPC(s.ctx, inTrailer) } } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index cdffc369d3a..d6d3cb01b8c 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -257,7 +257,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, fc: &trInFlow{limit: uint32(icwz)}, state: reachable, activeStreams: make(map[uint32]*Stream), - stats: config.StatsHandler, + stats: config.StatsHandlers, kp: kp, idle: time.Now(), kep: kep, @@ -272,15 +272,13 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, updateFlowControl: t.updateFlowControl, } } - if len(t.stats) != 0 { - for _, sh := range t.stats { - t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - }) - connBegin := &stats.ConnBegin{} - sh.HandleConn(t.ctx, connBegin) - } + for _, sh := range t.stats { + t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{} + sh.HandleConn(t.ctx, connBegin) } t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr)) if err != nil { @@ -568,19 +566,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.adjustWindow(s, uint32(n)) } s.ctx = traceCtx(s.ctx, s.method) - if len(t.stats) != 0 { - for _, sh := range t.stats { - s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) - inHeader := &stats.InHeader{ - FullMethod: s.method, - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - Compression: s.recvCompress, - WireLength: int(frame.Header().Length), - Header: metadata.MD(mdata).Copy(), - } - sh.HandleRPC(s.ctx, inHeader) + for _, sh := range t.stats { + s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: s.recvCompress, + WireLength: int(frame.Header().Length), + Header: metadata.MD(mdata).Copy(), } + sh.HandleRPC(s.ctx, inHeader) } s.ctxDone = s.ctx.Done() s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) @@ -996,16 +992,14 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { t.closeStream(s, true, http2.ErrCodeInternal, false) return ErrHeaderListSizeLimitViolation } - if len(t.stats) != 0 { - for _, sh := range t.stats { - // Note: Headers are compressed with hpack after this call returns. - // No WireLength field is set here. - outHeader := &stats.OutHeader{ - Header: s.header.Copy(), - Compression: s.sendCompress, - } - sh.HandleRPC(s.Context(), outHeader) + for _, sh := range t.stats { + // Note: Headers are compressed with hpack after this call returns. + // No WireLength field is set here. + outHeader := &stats.OutHeader{ + Header: s.header.Copy(), + Compression: s.sendCompress, } + sh.HandleRPC(s.Context(), outHeader) } return nil } @@ -1066,14 +1060,12 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { // Send a RST_STREAM after the trailers if the client has not already half-closed. rst := s.getState() == streamActive t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) - if len(t.stats) != 0 { - for _, sh := range t.stats { - // Note: The trailer fields are compressed with hpack after this call returns. - // No WireLength field is set here. - sh.HandleRPC(s.Context(), &stats.OutTrailer{ - Trailer: s.trailer.Copy(), - }) - } + for _, sh := range t.stats { + // Note: The trailer fields are compressed with hpack after this call returns. + // No WireLength field is set here. + sh.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) } return nil } @@ -1226,11 +1218,9 @@ func (t *http2Server) Close() { for _, s := range streams { s.cancel() } - if len(t.stats) != 0 { - for _, sh := range t.stats { - connEnd := &stats.ConnEnd{} - sh.HandleConn(t.ctx, connEnd) - } + for _, sh := range t.stats { + connEnd := &stats.ConnEnd{} + sh.HandleConn(t.ctx, connEnd) } } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 7068ae512d8..6c3ba851594 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -523,7 +523,7 @@ type ServerConfig struct { ConnectionTimeout time.Duration Credentials credentials.TransportCredentials InTapHandle tap.ServerInHandle - StatsHandler []stats.Handler + StatsHandlers []stats.Handler KeepaliveParams keepalive.ServerParameters KeepalivePolicy keepalive.EnforcementPolicy InitialWindowSize int32 @@ -553,8 +553,8 @@ type ConnectOptions struct { CredsBundle credentials.Bundle // KeepaliveParams stores the keepalive parameters. KeepaliveParams keepalive.ClientParameters - // StatsHandler stores the handler for stats. - StatsHandler []stats.Handler + // StatsHandlers stores the handler for stats. + StatsHandlers []stats.Handler // InitialWindowSize sets the initial window size for a stream. InitialWindowSize int32 // InitialConnWindowSize sets the initial window size for a connection. diff --git a/server.go b/server.go index e9023b4bb0b..fa0c0910adf 100644 --- a/server.go +++ b/server.go @@ -150,7 +150,7 @@ type serverOptions struct { chainUnaryInts []UnaryServerInterceptor chainStreamInts []StreamServerInterceptor inTapHandle tap.ServerInHandle - statsHandler []stats.Handler + statsHandlers []stats.Handler maxConcurrentStreams uint32 maxReceiveMessageSize int maxSendMessageSize int @@ -435,7 +435,7 @@ func InTapHandle(h tap.ServerInHandle) ServerOption { // StatsHandler returns a ServerOption that sets the stats handler for the server. func StatsHandler(h stats.Handler) ServerOption { return newFuncServerOption(func(o *serverOptions) { - o.statsHandler = append(o.statsHandler, h) + o.statsHandlers = append(o.statsHandlers, h) }) } @@ -867,7 +867,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { ConnectionTimeout: s.opts.connectionTimeout, Credentials: s.opts.creds, InTapHandle: s.opts.inTapHandle, - StatsHandler: s.opts.statsHandler, + StatsHandlers: s.opts.statsHandlers, KeepaliveParams: s.opts.keepaliveParams, KeepalivePolicy: s.opts.keepalivePolicy, InitialWindowSize: s.opts.initialWindowSize, @@ -963,7 +963,7 @@ var _ http.Handler = (*Server)(nil) // Notice: This API is EXPERIMENTAL and may be changed or removed in a // later release. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler) + st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -1076,8 +1076,8 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) } err = t.Write(stream, hdr, payload, opts) - if err == nil && len(s.opts.statsHandler) != 0 { - for _, sh := range s.opts.statsHandler { + if err == nil { + for _, sh := range s.opts.statsHandlers { sh.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) } } @@ -1127,21 +1127,19 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { - if len(s.opts.statsHandler) != 0 || trInfo != nil || channelz.IsOn() { + if len(s.opts.statsHandlers) != 0 || trInfo != nil || channelz.IsOn() { if channelz.IsOn() { s.incrCallsStarted() } var statsBegin *stats.Begin - if len(s.opts.statsHandler) != 0 { - for _, sh := range s.opts.statsHandler { - beginTime := time.Now() - statsBegin = &stats.Begin{ - BeginTime: beginTime, - IsClientStream: false, - IsServerStream: false, - } - sh.HandleRPC(stream.Context(), statsBegin) + for _, sh := range s.opts.statsHandlers { + beginTime := time.Now() + statsBegin = &stats.Begin{ + BeginTime: beginTime, + IsClientStream: false, + IsServerStream: false, } + sh.HandleRPC(stream.Context(), statsBegin) } if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) @@ -1165,17 +1163,15 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. trInfo.tr.Finish() } - if len(s.opts.statsHandler) != 0 { - for _, sh := range s.opts.statsHandler { - end := &stats.End{ - BeginTime: statsBegin.BeginTime, - EndTime: time.Now(), - } - if err != nil && err != io.EOF { - end.Error = toRPCErr(err) - } - sh.HandleRPC(stream.Context(), end) + for _, sh := range s.opts.statsHandlers { + end := &stats.End{ + BeginTime: statsBegin.BeginTime, + EndTime: time.Now(), + } + if err != nil && err != io.EOF { + end.Error = toRPCErr(err) } + sh.HandleRPC(stream.Context(), end) } if channelz.IsOn() { @@ -1249,7 +1245,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } var payInfo *payloadInfo - if len(s.opts.statsHandler) != 0 || binlog != nil { + if len(s.opts.statsHandlers) != 0 || binlog != nil { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) @@ -1266,16 +1262,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } - if len(s.opts.statsHandler) != 0 { - for _, sh := range s.opts.statsHandler { - sh.HandleRPC(stream.Context(), &stats.InPayload{ - RecvTime: time.Now(), - Payload: v, - WireLength: payInfo.wireLength + headerLen, - Data: d, - Length: len(d), - }) - } + for _, sh := range s.opts.statsHandlers { + sh.HandleRPC(stream.Context(), &stats.InPayload{ + RecvTime: time.Now(), + Payload: v, + WireLength: payInfo.wireLength + headerLen, + Data: d, + Length: len(d), + }) } if binlog != nil { binlog.Log(&binarylog.ClientMessage{ @@ -1427,14 +1421,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp s.incrCallsStarted() } var statsBegin *stats.Begin - if len(s.opts.statsHandler) != 0 { + if len(s.opts.statsHandlers) != 0 { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - for _, sh := range s.opts.statsHandler { + for _, sh := range s.opts.statsHandlers { sh.HandleRPC(stream.Context(), statsBegin) } } @@ -1448,10 +1442,10 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, - statsHandler: s.opts.statsHandler, + statsHandler: s.opts.statsHandlers, } - if len(s.opts.statsHandler) != 0 || trInfo != nil || channelz.IsOn() { + if len(s.opts.statsHandlers) != 0 || trInfo != nil || channelz.IsOn() { // See comment in processUnaryRPC on defers. defer func() { if trInfo != nil { @@ -1465,7 +1459,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } - if len(s.opts.statsHandler) != 0 { + if len(s.opts.statsHandlers) != 0 { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1473,7 +1467,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - for _, sh := range s.opts.statsHandler { + for _, sh := range s.opts.statsHandlers { sh.HandleRPC(stream.Context(), end) } } diff --git a/stats/stats_test.go b/stats/stats_test.go index 8537386f646..03fcb102d84 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -176,10 +176,10 @@ func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, // func, modified as needed, and then started with its startServer method. // It should be cleaned up with the tearDown method. type test struct { - t *testing.T - compress string - clientStatsHandler []stats.Handler - serverStatsHandler []stats.Handler + t *testing.T + compress string + clientStatsHandlers []stats.Handler + serverStatsHandlers []stats.Handler testServer testgrpc.TestServiceServer // nil means none // srv and srvAddr are set once startServer is called. @@ -204,12 +204,12 @@ type testConfig struct { // newTest returns a new test using the provided testing.T and // environment. It is returned with default values. Tests should // modify it before calling its startServer and clientConn methods. -func newTest(t *testing.T, tc *testConfig, ch []stats.Handler, sh []stats.Handler) *test { +func newTest(t *testing.T, tc *testConfig, chs []stats.Handler, shs []stats.Handler) *test { te := &test{ - t: t, - compress: tc.compress, - clientStatsHandler: ch, - serverStatsHandler: sh, + t: t, + compress: tc.compress, + clientStatsHandlers: chs, + serverStatsHandlers: shs, } return te } @@ -229,10 +229,8 @@ func (te *test) startServer(ts testgrpc.TestServiceServer) { grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), ) } - if len(te.serverStatsHandler) != 0 { - for _, sh := range te.serverStatsHandler { - opts = append(opts, grpc.StatsHandler(sh)) - } + for _, sh := range te.serverStatsHandlers { + opts = append(opts, grpc.StatsHandler(sh)) } s := grpc.NewServer(opts...) te.srv = s @@ -259,10 +257,8 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } - if len(te.clientStatsHandler) != 0 { - for _, sh := range te.clientStatsHandler { - opts = append(opts, grpc.WithStatsHandler(sh)) - } + for _, sh := range te.clientStatsHandlers { + opts = append(opts, grpc.WithStatsHandler(sh)) } var err error @@ -1395,26 +1391,6 @@ func (s) TestMultipleClientStatsHandler(t *testing.T) { te.cc.Close() te.srv.GracefulStop() // Wait for the server to stop. - for start := time.Now(); time.Since(start) < defaultTestTimeout; { - h.mu.Lock() - if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { - h.mu.Unlock() - break - } - h.mu.Unlock() - time.Sleep(10 * time.Millisecond) - } - - for start := time.Now(); time.Since(start) < defaultTestTimeout; { - h.mu.Lock() - if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { - h.mu.Unlock() - break - } - h.mu.Unlock() - time.Sleep(10 * time.Millisecond) - } - // Each RPC generates 6 stats events on the client-side, times 2 StatsHandler if len(h.gotRPC) != 12 { t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) @@ -1441,26 +1417,6 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) { te.cc.Close() te.srv.GracefulStop() // Wait for the server to stop. - for start := time.Now(); time.Since(start) < defaultTestTimeout; { - h.mu.Lock() - if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { - h.mu.Unlock() - break - } - h.mu.Unlock() - time.Sleep(10 * time.Millisecond) - } - - for start := time.Now(); time.Since(start) < defaultTestTimeout; { - h.mu.Lock() - if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { - h.mu.Unlock() - break - } - h.mu.Unlock() - time.Sleep(10 * time.Millisecond) - } - // Each RPC generates 6 stats events on the server-side, times 2 StatsHandler if len(h.gotRPC) != 12 { t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) diff --git a/stream.go b/stream.go index 5fef7f2e332..6d82e0d7cca 100644 --- a/stream.go +++ b/stream.go @@ -375,20 +375,19 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.cp, cs.comp) method := cs.callHdr.Method var beginTime time.Time - if len(cs.cc.dopts.copts.StatsHandler) != 0 { - for _, sh := range cs.cc.dopts.copts.StatsHandler { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) - beginTime = time.Now() - begin := &stats.Begin{ - Client: true, - BeginTime: beginTime, - FailFast: cs.callInfo.failFast, - IsClientStream: cs.desc.ClientStreams, - IsServerStream: cs.desc.ServerStreams, - IsTransparentRetryAttempt: isTransparent, - } - sh.HandleRPC(ctx, begin) + shs := cs.cc.dopts.copts.StatsHandlers + for _, sh := range shs { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast}) + beginTime = time.Now() + begin := &stats.Begin{ + Client: true, + BeginTime: beginTime, + FailFast: cs.callInfo.failFast, + IsClientStream: cs.desc.ClientStreams, + IsServerStream: cs.desc.ServerStreams, + IsTransparentRetryAttempt: isTransparent, } + sh.HandleRPC(ctx, begin) } var trInfo *traceInfo @@ -415,12 +414,12 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) } return &csAttempt{ - ctx: ctx, - beginTime: beginTime, - cs: cs, - dc: cs.cc.dopts.dc, - statsHandler: cs.cc.dopts.copts.StatsHandler, - trInfo: trInfo, + ctx: ctx, + beginTime: beginTime, + cs: cs, + dc: cs.cc.dopts.dc, + statsHandlers: shs, + trInfo: trInfo, }, nil } @@ -537,8 +536,8 @@ type csAttempt struct { // and cleared when the finish method is called. trInfo *traceInfo - statsHandler []stats.Handler - beginTime time.Time + statsHandlers []stats.Handler + beginTime time.Time // set for newStream errors that may be transparently retried allowTransparentRetry bool @@ -961,10 +960,8 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { } return io.EOF } - if len(a.statsHandler) != 0 { - for _, sh := range a.statsHandler { - sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) - } + for _, sh := range a.statsHandlers { + sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) } if channelz.IsOn() { a.t.IncrMsgSent() @@ -974,7 +971,7 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { cs := a.cs - if len(a.statsHandler) != 0 && payInfo == nil { + if len(a.statsHandlers) != 0 && payInfo == nil { payInfo = &payloadInfo{} } @@ -1011,18 +1008,16 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { } a.mu.Unlock() } - if len(a.statsHandler) != 0 { - for _, sh := range a.statsHandler { - sh.HandleRPC(a.ctx, &stats.InPayload{ - Client: true, - RecvTime: time.Now(), - Payload: m, - // TODO truncate large payload. - Data: payInfo.uncompressedBytes, - WireLength: payInfo.wireLength + headerLen, - Length: len(payInfo.uncompressedBytes), - }) - } + for _, sh := range a.statsHandlers { + sh.HandleRPC(a.ctx, &stats.InPayload{ + Client: true, + RecvTime: time.Now(), + Payload: m, + // TODO truncate large payload. + Data: payInfo.uncompressedBytes, + WireLength: payInfo.wireLength + headerLen, + Length: len(payInfo.uncompressedBytes), + }) } if channelz.IsOn() { a.t.IncrMsgRecv() @@ -1073,17 +1068,15 @@ func (a *csAttempt) finish(err error) { ServerLoad: balancerload.Parse(tr), }) } - if len(a.statsHandler) != 0 { - for _, sh := range a.statsHandler { - end := &stats.End{ - Client: true, - BeginTime: a.beginTime, - EndTime: time.Now(), - Trailer: tr, - Error: err, - } - sh.HandleRPC(a.ctx, end) + for _, sh := range a.statsHandlers { + end := &stats.End{ + Client: true, + BeginTime: a.beginTime, + EndTime: time.Now(), + Trailer: tr, + Error: err, } + sh.HandleRPC(a.ctx, end) } if a.trInfo != nil && a.trInfo.tr != nil { if err == nil { From b7e4cf22711be333c13f80ddf955e7b8d4df0fc8 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 2 Jun 2022 10:22:16 -0700 Subject: [PATCH 4/6] Stablize stats_test --- stats/stats_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/stats/stats_test.go b/stats/stats_test.go index 03fcb102d84..9a1a6c11253 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -1391,6 +1391,26 @@ func (s) TestMultipleClientStatsHandler(t *testing.T) { te.cc.Close() te.srv.GracefulStop() // Wait for the server to stop. + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + // Each RPC generates 6 stats events on the client-side, times 2 StatsHandler if len(h.gotRPC) != 12 { t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) @@ -1417,6 +1437,26 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) { te.cc.Close() te.srv.GracefulStop() // Wait for the server to stop. + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + + for start := time.Now(); time.Since(start) < defaultTestTimeout; { + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() + break + } + h.mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + // Each RPC generates 6 stats events on the server-side, times 2 StatsHandler if len(h.gotRPC) != 12 { t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12) From 821f1f92bbf444ba3a487e8c5c0bb1f5564d721f Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 2 Jun 2022 16:11:28 -0700 Subject: [PATCH 5/6] Add the "shs" variable --- server.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index fa0c0910adf..6ccd3bf6fa7 100644 --- a/server.go +++ b/server.go @@ -1126,13 +1126,13 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { - - if len(s.opts.statsHandlers) != 0 || trInfo != nil || channelz.IsOn() { + shs := s.opts.statsHandlers + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { if channelz.IsOn() { s.incrCallsStarted() } var statsBegin *stats.Begin - for _, sh := range s.opts.statsHandlers { + for _, sh := range shs { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, @@ -1163,7 +1163,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. trInfo.tr.Finish() } - for _, sh := range s.opts.statsHandlers { + for _, sh := range shs { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1245,7 +1245,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } var payInfo *payloadInfo - if len(s.opts.statsHandlers) != 0 || binlog != nil { + if len(shs) != 0 || binlog != nil { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) @@ -1262,7 +1262,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } - for _, sh := range s.opts.statsHandlers { + for _, sh := range shs { sh.HandleRPC(stream.Context(), &stats.InPayload{ RecvTime: time.Now(), Payload: v, From 3ceb045a7cb554781be589212d25e07b6d00f4c3 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 2 Jun 2022 16:20:50 -0700 Subject: [PATCH 6/6] Add one more shs variable --- server.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index 6ccd3bf6fa7..ef6ecf31519 100644 --- a/server.go +++ b/server.go @@ -1420,15 +1420,16 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if channelz.IsOn() { s.incrCallsStarted() } + shs := s.opts.statsHandlers var statsBegin *stats.Begin - if len(s.opts.statsHandlers) != 0 { + if len(shs) != 0 { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - for _, sh := range s.opts.statsHandlers { + for _, sh := range shs { sh.HandleRPC(stream.Context(), statsBegin) } } @@ -1442,10 +1443,10 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, - statsHandler: s.opts.statsHandlers, + statsHandler: shs, } - if len(s.opts.statsHandlers) != 0 || trInfo != nil || channelz.IsOn() { + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { // See comment in processUnaryRPC on defers. defer func() { if trInfo != nil { @@ -1459,7 +1460,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } - if len(s.opts.statsHandlers) != 0 { + if len(shs) != 0 { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1467,7 +1468,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - for _, sh := range s.opts.statsHandlers { + for _, sh := range shs { sh.HandleRPC(stream.Context(), end) } }