From c84a5de06496bf8416cebf9d0058f481e37c165e Mon Sep 17 00:00:00 2001 From: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Wed, 15 Sep 2021 17:02:08 -0400 Subject: [PATCH] transport/server: add :method POST to incoming metadata (#4770) * transport/server: add :method POST to incoming metadata --- binarylog/binarylog_end2end_test.go | 16 +++++++++++++ internal/transport/http2_server.go | 1 + test/end2end_test.go | 37 +++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+) diff --git a/binarylog/binarylog_end2end_test.go b/binarylog/binarylog_end2end_test.go index 61eeb68edae..7da91ad1a8d 100644 --- a/binarylog/binarylog_end2end_test.go +++ b/binarylog/binarylog_end2end_test.go @@ -850,11 +850,27 @@ func equalLogEntry(entries ...*pb.GrpcLogEntry) (equal bool) { tmp := append(h.Metadata.Entry[:0], h.Metadata.Entry...) h.Metadata.Entry = tmp sort.Slice(h.Metadata.Entry, func(i, j int) bool { return h.Metadata.Entry[i].Key < h.Metadata.Entry[j].Key }) + // Delete headers that have POST values here since we cannot control + // this. + for i, entry := range h.Metadata.Entry { + if entry.Key == ":method" { + h.Metadata.Entry = append(h.Metadata.Entry[:i], h.Metadata.Entry[i+1:]...) + break + } + } } if h := e.GetServerHeader(); h != nil { tmp := append(h.Metadata.Entry[:0], h.Metadata.Entry...) h.Metadata.Entry = tmp sort.Slice(h.Metadata.Entry, func(i, j int) bool { return h.Metadata.Entry[i].Key < h.Metadata.Entry[j].Key }) + // Delete headers that have POST values here since we cannot control + // this. + for i, entry := range h.Metadata.Entry { + if entry.Key == ":method" { + h.Metadata.Entry = append(h.Metadata.Entry[:i], h.Metadata.Entry[i+1:]...) + break + } + } } if h := e.GetTrailer(); h != nil { sort.Slice(h.Metadata.Entry, func(i, j int) bool { return h.Metadata.Entry[i].Key < h.Metadata.Entry[j].Key }) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index cd0ebed9884..0ecfe09ceec 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -380,6 +380,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.recvCompress = hf.Value case ":method": httpMethod = hf.Value + mdata[":method"] = append(mdata[":method"], hf.Value) case ":path": s.method = hf.Value case "grpc-timeout": diff --git a/test/end2end_test.go b/test/end2end_test.go index bce752701da..4c7e2f1fc75 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7847,3 +7847,40 @@ func (s) TestStreamingServerInterceptorGetsConnection(t *testing.T) { t.Fatalf("ss.Client.StreamingInputCall(_) = _, %v, want _, %v", err, io.EOF) } } + +func unaryInterceptorVerifyPost(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.NotFound, "metadata was not in context") + } + method := md.Get(":method") + if len(method) != 1 { + return nil, status.Error(codes.InvalidArgument, ":method value had more than one value") + } + if method[0] != "POST" { + return nil, status.Error(codes.InvalidArgument, ":method value was not post") + } + return handler(ctx, req) +} + +// TestUnaryInterceptorGetsPost verifies that the server transport adds a +// :method POST header to metadata, and that that added Header is visibile at +// the grpc layer. +func (s) TestUnaryInterceptorGetsPost(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + if err := ss.Start([]grpc.ServerOption{grpc.UnaryInterceptor(unaryInterceptorVerifyPost)}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.OK { + t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v, want _, error code %s", err, codes.OK) + } +}