diff --git a/credentials/credentials.go b/credentials/credentials.go index 8ea3d4a1dc2..4076f3716dd 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -34,6 +34,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/credentials/internal" + ginternal "google.golang.org/grpc/internal" ) // PerRPCCredentials defines the common interface for the credentials which need to @@ -334,3 +335,32 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { return cfg.Clone() } + +// RequestInfo contains request data to be attached to the context passed along to GetRequestMetadata calls. This API is experimental. +type RequestInfo struct { + // The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method") + Method string +} + +// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. +type requestInfoKey struct{} + +// RequestInfoFromContext extracts the RequestInfo from the context. This API is experimental. +func RequestInfoFromContext(ctx context.Context) RequestInfo { + ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo) + if !ok { + return RequestInfo{} + } + return ri +} + +// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental. +func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) +} + +func init() { + ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) + } +} diff --git a/internal/internal.go b/internal/internal.go index bc1f99ac803..7f8a3d0b0fa 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -46,7 +46,8 @@ var ( // pointer to the wrapped Status proto for a given status.Status without a // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. - StatusRawProto interface{} // func (*status.Status) *spb.Status + StatusRawProto interface{} // func (*status.Status) *spb.Status + NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already) ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 9bd8c27b365..5152d842868 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/keepalive" @@ -547,6 +548,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) + ri := credentials.RequestInfo{ + Method: callHdr.Method, + } + ctx = internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { return nil, err diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c5ee74848cc..caf364a9821 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -37,6 +37,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/status" @@ -1809,3 +1810,28 @@ func TestHeaderTblSize(t *testing.T) { t.Fatalf("expected len(limits) = 2 within 10s, got != 2") } } + +// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets +// propagated into the context object as a credentials.RequestInfo object. +func TestRequestInfoFoundInStreamContext(t *testing.T) { + serverConfig := &ServerConfig{} + server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + defer cancel() + defer server.stop() + defer client.Close() + + ch := &CallHdr{ + Method: "someService/Method", + } + ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer ctxCancel() + + s, err := client.NewStream(ctx, ch) + if err != nil { + t.Fatalf("client.NewStream() failed: %v", err) + } + ri := credentials.RequestInfoFromContext(s.ctx) + if ch.Method != ri.Method { + t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 8b38563b9e3..daf49e44688 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7484,3 +7484,32 @@ func parseCfg(s string) serviceconfig.Config { } return c } + +type methodTestCreds struct { + expectedMethod string +} + +func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ri := credentials.RequestInfoFromContext(ctx) + return nil, status.Errorf(codes.Unknown, ri.Method) +} + +func (m methodTestCreds) RequireTransportSecurity() bool { + return false +} + +func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { + const wantMethod = "/grpc.testing.TestService/EmptyCall" + ss := &stubServer{} + if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) + } +}