From 2cc8abd436c68f84a26da1a37ddbd2b0d30f3384 Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Mon, 30 Sep 2019 10:54:32 -0400 Subject: [PATCH] Add a RequestInfo struct which initially is used for passing the full request method (though could later be expanded to pass more info) so that things like GetRequestMetadata can be used to apply logic based on that data. This is a fix for #3019 --- credentials/credentials.go | 30 ++++++++++++++++++++++++++++ internal/internal.go | 3 ++- internal/transport/http2_client.go | 5 +++++ internal/transport/transport_test.go | 26 ++++++++++++++++++++++++ test/end2end_test.go | 29 +++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 1 deletion(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 8ea3d4a1dc2..4076f3716dd 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -34,6 +34,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/credentials/internal" + ginternal "google.golang.org/grpc/internal" ) // PerRPCCredentials defines the common interface for the credentials which need to @@ -334,3 +335,32 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { return cfg.Clone() } + +// RequestInfo contains request data to be attached to the context passed along to GetRequestMetadata calls. This API is experimental. +type RequestInfo struct { + // The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method") + Method string +} + +// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. +type requestInfoKey struct{} + +// RequestInfoFromContext extracts the RequestInfo from the context. This API is experimental. +func RequestInfoFromContext(ctx context.Context) RequestInfo { + ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo) + if !ok { + return RequestInfo{} + } + return ri +} + +// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental. +func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) +} + +func init() { + ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) + } +} diff --git a/internal/internal.go b/internal/internal.go index bc1f99ac803..7f8a3d0b0fa 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -46,7 +46,8 @@ var ( // pointer to the wrapped Status proto for a given status.Status without a // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. - StatusRawProto interface{} // func (*status.Status) *spb.Status + StatusRawProto interface{} // func (*status.Status) *spb.Status + NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already) ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 9bd8c27b365..5152d842868 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/keepalive" @@ -547,6 +548,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) + ri := credentials.RequestInfo{ + Method: callHdr.Method, + } + ctx = internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { return nil, err diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c5ee74848cc..caf364a9821 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -37,6 +37,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/status" @@ -1809,3 +1810,28 @@ func TestHeaderTblSize(t *testing.T) { t.Fatalf("expected len(limits) = 2 within 10s, got != 2") } } + +// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets +// propagated into the context object as a credentials.RequestInfo object. +func TestRequestInfoFoundInStreamContext(t *testing.T) { + serverConfig := &ServerConfig{} + server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + defer cancel() + defer server.stop() + defer client.Close() + + ch := &CallHdr{ + Method: "someService/Method", + } + ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer ctxCancel() + + s, err := client.NewStream(ctx, ch) + if err != nil { + t.Fatalf("client.NewStream() failed: %v", err) + } + ri := credentials.RequestInfoFromContext(s.ctx) + if ch.Method != ri.Method { + t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 8b38563b9e3..daf49e44688 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7484,3 +7484,32 @@ func parseCfg(s string) serviceconfig.Config { } return c } + +type methodTestCreds struct { + expectedMethod string +} + +func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ri := credentials.RequestInfoFromContext(ctx) + return nil, status.Errorf(codes.Unknown, ri.Method) +} + +func (m methodTestCreds) RequireTransportSecurity() bool { + return false +} + +func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { + const wantMethod = "/grpc.testing.TestService/EmptyCall" + ss := &stubServer{} + if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) + } +}