From dc6e857dd67f988b96ff54f54130f2e13f2a9456 Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Mon, 30 Sep 2019 10:54:32 -0400 Subject: [PATCH 01/10] Add a RequestInfo struct which initially is used for passing the full request method (though could later be expanded to pass more info) so that things like GetRequestMetadata can be used to apply logic based on that data. This is a fix for #3019 --- credentials/credentials.go | 30 ++++++++++++++++++++++++++++ internal/internal.go | 3 ++- internal/transport/http2_client.go | 5 +++++ internal/transport/transport_test.go | 26 ++++++++++++++++++++++++ test/end2end_test.go | 29 +++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 1 deletion(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 8ea3d4a1dc2..4076f3716dd 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -34,6 +34,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/credentials/internal" + ginternal "google.golang.org/grpc/internal" ) // PerRPCCredentials defines the common interface for the credentials which need to @@ -334,3 +335,32 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { return cfg.Clone() } + +// RequestInfo contains request data to be attached to the context passed along to GetRequestMetadata calls. This API is experimental. +type RequestInfo struct { + // The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method") + Method string +} + +// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. +type requestInfoKey struct{} + +// RequestInfoFromContext extracts the RequestInfo from the context. This API is experimental. +func RequestInfoFromContext(ctx context.Context) RequestInfo { + ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo) + if !ok { + return RequestInfo{} + } + return ri +} + +// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental. +func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) +} + +func init() { + ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) + } +} diff --git a/internal/internal.go b/internal/internal.go index bc1f99ac803..7f8a3d0b0fa 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -46,7 +46,8 @@ var ( // pointer to the wrapped Status proto for a given status.Status without a // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. - StatusRawProto interface{} // func (*status.Status) *spb.Status + StatusRawProto interface{} // func (*status.Status) *spb.Status + NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already) ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 9bd8c27b365..5152d842868 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/keepalive" @@ -547,6 +548,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) + ri := credentials.RequestInfo{ + Method: callHdr.Method, + } + ctx = internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { return nil, err diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 067d588c000..1a2328f89a1 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -37,6 +37,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/testutils" @@ -2378,3 +2379,28 @@ func TestTCPUserTimeout(t *testing.T) { } } } + +// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets +// propagated into the context object as a credentials.RequestInfo object. +func TestRequestInfoFoundInStreamContext(t *testing.T) { + serverConfig := &ServerConfig{} + server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + defer cancel() + defer server.stop() + defer client.Close() + + ch := &CallHdr{ + Method: "someService/Method", + } + ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer ctxCancel() + + s, err := client.NewStream(ctx, ch) + if err != nil { + t.Fatalf("client.NewStream() failed: %v", err) + } + ri := credentials.RequestInfoFromContext(s.ctx) + if ch.Method != ri.Method { + t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 8b38563b9e3..daf49e44688 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7484,3 +7484,32 @@ func parseCfg(s string) serviceconfig.Config { } return c } + +type methodTestCreds struct { + expectedMethod string +} + +func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ri := credentials.RequestInfoFromContext(ctx) + return nil, status.Errorf(codes.Unknown, ri.Method) +} + +func (m methodTestCreds) RequireTransportSecurity() bool { + return false +} + +func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { + const wantMethod = "/grpc.testing.TestService/EmptyCall" + ss := &stubServer{} + if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) + } +} From 5b75aba1ca058b03bb68d3241465dd1547919640 Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Mon, 30 Sep 2019 10:54:32 -0400 Subject: [PATCH 02/10] Add a RequestInfo struct which initially is used for passing the full request method (though could later be expanded to pass more info) so that things like GetRequestMetadata can be used to apply logic based on that data. This is a fix for #3019 --- credentials/credentials.go | 30 ++++++++++++++++++++++++++++ internal/internal.go | 3 ++- internal/transport/http2_client.go | 5 +++++ internal/transport/transport_test.go | 26 ++++++++++++++++++++++++ test/end2end_test.go | 29 +++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 1 deletion(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 8ea3d4a1dc2..4076f3716dd 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -34,6 +34,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/credentials/internal" + ginternal "google.golang.org/grpc/internal" ) // PerRPCCredentials defines the common interface for the credentials which need to @@ -334,3 +335,32 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { return cfg.Clone() } + +// RequestInfo contains request data to be attached to the context passed along to GetRequestMetadata calls. This API is experimental. +type RequestInfo struct { + // The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method") + Method string +} + +// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. +type requestInfoKey struct{} + +// RequestInfoFromContext extracts the RequestInfo from the context. This API is experimental. +func RequestInfoFromContext(ctx context.Context) RequestInfo { + ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo) + if !ok { + return RequestInfo{} + } + return ri +} + +// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental. +func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) +} + +func init() { + ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) + } +} diff --git a/internal/internal.go b/internal/internal.go index bc1f99ac803..7f8a3d0b0fa 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -46,7 +46,8 @@ var ( // pointer to the wrapped Status proto for a given status.Status without a // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. - StatusRawProto interface{} // func (*status.Status) *spb.Status + StatusRawProto interface{} // func (*status.Status) *spb.Status + NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already) ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 9bd8c27b365..5152d842868 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/keepalive" @@ -547,6 +548,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) + ri := credentials.RequestInfo{ + Method: callHdr.Method, + } + ctx = internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { return nil, err diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c5ee74848cc..caf364a9821 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -37,6 +37,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/status" @@ -1809,3 +1810,28 @@ func TestHeaderTblSize(t *testing.T) { t.Fatalf("expected len(limits) = 2 within 10s, got != 2") } } + +// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets +// propagated into the context object as a credentials.RequestInfo object. +func TestRequestInfoFoundInStreamContext(t *testing.T) { + serverConfig := &ServerConfig{} + server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + defer cancel() + defer server.stop() + defer client.Close() + + ch := &CallHdr{ + Method: "someService/Method", + } + ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer ctxCancel() + + s, err := client.NewStream(ctx, ch) + if err != nil { + t.Fatalf("client.NewStream() failed: %v", err) + } + ri := credentials.RequestInfoFromContext(s.ctx) + if ch.Method != ri.Method { + t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 8b38563b9e3..daf49e44688 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7484,3 +7484,32 @@ func parseCfg(s string) serviceconfig.Config { } return c } + +type methodTestCreds struct { + expectedMethod string +} + +func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ri := credentials.RequestInfoFromContext(ctx) + return nil, status.Errorf(codes.Unknown, ri.Method) +} + +func (m methodTestCreds) RequireTransportSecurity() bool { + return false +} + +func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { + const wantMethod = "/grpc.testing.TestService/EmptyCall" + ss := &stubServer{} + if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) + } +} From dea8accae175dec95f59b64af4939eb7176c1073 Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Tue, 1 Oct 2019 12:19:57 -0400 Subject: [PATCH 03/10] a few requested edits --- credentials/credentials.go | 5 ----- internal/internal.go | 2 +- test/end2end_test.go | 6 ++---- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 4076f3716dd..76fc7af61b8 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -354,11 +354,6 @@ func RequestInfoFromContext(ctx context.Context) RequestInfo { return ri } -// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental. -func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context { - return context.WithValue(ctx, requestInfoKey{}, ri) -} - func init() { ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { return context.WithValue(ctx, requestInfoKey{}, ri) diff --git a/internal/internal.go b/internal/internal.go index 7f8a3d0b0fa..55f5f0685db 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -47,7 +47,7 @@ var ( // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. StatusRawProto interface{} // func (*status.Status) *spb.Status - NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already) + NewRequestInfoContext interface{} // func(context.Context, RequestInfo) context.Context ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/test/end2end_test.go b/test/end2end_test.go index daf49e44688..2b9c29ef2b6 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7485,9 +7485,7 @@ func parseCfg(s string) serviceconfig.Config { return c } -type methodTestCreds struct { - expectedMethod string -} +type methodTestCreds struct{} func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { ri := credentials.RequestInfoFromContext(ctx) @@ -7501,7 +7499,7 @@ func (m methodTestCreds) RequireTransportSecurity() bool { func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { const wantMethod = "/grpc.testing.TestService/EmptyCall" ss := &stubServer{} - if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil { + if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{})); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() From 2cc8abd436c68f84a26da1a37ddbd2b0d30f3384 Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Mon, 30 Sep 2019 10:54:32 -0400 Subject: [PATCH 04/10] Add a RequestInfo struct which initially is used for passing the full request method (though could later be expanded to pass more info) so that things like GetRequestMetadata can be used to apply logic based on that data. This is a fix for #3019 --- credentials/credentials.go | 30 ++++++++++++++++++++++++++++ internal/internal.go | 3 ++- internal/transport/http2_client.go | 5 +++++ internal/transport/transport_test.go | 26 ++++++++++++++++++++++++ test/end2end_test.go | 29 +++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 1 deletion(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 8ea3d4a1dc2..4076f3716dd 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -34,6 +34,7 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/credentials/internal" + ginternal "google.golang.org/grpc/internal" ) // PerRPCCredentials defines the common interface for the credentials which need to @@ -334,3 +335,32 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { return cfg.Clone() } + +// RequestInfo contains request data to be attached to the context passed along to GetRequestMetadata calls. This API is experimental. +type RequestInfo struct { + // The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method") + Method string +} + +// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. +type requestInfoKey struct{} + +// RequestInfoFromContext extracts the RequestInfo from the context. This API is experimental. +func RequestInfoFromContext(ctx context.Context) RequestInfo { + ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo) + if !ok { + return RequestInfo{} + } + return ri +} + +// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental. +func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) +} + +func init() { + ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, ri) + } +} diff --git a/internal/internal.go b/internal/internal.go index bc1f99ac803..7f8a3d0b0fa 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -46,7 +46,8 @@ var ( // pointer to the wrapped Status proto for a given status.Status without a // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. - StatusRawProto interface{} // func (*status.Status) *spb.Status + StatusRawProto interface{} // func (*status.Status) *spb.Status + NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already) ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 9bd8c27b365..5152d842868 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/keepalive" @@ -547,6 +548,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) + ri := credentials.RequestInfo{ + Method: callHdr.Method, + } + ctx = internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { return nil, err diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c5ee74848cc..caf364a9821 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -37,6 +37,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/status" @@ -1809,3 +1810,28 @@ func TestHeaderTblSize(t *testing.T) { t.Fatalf("expected len(limits) = 2 within 10s, got != 2") } } + +// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets +// propagated into the context object as a credentials.RequestInfo object. +func TestRequestInfoFoundInStreamContext(t *testing.T) { + serverConfig := &ServerConfig{} + server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + defer cancel() + defer server.stop() + defer client.Close() + + ch := &CallHdr{ + Method: "someService/Method", + } + ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer ctxCancel() + + s, err := client.NewStream(ctx, ch) + if err != nil { + t.Fatalf("client.NewStream() failed: %v", err) + } + ri := credentials.RequestInfoFromContext(s.ctx) + if ch.Method != ri.Method { + t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 8b38563b9e3..daf49e44688 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7484,3 +7484,32 @@ func parseCfg(s string) serviceconfig.Config { } return c } + +type methodTestCreds struct { + expectedMethod string +} + +func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ri := credentials.RequestInfoFromContext(ctx) + return nil, status.Errorf(codes.Unknown, ri.Method) +} + +func (m methodTestCreds) RequireTransportSecurity() bool { + return false +} + +func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { + const wantMethod = "/grpc.testing.TestService/EmptyCall" + ss := &stubServer{} + if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod) + } +} From f342e9a88cb61ca0db442683890f7a4fecb1f0ff Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Tue, 1 Oct 2019 12:19:57 -0400 Subject: [PATCH 05/10] a few requested edits --- credentials/credentials.go | 5 ----- internal/internal.go | 2 +- test/end2end_test.go | 6 ++---- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 4076f3716dd..76fc7af61b8 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -354,11 +354,6 @@ func RequestInfoFromContext(ctx context.Context) RequestInfo { return ri } -// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental. -func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context { - return context.WithValue(ctx, requestInfoKey{}, ri) -} - func init() { ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context { return context.WithValue(ctx, requestInfoKey{}, ri) diff --git a/internal/internal.go b/internal/internal.go index 7f8a3d0b0fa..55f5f0685db 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -47,7 +47,7 @@ var ( // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. StatusRawProto interface{} // func (*status.Status) *spb.Status - NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already) + NewRequestInfoContext interface{} // func(context.Context, RequestInfo) context.Context ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/test/end2end_test.go b/test/end2end_test.go index daf49e44688..2b9c29ef2b6 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7485,9 +7485,7 @@ func parseCfg(s string) serviceconfig.Config { return c } -type methodTestCreds struct { - expectedMethod string -} +type methodTestCreds struct{} func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { ri := credentials.RequestInfoFromContext(ctx) @@ -7501,7 +7499,7 @@ func (m methodTestCreds) RequireTransportSecurity() bool { func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) { const wantMethod = "/grpc.testing.TestService/EmptyCall" ss := &stubServer{} - if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil { + if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{})); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() From f5b8f37cfdbd5f4c3e43e2572d7349b14a723d55 Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Mon, 30 Sep 2019 10:54:32 -0400 Subject: [PATCH 06/10] Add a RequestInfo struct which initially is used for passing the full request method (though could later be expanded to pass more info) so that things like GetRequestMetadata can be used to apply logic based on that data. This is a fix for #3019 --- internal/transport/transport_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index caf364a9821..c339eb83d78 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -1835,3 +1835,28 @@ func TestRequestInfoFoundInStreamContext(t *testing.T) { t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) } } + +// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets +// propagated into the context object as a credentials.RequestInfo object. +func TestRequestInfoFoundInStreamContext(t *testing.T) { + serverConfig := &ServerConfig{} + server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) + defer cancel() + defer server.stop() + defer client.Close() + + ch := &CallHdr{ + Method: "someService/Method", + } + ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer ctxCancel() + + s, err := client.NewStream(ctx, ch) + if err != nil { + t.Fatalf("client.NewStream() failed: %v", err) + } + ri := credentials.RequestInfoFromContext(s.ctx) + if ch.Method != ri.Method { + t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) + } +} From 5415fd6cdd2fc8c4d881b8f2fbd73b706f1430ea Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Tue, 1 Oct 2019 12:46:20 -0400 Subject: [PATCH 07/10] removed duplicate test --- internal/transport/transport_test.go | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c339eb83d78..caf364a9821 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -1835,28 +1835,3 @@ func TestRequestInfoFoundInStreamContext(t *testing.T) { t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) } } - -// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets -// propagated into the context object as a credentials.RequestInfo object. -func TestRequestInfoFoundInStreamContext(t *testing.T) { - serverConfig := &ServerConfig{} - server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) - defer cancel() - defer server.stop() - defer client.Close() - - ch := &CallHdr{ - Method: "someService/Method", - } - ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer ctxCancel() - - s, err := client.NewStream(ctx, ch) - if err != nil { - t.Fatalf("client.NewStream() failed: %v", err) - } - ri := credentials.RequestInfoFromContext(s.ctx) - if ch.Method != ri.Method { - t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) - } -} From b6d80143664238584974e551ec24cbcdd5f29e0b Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Tue, 1 Oct 2019 15:48:25 -0400 Subject: [PATCH 08/10] commenting exported var: --- internal/internal.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/internal.go b/internal/internal.go index 55f5f0685db..e535a5ff91e 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -46,8 +46,10 @@ var ( // pointer to the wrapped Status proto for a given status.Status without a // call to proto.Clone(). The returned Status proto should not be mutated by // the caller. - StatusRawProto interface{} // func (*status.Status) *spb.Status - NewRequestInfoContext interface{} // func(context.Context, RequestInfo) context.Context + StatusRawProto interface{} // func (*status.Status) *spb.Status + // NewRequestInfoContext creates a new context based on the argument context attaching + // the passed in RequestInfo to the new context. + NewRequestInfoContext interface{} // func(context.Context, credentials.RequestInfo) context.Context ) // HealthChecker defines the signature of the client-side LB channel health checking function. From e220384ef763883bb714da67b037fef05f40fe3c Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Thu, 3 Oct 2019 10:47:50 -0400 Subject: [PATCH 09/10] a few comment updates --- credentials/credentials.go | 11 ++++++++--- internal/transport/http2_client.go | 12 ++++++------ internal/transport/transport_test.go | 26 -------------------------- 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 76fc7af61b8..d0cf06a7dd2 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -46,7 +46,8 @@ type PerRPCCredentials interface { // context. If a status code is returned, it will be used as the status // for the RPC. uri is the URI of the entry point for the request. // When supported by the underlying implementation, ctx can be used for - // timeout and cancellation. + // timeout and cancellation. Additionally, RequestInfo data will be + // available via ctx to this call. // TODO(zhaoq): Define the set of the qualified keys instead of leaving // it as an arbitrary string. GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) @@ -336,7 +337,9 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { return cfg.Clone() } -// RequestInfo contains request data to be attached to the context passed along to GetRequestMetadata calls. This API is experimental. +// RequestInfo contains request data attached to the context passed to GetRequestMetadata calls. +// +// This API is experimental. type RequestInfo struct { // The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method") Method string @@ -345,7 +348,9 @@ type RequestInfo struct { // requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. type requestInfoKey struct{} -// RequestInfoFromContext extracts the RequestInfo from the context. This API is experimental. +// RequestInfoFromContext extracts the RequestInfo from the context. +// +// This API is experimental. func RequestInfoFromContext(ctx context.Context) RequestInfo { ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo) if !ok { diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 5152d842868..ba3054776fd 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -398,11 +398,15 @@ func (t *http2Client) getPeer() *peer.Peer { func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) { aud := t.createAudience(callHdr) - authData, err := t.getTrAuthData(ctx, aud) + ri := credentials.RequestInfo{ + Method: callHdr.Method, + } + ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) + authData, err := t.getTrAuthData(ctxWithRequestInfo, aud) if err != nil { return nil, err } - callAuthData, err := t.getCallAuthData(ctx, aud, callHdr) + callAuthData, err := t.getCallAuthData(ctxWithRequestInfo, aud, callHdr) if err != nil { return nil, err } @@ -548,10 +552,6 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { ctx = peer.NewContext(ctx, t.getPeer()) - ri := credentials.RequestInfo{ - Method: callHdr.Method, - } - ctx = internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { return nil, err diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index caf364a9821..c5ee74848cc 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -37,7 +37,6 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/status" @@ -1810,28 +1809,3 @@ func TestHeaderTblSize(t *testing.T) { t.Fatalf("expected len(limits) = 2 within 10s, got != 2") } } - -// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets -// propagated into the context object as a credentials.RequestInfo object. -func TestRequestInfoFoundInStreamContext(t *testing.T) { - serverConfig := &ServerConfig{} - server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) - defer cancel() - defer server.stop() - defer client.Close() - - ch := &CallHdr{ - Method: "someService/Method", - } - ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer ctxCancel() - - s, err := client.NewStream(ctx, ch) - if err != nil { - t.Fatalf("client.NewStream() failed: %v", err) - } - ri := credentials.RequestInfoFromContext(s.ctx) - if ch.Method != ri.Method { - t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method) - } -} From 9d4b4181c5653e4f6ff565999782620d9d3b63ec Mon Sep 17 00:00:00 2001 From: Shane Liebling Date: Fri, 4 Oct 2019 10:55:21 -0400 Subject: [PATCH 10/10] modified RequestInfoFromContext to match peer.FromContext --- credentials/credentials.go | 11 ++++------- test/end2end_test.go | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index d0cf06a7dd2..486322b80c6 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -348,15 +348,12 @@ type RequestInfo struct { // requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object. type requestInfoKey struct{} -// RequestInfoFromContext extracts the RequestInfo from the context. +// RequestInfoFromContext extracts the RequestInfo from the context if it exists. // // This API is experimental. -func RequestInfoFromContext(ctx context.Context) RequestInfo { - ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo) - if !ok { - return RequestInfo{} - } - return ri +func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) { + ri, ok = ctx.Value(requestInfoKey{}).(RequestInfo) + return } func init() { diff --git a/test/end2end_test.go b/test/end2end_test.go index 2b9c29ef2b6..31e10741768 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7488,7 +7488,7 @@ func parseCfg(s string) serviceconfig.Config { type methodTestCreds struct{} func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - ri := credentials.RequestInfoFromContext(ctx) + ri, _ := credentials.RequestInfoFromContext(ctx) return nil, status.Errorf(codes.Unknown, ri.Method) }