Skip to content

Commit

Permalink
Add a RequestInfo struct which initially is used for passing the full…
Browse files Browse the repository at this point in the history
… request method (though could later be expanded to pass more info) so that things like GetRequestMetadata can be used to apply logic based on that data.

This is a fix for grpc#3019
  • Loading branch information
shanel-at-google committed Oct 1, 2019
1 parent 861d8e7 commit 2cc8abd
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 1 deletion.
30 changes: 30 additions & 0 deletions credentials/credentials.go
Expand Up @@ -34,6 +34,7 @@ import (

"github.com/golang/protobuf/proto"
"google.golang.org/grpc/credentials/internal"
ginternal "google.golang.org/grpc/internal"
)

// PerRPCCredentials defines the common interface for the credentials which need to
Expand Down Expand Up @@ -334,3 +335,32 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config {

return cfg.Clone()
}

// RequestInfo contains request data to be attached to the context passed along to GetRequestMetadata calls. This API is experimental.
type RequestInfo struct {
// The method passed to Invoke or NewStream for this RPC. (For proto methods, this has the format "/some.Service/Method")
Method string
}

// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object.
type requestInfoKey struct{}

// RequestInfoFromContext extracts the RequestInfo from the context. This API is experimental.
func RequestInfoFromContext(ctx context.Context) RequestInfo {
ri, ok := ctx.Value(requestInfoKey{}).(RequestInfo)
if !ok {
return RequestInfo{}
}
return ri
}

// withRequestInfo adds the supplied RequestInfo to the context. This API is experimental.
func withRequestInfo(ctx context.Context, ri RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, ri)
}

func init() {
ginternal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, ri)
}
}
3 changes: 2 additions & 1 deletion internal/internal.go
Expand Up @@ -46,7 +46,8 @@ var (
// pointer to the wrapped Status proto for a given status.Status without a
// call to proto.Clone(). The returned Status proto should not be mutated by
// the caller.
StatusRawProto interface{} // func (*status.Status) *spb.Status
StatusRawProto interface{} // func (*status.Status) *spb.Status
NewRequestInfoContext interface{} // (this avoids a circular dependency; we do it for other things like this already)
)

// HealthChecker defines the signature of the client-side LB channel health checking function.
Expand Down
5 changes: 5 additions & 0 deletions internal/transport/http2_client.go
Expand Up @@ -35,6 +35,7 @@ import (

"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/keepalive"
Expand Down Expand Up @@ -547,6 +548,10 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
// streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
ctx = peer.NewContext(ctx, t.getPeer())
ri := credentials.RequestInfo{
Method: callHdr.Method,
}
ctx = internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
headerFields, err := t.createHeaderFields(ctx, callHdr)
if err != nil {
return nil, err
Expand Down
26 changes: 26 additions & 0 deletions internal/transport/transport_test.go
Expand Up @@ -37,6 +37,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -1809,3 +1810,28 @@ func TestHeaderTblSize(t *testing.T) {
t.Fatalf("expected len(limits) = 2 within 10s, got != 2")
}
}

// TestRequestInfoFoundInStreamContext tests whether data passed in as callheader data gets
// propagated into the context object as a credentials.RequestInfo object.
func TestRequestInfoFoundInStreamContext(t *testing.T) {
serverConfig := &ServerConfig{}
server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer cancel()
defer server.stop()
defer client.Close()

ch := &CallHdr{
Method: "someService/Method",
}
ctx, ctxCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer ctxCancel()

s, err := client.NewStream(ctx, ch)
if err != nil {
t.Fatalf("client.NewStream() failed: %v", err)
}
ri := credentials.RequestInfoFromContext(s.ctx)
if ch.Method != ri.Method {
t.Fatalf("RequestInfo.Method == %s; want %s", ri.Method, ch.Method)
}
}
29 changes: 29 additions & 0 deletions test/end2end_test.go
Expand Up @@ -7484,3 +7484,32 @@ func parseCfg(s string) serviceconfig.Config {
}
return c
}

type methodTestCreds struct {
expectedMethod string
}

func (m methodTestCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
ri := credentials.RequestInfoFromContext(ctx)
return nil, status.Errorf(codes.Unknown, ri.Method)
}

func (m methodTestCreds) RequireTransportSecurity() bool {
return false
}

func (s) TestGRPCMethodAccessibleToCredsViaContextRequestInfo(t *testing.T) {
const wantMethod = "/grpc.testing.TestService/EmptyCall"
ss := &stubServer{}
if err := ss.Start(nil, grpc.WithPerRPCCredentials(methodTestCreds{"/grpc.testing.TestService/EmptyCall"})); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMethod {
t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, wantMethod)
}
}

0 comments on commit 2cc8abd

Please sign in to comment.