diff --git a/server.go b/server.go index d90f3fcd3bf..e72029bf147 100644 --- a/server.go +++ b/server.go @@ -1144,7 +1144,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: false, + IsServerStream: false, } sh.HandleRPC(stream.Context(), statsBegin) } @@ -1424,7 +1426,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: sd.ClientStreams, + IsServerStream: sd.ServerStreams, } sh.HandleRPC(stream.Context(), statsBegin) } diff --git a/stats/stats.go b/stats/stats.go index 63e476ee7ff..a5ebeeb6932 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -45,6 +45,10 @@ type Begin struct { BeginTime time.Time // FailFast indicates if this RPC is failfast. FailFast bool + // IsClientStream indicates whether the RPC is a client streaming RPC. + IsClientStream bool + // IsServerStream indicates whether the RPC is a server streaming RPC. + IsServerStream bool } // IsClient indicates if the stats information is from client side. diff --git a/stats/stats_test.go b/stats/stats_test.go index 306f2f6b8e9..dfc6edfc3d3 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -407,15 +407,17 @@ func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallReq } type expectedData struct { - method string - serverAddr string - compression string - reqIdx int - requests []proto.Message - respIdx int - responses []proto.Message - err error - failfast bool + method string + isClientStream bool + isServerStream bool + serverAddr string + compression string + reqIdx int + requests []proto.Message + respIdx int + responses []proto.Message + err error + failfast bool } type gotData struct { @@ -456,6 +458,12 @@ func checkBegin(t *testing.T, d *gotData, e *expectedData) { t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) } } + if st.IsClientStream != e.isClientStream { + t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream) + } + if st.IsServerStream != e.isServerStream { + t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream) + } } func checkInHeader(t *testing.T, d *gotData, e *expectedData) { @@ -847,6 +855,9 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f err error method string + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -864,14 +875,18 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -900,12 +915,14 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() @@ -1138,6 +1155,9 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map method string err error + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -1154,14 +1174,18 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -1194,13 +1218,15 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - failfast: cc.failfast, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + failfast: cc.failfast, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() diff --git a/stream.go b/stream.go index 1f3e70d2c44..ed6af683d20 100644 --- a/stream.go +++ b/stream.go @@ -295,9 +295,11 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) beginTime = time.Now() begin := &stats.Begin{ - Client: true, - BeginTime: beginTime, - FailFast: c.failFast, + Client: true, + BeginTime: beginTime, + FailFast: c.failFast, + IsClientStream: desc.ClientStreams, + IsServerStream: desc.ServerStreams, } sh.HandleRPC(ctx, begin) }