diff --git a/default_dial_option_server_option_test.go b/default_dial_option_server_option_test.go index 3dc446f58b5..eecd6b846f2 100644 --- a/default_dial_option_server_option_test.go +++ b/default_dial_option_server_option_test.go @@ -38,7 +38,7 @@ func (s) TestAddExtraDialOptions(t *testing.T) { // Set and check the DialOptions opts := []DialOption{WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials())} - internal.AddExtraDialOptions.(func(opt ...DialOption))(opts...) + internal.AddGlobalDialOptions.(func(opt ...DialOption))(opts...) for i, opt := range opts { if extraDialOptions[i] != opt { t.Fatalf("Unexpected extra dial option at index %d: %v != %v", i, extraDialOptions[i], opt) @@ -52,7 +52,7 @@ func (s) TestAddExtraDialOptions(t *testing.T) { cc.Close() } - internal.ClearExtraDialOptions() + internal.ClearGlobalDialOptions() if len(extraDialOptions) != 0 { t.Fatalf("Unexpected len of extraDialOptions: %d != 0", len(extraDialOptions)) } @@ -62,7 +62,7 @@ func (s) TestAddExtraServerOptions(t *testing.T) { const maxRecvSize = 998765 // Set and check the ServerOptions opts := []ServerOption{Creds(insecure.NewCredentials()), MaxRecvMsgSize(maxRecvSize)} - internal.AddExtraServerOptions.(func(opt ...ServerOption))(opts...) + internal.AddGlobalServerOptions.(func(opt ...ServerOption))(opts...) for i, opt := range opts { if extraServerOptions[i] != opt { t.Fatalf("Unexpected extra server option at index %d: %v != %v", i, extraServerOptions[i], opt) @@ -75,7 +75,7 @@ func (s) TestAddExtraServerOptions(t *testing.T) { t.Fatalf("Unexpected s.opts.maxReceiveMessageSize: %d != %d", s.opts.maxReceiveMessageSize, maxRecvSize) } - internal.ClearExtraServerOptions() + internal.ClearGlobalServerOptions() if len(extraServerOptions) != 0 { t.Fatalf("Unexpected len of extraServerOptions: %d != 0", len(extraServerOptions)) } diff --git a/dialoptions.go b/dialoptions.go index 60403bc160e..9372dc322e8 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" internalbackoff "google.golang.org/grpc/internal/backoff" + "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/resolver" @@ -36,12 +37,13 @@ import ( ) func init() { - internal.AddExtraDialOptions = func(opt ...DialOption) { + internal.AddGlobalDialOptions = func(opt ...DialOption) { extraDialOptions = append(extraDialOptions, opt...) } - internal.ClearExtraDialOptions = func() { + internal.ClearGlobalDialOptions = func() { extraDialOptions = nil } + internal.WithBinaryLogger = withBinaryLogger } // dialOptions configure a Dial call. dialOptions are set by the DialOption @@ -61,6 +63,7 @@ type dialOptions struct { timeout time.Duration scChan <-chan ServiceConfig authority string + binaryLogger binarylog.Logger copts transport.ConnectOptions callOptions []CallOption channelzParentID *channelz.Identifier @@ -401,6 +404,14 @@ func WithStatsHandler(h stats.Handler) DialOption { }) } +// withBinaryLogger returns a DialOption that specifies the binary logger for +// this ClientConn. +func withBinaryLogger(bl binarylog.Logger) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.binaryLogger = bl + }) +} + // FailOnNonTempDialError returns a DialOption that specifies if gRPC fails on // non-temporary dial errors. If f is true, and dialer returns a non-temporary // error, gRPC will fail the connection to the network address and won't try to diff --git a/gcp/observability/opencensus.go b/gcp/observability/opencensus.go index ccaa6a98a42..1bca322cc52 100644 --- a/gcp/observability/opencensus.go +++ b/gcp/observability/opencensus.go @@ -117,8 +117,8 @@ func startOpenCensus(config *config) error { } // Only register default StatsHandlers if other things are setup correctly. - internal.AddExtraServerOptions.(func(opt ...grpc.ServerOption))(grpc.StatsHandler(&ocgrpc.ServerHandler{StartOptions: so})) - internal.AddExtraDialOptions.(func(opt ...grpc.DialOption))(grpc.WithStatsHandler(&ocgrpc.ClientHandler{StartOptions: so})) + internal.AddGlobalServerOptions.(func(opt ...grpc.ServerOption))(grpc.StatsHandler(&ocgrpc.ServerHandler{StartOptions: so})) + internal.AddGlobalDialOptions.(func(opt ...grpc.DialOption))(grpc.WithStatsHandler(&ocgrpc.ClientHandler{StartOptions: so})) logger.Infof("Enabled OpenCensus StatsHandlers for clients and servers") return nil @@ -128,8 +128,8 @@ func startOpenCensus(config *config) error { // packages if exporter was created. func stopOpenCensus() { if exporter != nil { - internal.ClearExtraDialOptions() - internal.ClearExtraServerOptions() + internal.ClearGlobalDialOptions() + internal.ClearGlobalServerOptions() // Call these unconditionally, doesn't matter if not registered, will be // a noop if not registered. diff --git a/internal/internal.go b/internal/internal.go index 87cacb5b808..fd0ee3dcaf1 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -63,24 +63,31 @@ var ( // xDS-enabled server invokes this method on a grpc.Server when a particular // listener moves to "not-serving" mode. DrainServerTransports interface{} // func(*grpc.Server, string) - // AddExtraServerOptions adds an array of ServerOption that will be + // AddGlobalServerOptions adds an array of ServerOption that will be // effective globally for newly created servers. The priority will be: 1. // user-provided; 2. this method; 3. default values. - AddExtraServerOptions interface{} // func(opt ...ServerOption) - // ClearExtraServerOptions clears the array of extra ServerOption. This + AddGlobalServerOptions interface{} // func(opt ...ServerOption) + // ClearGlobalServerOptions clears the array of extra ServerOption. This // method is useful in testing and benchmarking. - ClearExtraServerOptions func() - // AddExtraDialOptions adds an array of DialOption that will be effective + ClearGlobalServerOptions func() + // AddGlobalDialOptions adds an array of DialOption that will be effective // globally for newly created client channels. The priority will be: 1. // user-provided; 2. this method; 3. default values. - AddExtraDialOptions interface{} // func(opt ...DialOption) - // ClearExtraDialOptions clears the array of extra DialOption. This + AddGlobalDialOptions interface{} // func(opt ...DialOption) + // ClearGlobalDialOptions clears the array of extra DialOption. This // method is useful in testing and benchmarking. - ClearExtraDialOptions func() + ClearGlobalDialOptions func() // JoinServerOptions combines the server options passed as arguments into a // single server option. JoinServerOptions interface{} // func(...grpc.ServerOption) grpc.ServerOption + // WithBinaryLogger returns a DialOption that specifies the binary logger + // for a ClientConn. + WithBinaryLogger interface{} // func(binarylog.Logger) grpc.DialOption + // BinaryLogger returns a ServerOption that can set the binary logger for a + // server. + BinaryLogger interface{} // func(binarylog.Logger) grpc.ServerOption + // NewXDSResolverWithConfigForTesting creates a new xds resolver builder using // the provided xds bootstrap config instead of the global configuration from // the supported environment variables. The resolver.Builder is meant to be diff --git a/server.go b/server.go index 6ef3df67d9e..e8143492c25 100644 --- a/server.go +++ b/server.go @@ -73,12 +73,13 @@ func init() { internal.DrainServerTransports = func(srv *Server, addr string) { srv.drainServerTransports(addr) } - internal.AddExtraServerOptions = func(opt ...ServerOption) { + internal.AddGlobalServerOptions = func(opt ...ServerOption) { extraServerOptions = opt } - internal.ClearExtraServerOptions = func() { + internal.ClearGlobalServerOptions = func() { extraServerOptions = nil } + internal.BinaryLogger = binaryLogger internal.JoinServerOptions = newJoinServerOption } @@ -156,6 +157,7 @@ type serverOptions struct { streamInt StreamServerInterceptor chainUnaryInts []UnaryServerInterceptor chainStreamInts []StreamServerInterceptor + binaryLogger binarylog.Logger inTapHandle tap.ServerInHandle statsHandlers []stats.Handler maxConcurrentStreams uint32 @@ -469,6 +471,14 @@ func StatsHandler(h stats.Handler) ServerOption { }) } +// binaryLogger returns a ServerOption that can set the binary logger for the +// server. +func binaryLogger(bl binarylog.Logger) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.binaryLogger = bl + }) +} + // UnknownServiceHandler returns a ServerOption that allows for adding a custom // unknown service handler. The provided method is a bidi-streaming RPC service // handler that will be invoked instead of returning the "unimplemented" gRPC @@ -1216,9 +1226,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } }() } - - binlog := binarylog.GetMethodLogger(stream.Method()) - if binlog != nil { + var binlogs []binarylog.MethodLogger + if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { + binlogs = append(binlogs, ml) + } + if s.opts.binaryLogger != nil { + if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { + binlogs = append(binlogs, ml) + } + } + if len(binlogs) != 0 { ctx := stream.Context() md, _ := metadata.FromIncomingContext(ctx) logEntry := &binarylog.ClientHeader{ @@ -1238,7 +1255,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if peer, ok := peer.FromContext(ctx); ok { logEntry.PeerAddr = peer.Addr } - binlog.Log(logEntry) + for _, binlog := range binlogs { + binlog.Log(logEntry) + } } // comp and cp are used for compression. decomp and dc are used for @@ -1278,7 +1297,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } var payInfo *payloadInfo - if len(shs) != 0 || binlog != nil { + if len(shs) != 0 || len(binlogs) != 0 { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) @@ -1304,10 +1323,13 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Length: len(d), }) } - if binlog != nil { - binlog.Log(&binarylog.ClientMessage{ + if len(binlogs) != 0 { + cm := &binarylog.ClientMessage{ Message: d, - }) + } + for _, binlog := range binlogs { + binlog.Log(cm) + } } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) @@ -1331,18 +1353,24 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if e := t.WriteStatus(stream, appStatus); e != nil { channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) } - if binlog != nil { + if len(binlogs) != 0 { if h, _ := stream.Header(); h.Len() > 0 { // Only log serverHeader if there was header. Otherwise it can // be trailer only. - binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) + } + for _, binlog := range binlogs { + binlog.Log(sh) + } } - binlog.Log(&binarylog.ServerTrailer{ + st := &binarylog.ServerTrailer{ Trailer: stream.Trailer(), Err: appErr, - }) + } + for _, binlog := range binlogs { + binlog.Log(st) + } } return appErr } @@ -1368,26 +1396,34 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) } } - if binlog != nil { + if len(binlogs) != 0 { h, _ := stream.Header() - binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) - binlog.Log(&binarylog.ServerTrailer{ + } + st := &binarylog.ServerTrailer{ Trailer: stream.Trailer(), Err: appErr, - }) + } + for _, binlog := range binlogs { + binlog.Log(sh) + binlog.Log(st) + } } return err } - if binlog != nil { + if len(binlogs) != 0 { h, _ := stream.Header() - binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) - binlog.Log(&binarylog.ServerMessage{ + } + sm := &binarylog.ServerMessage{ Message: reply, - }) + } + for _, binlog := range binlogs { + binlog.Log(sh) + binlog.Log(sm) + } } if channelz.IsOn() { t.IncrMsgSent() @@ -1399,11 +1435,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // Should the logging be in WriteStatus? Should we ignore the WriteStatus // error or allow the stats handler to see it? err = t.WriteStatus(stream, statusOK) - if binlog != nil { - binlog.Log(&binarylog.ServerTrailer{ + if len(binlogs) != 0 { + st := &binarylog.ServerTrailer{ Trailer: stream.Trailer(), Err: appErr, - }) + } + for _, binlog := range binlogs { + binlog.Log(st) + } } return err } @@ -1516,8 +1555,15 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp }() } - ss.binlog = binarylog.GetMethodLogger(stream.Method()) - if ss.binlog != nil { + if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { + ss.binlogs = append(ss.binlogs, ml) + } + if s.opts.binaryLogger != nil { + if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { + ss.binlogs = append(ss.binlogs, ml) + } + } + if len(ss.binlogs) != 0 { md, _ := metadata.FromIncomingContext(ctx) logEntry := &binarylog.ClientHeader{ Header: md, @@ -1536,7 +1582,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if peer, ok := peer.FromContext(ss.Context()); ok { logEntry.PeerAddr = peer.Addr } - ss.binlog.Log(logEntry) + for _, binlog := range ss.binlogs { + binlog.Log(logEntry) + } } // If dc is set and matches the stream's compression, use it. Otherwise, try @@ -1602,11 +1650,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } t.WriteStatus(ss.s, appStatus) - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ServerTrailer{ + if len(ss.binlogs) != 0 { + st := &binarylog.ServerTrailer{ Trailer: ss.s.Trailer(), Err: appErr, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(st) + } } // TODO: Should we log an error from WriteStatus here and below? return appErr @@ -1617,11 +1668,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } err = t.WriteStatus(ss.s, statusOK) - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ServerTrailer{ + if len(ss.binlogs) != 0 { + st := &binarylog.ServerTrailer{ Trailer: ss.s.Trailer(), Err: appErr, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(st) + } } return err } diff --git a/stream.go b/stream.go index b678596949b..b10ab1ab632 100644 --- a/stream.go +++ b/stream.go @@ -309,7 +309,14 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client if !cc.dopts.disableRetry { cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler) } - cs.binlog = binarylog.GetMethodLogger(method) + if ml := binarylog.GetMethodLogger(method); ml != nil { + cs.binlogs = append(cs.binlogs, ml) + } + if cc.dopts.binaryLogger != nil { + if ml := cc.dopts.binaryLogger.GetMethodLogger(method); ml != nil { + cs.binlogs = append(cs.binlogs, ml) + } + } // Pick the transport to use and create a new stream on the transport. // Assign cs.attempt upon success. @@ -330,7 +337,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client return nil, err } - if cs.binlog != nil { + if len(cs.binlogs) != 0 { md, _ := metadata.FromOutgoingContext(ctx) logEntry := &binarylog.ClientHeader{ OnClientSide: true, @@ -344,7 +351,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client logEntry.Timeout = 0 } } - cs.binlog.Log(logEntry) + for _, binlog := range cs.binlogs { + binlog.Log(logEntry) + } } if desc != unaryStreamDesc { @@ -488,7 +497,7 @@ type clientStream struct { retryThrottler *retryThrottler // The throttler active when the RPC began. - binlog binarylog.MethodLogger // Binary logger, can be nil. + binlogs []binarylog.MethodLogger // serverHeaderBinlogged is a boolean for whether server header has been // logged. Server header will be logged when the first time one of those // happens: stream.Header(), stream.Recv(). @@ -752,7 +761,7 @@ func (cs *clientStream) Header() (metadata.MD, error) { cs.finish(err) return nil, err } - if cs.binlog != nil && !cs.serverHeaderBinlogged { + if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged { // Only log if binary log is on and header has not been logged. logEntry := &binarylog.ServerHeader{ OnClientSide: true, @@ -762,8 +771,10 @@ func (cs *clientStream) Header() (metadata.MD, error) { if peer, ok := peer.FromContext(cs.Context()); ok { logEntry.PeerAddr = peer.Addr } - cs.binlog.Log(logEntry) cs.serverHeaderBinlogged = true + for _, binlog := range cs.binlogs { + binlog.Log(logEntry) + } } return m, nil } @@ -837,38 +848,44 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { return a.sendMsg(m, hdr, payload, data) } err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) }) - if cs.binlog != nil && err == nil { - cs.binlog.Log(&binarylog.ClientMessage{ + if len(cs.binlogs) != 0 && err == nil { + cm := &binarylog.ClientMessage{ OnClientSide: true, Message: data, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(cm) + } } return err } func (cs *clientStream) RecvMsg(m interface{}) error { - if cs.binlog != nil && !cs.serverHeaderBinlogged { + if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged { // Call Header() to binary log header if it's not already logged. cs.Header() } var recvInfo *payloadInfo - if cs.binlog != nil { + if len(cs.binlogs) != 0 { recvInfo = &payloadInfo{} } err := cs.withRetry(func(a *csAttempt) error { return a.recvMsg(m, recvInfo) }, cs.commitAttemptLocked) - if cs.binlog != nil && err == nil { - cs.binlog.Log(&binarylog.ServerMessage{ + if len(cs.binlogs) != 0 && err == nil { + sm := &binarylog.ServerMessage{ OnClientSide: true, Message: recvInfo.uncompressedBytes, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(sm) + } } if err != nil || !cs.desc.ServerStreams { // err != nil or non-server-streaming indicates end of stream. cs.finish(err) - if cs.binlog != nil { + if len(cs.binlogs) != 0 { // finish will not log Trailer. Log Trailer here. logEntry := &binarylog.ServerTrailer{ OnClientSide: true, @@ -881,7 +898,9 @@ func (cs *clientStream) RecvMsg(m interface{}) error { if peer, ok := peer.FromContext(cs.Context()); ok { logEntry.PeerAddr = peer.Addr } - cs.binlog.Log(logEntry) + for _, binlog := range cs.binlogs { + binlog.Log(logEntry) + } } } return err @@ -902,10 +921,13 @@ func (cs *clientStream) CloseSend() error { return nil } cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }) - if cs.binlog != nil { - cs.binlog.Log(&binarylog.ClientHalfClose{ + if len(cs.binlogs) != 0 { + chc := &binarylog.ClientHalfClose{ OnClientSide: true, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(chc) + } } // We never returned an error here for reasons. return nil @@ -938,10 +960,13 @@ func (cs *clientStream) finish(err error) { // // Only one of cancel or trailer needs to be logged. In the cases where // users don't call RecvMsg, users must have already canceled the RPC. - if cs.binlog != nil && status.Code(err) == codes.Canceled { - cs.binlog.Log(&binarylog.Cancel{ + if len(cs.binlogs) != 0 && status.Code(err) == codes.Canceled { + c := &binarylog.Cancel{ OnClientSide: true, - }) + } + for _, binlog := range cs.binlogs { + binlog.Log(c) + } } if err == nil { cs.retryThrottler.successfulRPC() @@ -1013,6 +1038,7 @@ func (a *csAttempt) recvMsg(m interface{}, payInfo *payloadInfo) (err error) { } return io.EOF // indicates successful end of stream. } + return toRPCErr(err) } if a.trInfo != nil { @@ -1461,7 +1487,7 @@ type serverStream struct { statsHandler []stats.Handler - binlog binarylog.MethodLogger + binlogs []binarylog.MethodLogger // serverHeaderBinlogged indicates whether server header has been logged. It // will happen when one of the following two happens: stream.SendHeader(), // stream.Send(). @@ -1495,12 +1521,15 @@ func (ss *serverStream) SendHeader(md metadata.MD) error { } err = ss.t.WriteHeader(ss.s, md) - if ss.binlog != nil && !ss.serverHeaderBinlogged { + if len(ss.binlogs) != 0 && !ss.serverHeaderBinlogged { h, _ := ss.s.Header() - ss.binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) + } ss.serverHeaderBinlogged = true + for _, binlog := range ss.binlogs { + binlog.Log(sh) + } } return err } @@ -1557,17 +1586,23 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } - if ss.binlog != nil { + if len(ss.binlogs) != 0 { if !ss.serverHeaderBinlogged { h, _ := ss.s.Header() - ss.binlog.Log(&binarylog.ServerHeader{ + sh := &binarylog.ServerHeader{ Header: h, - }) + } ss.serverHeaderBinlogged = true + for _, binlog := range ss.binlogs { + binlog.Log(sh) + } } - ss.binlog.Log(&binarylog.ServerMessage{ + sm := &binarylog.ServerMessage{ Message: data, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(sm) + } } if len(ss.statsHandler) != 0 { for _, sh := range ss.statsHandler { @@ -1606,13 +1641,16 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } }() var payInfo *payloadInfo - if len(ss.statsHandler) != 0 || ss.binlog != nil { + if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 { payInfo = &payloadInfo{} } if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { if err == io.EOF { - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ClientHalfClose{}) + if len(ss.binlogs) != 0 { + chc := &binarylog.ClientHalfClose{} + for _, binlog := range ss.binlogs { + binlog.Log(chc) + } } return err } @@ -1633,10 +1671,13 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { }) } } - if ss.binlog != nil { - ss.binlog.Log(&binarylog.ClientMessage{ + if len(ss.binlogs) != 0 { + cm := &binarylog.ClientMessage{ Message: payInfo.uncompressedBytes, - }) + } + for _, binlog := range ss.binlogs { + binlog.Log(cm) + } } return nil } diff --git a/test/end2end_test.go b/test/end2end_test.go index 725bcdb641e..ecf5b5e303e 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -57,6 +57,7 @@ import ( healthgrpc "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" @@ -8184,3 +8185,93 @@ func (s) TestGoAwayStreamIDSmallerThanCreatedStreams(t *testing.T) { ct.writeGoAway(1, http2.ErrCodeNo, []byte{}) goAwayWritten.Fire() } + +type mockBinaryLogger struct { + mml *mockMethodLogger +} + +func newMockBinaryLogger() *mockBinaryLogger { + return &mockBinaryLogger{ + mml: &mockMethodLogger{}, + } +} + +func (mbl *mockBinaryLogger) GetMethodLogger(string) binarylog.MethodLogger { + return mbl.mml +} + +type mockMethodLogger struct { + events uint64 +} + +func (mml *mockMethodLogger) Log(binarylog.LogEntryConfig) { + atomic.AddUint64(&mml.events, 1) +} + +// TestGlobalBinaryLoggingOptions tests the binary logging options for client +// and server side. The test configures a binary logger to be plumbed into every +// created ClientConn and server. It then makes a unary RPC call, and a +// streaming RPC call. A certain amount of logging calls should happen as a +// result of the stream operations on each of these calls. +func (s) TestGlobalBinaryLoggingOptions(t *testing.T) { + csbl := newMockBinaryLogger() + ssbl := newMockBinaryLogger() + + internal.AddGlobalDialOptions.(func(opt ...grpc.DialOption))(internal.WithBinaryLogger.(func(bl binarylog.Logger) grpc.DialOption)(csbl)) + internal.AddGlobalServerOptions.(func(opt ...grpc.ServerOption))(internal.BinaryLogger.(func(bl binarylog.Logger) grpc.ServerOption)(ssbl)) + defer func() { + internal.ClearGlobalDialOptions() + internal.ClearGlobalServerOptions() + }() + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return nil + } + } + }, + } + + // No client or server options specified, because should pick up configured + // global options. + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Make a Unary RPC. This should cause Log calls on the MethodLogger. + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { + t.Fatalf("Unexpected error from UnaryCall: %v", err) + } + if csbl.mml.events != 5 { + t.Fatalf("want 5 client side binary logging events, got %v", csbl.mml.events) + } + if ssbl.mml.events != 5 { + t.Fatalf("want 5 server side binary logging events, got %v", ssbl.mml.events) + } + + // Make a streaming RPC. This should cause Log calls on the MethodLogger. + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("ss.Client.FullDuplexCall failed: %f", err) + } + + stream.CloseSend() + if _, err = stream.Recv(); err != io.EOF { + t.Fatalf("unexpected error: %v, expected an EOF error", err) + } + + if csbl.mml.events != 9 { + t.Fatalf("want 9 client side binary logging events, got %v", csbl.mml.events) + } + if ssbl.mml.events != 8 { + t.Fatalf("want 8 server side binary logging events, got %v", ssbl.mml.events) + } +}