From a8e85e0d5704da1f5bd858a7b47621e77fe5035b Mon Sep 17 00:00:00 2001 From: Ehsan Afzali Date: Sat, 22 May 2021 01:54:24 +0300 Subject: [PATCH] server: allow PreparedMsgs to work for server streams (#3480) --- server.go | 2 ++ test/end2end_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/server.go b/server.go index 0a151dee4fc..2d8e005cd6f 100644 --- a/server.go +++ b/server.go @@ -1521,6 +1521,8 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp } } + ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp) + if trInfo != nil { trInfo.tr.LazyLog(&trInfo.firstLine, false) } diff --git a/test/end2end_test.go b/test/end2end_test.go index 76ff07a27c2..9a8099ca1e1 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2249,6 +2249,61 @@ func testPreloaderClientSend(t *testing.T, e env) { } } +func (s) TestPreloaderSenderSend(t *testing.T) { + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + for i := 0; i < 10; i++ { + preparedMsg := &grpc.PreparedMsg{} + err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{ + Body: []byte{'0' + uint8(i)}, + }, + }) + if err != nil { + return err + } + stream.SendMsg(preparedMsg) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) + } + + var ngot int + var buf bytes.Buffer + for { + reply, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + ngot++ + if buf.Len() > 0 { + buf.WriteByte(',') + } + buf.Write(reply.GetPayload().GetBody()) + } + if want := 10; ngot != want { + t.Errorf("Got %d replies, want %d", ngot, want) + } + if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want { + t.Errorf("Got replies %q; want %q", got, want) + } +} + func (s) TestMaxMsgSizeClientDefault(t *testing.T) { for _, e := range listTestEnv() { testMaxMsgSizeClientDefault(t, e)