Skip to content

Commit

Permalink
Adding a RequestInfo struct for propagating request data to Get… (#3057)
Browse files Browse the repository at this point in the history
Add a RequestInfo struct which initially is used for passing the full request method (though could later be expanded to pass more info) so that things like GetRequestMetadata can be used to apply logic based on that data.

This is a fix for #3019
  • Loading branch information
shanel-at-google authored and dfawley committed Oct 4, 2019
1 parent 31911ed commit 47d3cfe
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 3 deletions.
29 changes: 28 additions & 1 deletion credentials/credentials.go
Expand Up @@ -34,6 +34,7 @@ import (

"github.com/golang/protobuf/proto"
"google.golang.org/grpc/credentials/internal"
ginternal "google.golang.org/grpc/internal"
)

// PerRPCCredentials defines the common interface for the credentials which need to
Expand All @@ -45,7 +46,8 @@ type PerRPCCredentials interface {
// context. If a status code is returned, it will be used as the status
// for the RPC. uri is the URI of the entry point for the request.
// When supported by the underlying implementation, ctx can be used for
// timeout and cancellation.
// timeout and cancellation. Additionally, RequestInfo data will be
// available via ctx to this call.
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
// it as an arbitrary string.
GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error)
Expand Down Expand Up @@ -334,3 +336,28 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {

return cfg.Clone()
}

// RequestInfo contains request data attached to the context passed to GetRequestMetadata calls.
//
// This API is experimental.
type RequestInfo struct {
// The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method")
Method string
}

// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object.
type requestInfoKey struct{}

// RequestInfoFromContext extracts the RequestInfo from the context if it exists.
//
// This API is experimental.
func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) {
ri, ok = ctx.Value(requestInfoKey{}).(RequestInfo)
return
}

func init() {
ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, ri)
}
}
3 changes: 3 additions & 0 deletions internal/internal.go
Expand Up @@ -47,6 +47,9 @@ var (
// call to proto.Clone(). The returned Status proto should not be mutated by
// the caller.
StatusRawProto interface{} // func (*status.Status) *spb.Status
// NewRequestInfoContext creates a new context based on the argument context attaching
// the passed in RequestInfo to the new context.
NewRequestInfoContext interface{} // func(context.Context, credentials.RequestInfo) context.Context
)

// HealthChecker defines the signature of the client-side LB channel health checking function.
Expand Down
9 changes: 7 additions & 2 deletions internal/transport/http2_client.go
Expand Up @@ -35,6 +35,7 @@ import (

"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/keepalive"
Expand Down Expand Up @@ -398,11 +399,15 @@ func (t *http2Client) getPeer() *peer.Peer {

func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr)
authData, err := t.getTrAuthData(ctx, aud)
ri := credentials.RequestInfo{
Method: callHdr.Method,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
authData, err := t.getTrAuthData(ctxWithRequestInfo, aud)
if err != nil {
return nil, err
}
callAuthData, err := t.getCallAuthData(ctx, aud, callHdr)
callAuthData, err := t.getCallAuthData(ctxWithRequestInfo, aud, callHdr)
if err != nil {
return nil, err
}
Expand Down
27 changes: 27 additions & 0 deletions test/end2end_test.go
Expand Up @@ -7484,3 +7484,30 @@ func parseCfg(s string) serviceconfig.Config {
}
return c
}

type methodTestCreds struct{}

func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
ri, _ := credentials.RequestInfoFromContext(ctx)
return nil, status.Errorf(codes.Unknown, ri.Method)
}

func (m methodTestCreds) RequireTransportSecurity() bool {
return false
}

func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) {
const wantMethod = "/grpc.testing.TestService/EmptyCall"
ss := &stubServer{}
if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{})); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod {
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod)
}
}

0 comments on commit 47d3cfe

Please sign in to comment.