From 14c76f7077b8f34a98b2db491f75072d6b9c250d Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Tue, 10 Aug 2021 21:28:22 -0700 Subject: [PATCH 01/14] Static Authorization Interceptor --- authz/rbac_translator.go | 44 ++-- authz/rbac_translator_test.go | 209 +++++++++------- authz/sdk_end2end_test.go | 178 ++++++++++++++ authz/sdk_server_interceptors.go | 66 +++++ authz/sdk_server_interceptors_test.go | 339 ++++++++++++++++++++++++++ go.mod | 3 +- go.sum | 6 +- internal/xds/rbac/rbac_engine.go | 8 +- internal/xds/rbac/rbac_engine_test.go | 6 +- 9 files changed, 751 insertions(+), 108 deletions(-) create mode 100644 authz/sdk_end2end_test.go create mode 100644 authz/sdk_server_interceptors.go create mode 100644 authz/sdk_server_interceptors_test.go diff --git a/authz/rbac_translator.go b/authz/rbac_translator.go index 8dc76489605..fa0a5001df3 100644 --- a/authz/rbac_translator.go +++ b/authz/rbac_translator.go @@ -23,6 +23,7 @@ package authz import ( + "bytes" "encoding/json" "fmt" "strings" @@ -93,7 +94,7 @@ func getStringMatcher(value string) *v3matcherpb.StringMatcher { switch { case value == "*": return &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Prefix{}, + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{}, } case strings.HasSuffix(value, "*"): prefix := strings.TrimSuffix(value, "*") @@ -117,7 +118,7 @@ func getHeaderMatcher(key, value string) *v3routepb.HeaderMatcher { case value == "*": return &v3routepb.HeaderMatcher{ Name: key, - HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{}, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{}, } case strings.HasSuffix(value, "*"): prefix := strings.TrimSuffix(value, "*") @@ -268,34 +269,39 @@ func parseRules(rules []rule, prefixName string) (map[string]*v3rbacpb.Policy, e } // translatePolicy translates SDK authorization policy in JSON format to two -// Envoy RBAC polices (deny and allow policy). If the policy cannot be parsed -// or is invalid, an error will be returned. -func translatePolicy(policyStr string) (*v3rbacpb.RBAC, *v3rbacpb.RBAC, error) { - var policy authorizationPolicy - if err := json.Unmarshal([]byte(policyStr), &policy); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal policy: %v", err) +// Envoy RBAC polices (deny followed by allow policy) or only one Envoy RBAC +// allow policy. If the input policy cannot be parsed or is invalid, an error +// will be returned. +func translatePolicy(policyStr string) ([]*v3rbacpb.RBAC, error) { + policy := &authorizationPolicy{} + d := json.NewDecoder(bytes.NewReader([]byte(policyStr))) + d.DisallowUnknownFields() + if err := d.Decode(policy); err != nil { + return nil, fmt.Errorf("failed to unmarshal policy: %v", err) } if policy.Name == "" { - return nil, nil, fmt.Errorf(`"name" is not present`) + return nil, fmt.Errorf(`"name" is not present`) } if len(policy.AllowRules) == 0 { - return nil, nil, fmt.Errorf(`"allow_rules" is not present`) + return nil, fmt.Errorf(`"allow_rules" is not present`) } - allowPolicies, err := parseRules(policy.AllowRules, policy.Name) - if err != nil { - return nil, nil, fmt.Errorf(`"allow_rules" %v`, err) - } - allowRBAC := &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: allowPolicies} - var denyRBAC *v3rbacpb.RBAC + var RBACPolicies []*v3rbacpb.RBAC if len(policy.DenyRules) > 0 { denyPolicies, err := parseRules(policy.DenyRules, policy.Name) if err != nil { - return nil, nil, fmt.Errorf(`"deny_rules" %v`, err) + return nil, fmt.Errorf(`"deny_rules" %v`, err) } - denyRBAC = &v3rbacpb.RBAC{ + denyRBAC := &v3rbacpb.RBAC{ Action: v3rbacpb.RBAC_DENY, Policies: denyPolicies, } + RBACPolicies = append(RBACPolicies, denyRBAC) } - return denyRBAC, allowRBAC, nil + allowPolicies, err := parseRules(policy.AllowRules, policy.Name) + if err != nil { + return nil, fmt.Errorf(`"allow_rules" %v`, err) + } + allowRBAC := &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: allowPolicies} + RBACPolicies = append(RBACPolicies, allowRBAC) + return RBACPolicies, nil } diff --git a/authz/rbac_translator_test.go b/authz/rbac_translator_test.go index 425cae85b03..9a883e9d78d 100644 --- a/authz/rbac_translator_test.go +++ b/authz/rbac_translator_test.go @@ -32,10 +32,9 @@ import ( func TestTranslatePolicy(t *testing.T) { tests := map[string]struct { - authzPolicy string - wantErr string - wantDenyPolicy *v3rbacpb.RBAC - wantAllowPolicy *v3rbacpb.RBAC + authzPolicy string + wantErr string + wantPolicies []*v3rbacpb.RBAC }{ "valid policy": { authzPolicy: `{ @@ -82,81 +81,133 @@ func TestTranslatePolicy(t *testing.T) { } }] }`, - wantDenyPolicy: &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_DENY, Policies: map[string]*v3rbacpb.Policy{ - "authz_deny_policy_1": { - Principals: []*v3rbacpb.Principal{ - {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ - Ids: []*v3rbacpb.Principal{ - {Identifier: &v3rbacpb.Principal_Authenticated_{ - Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://foo.abc"}}}}}, - {Identifier: &v3rbacpb.Principal_Authenticated_{ - Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "spiffe://bar"}}}}}, - {Identifier: &v3rbacpb.Principal_Authenticated_{ - Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}}}}}, - {Identifier: &v3rbacpb.Principal_Authenticated_{ - Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://abc.*.com"}}}}}, - }}}}}, - Permissions: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_Any{Any: true}}}, - }, - }}, - wantAllowPolicy: &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: map[string]*v3rbacpb.Policy{ - "authz_allow_policy_1": { - Principals: []*v3rbacpb.Principal{ - {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ - Ids: []*v3rbacpb.Principal{ - {Identifier: &v3rbacpb.Principal_Authenticated_{ - Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}}}}, - }}}}}, - Permissions: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ - Rules: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ - Rules: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_UrlPath{ - UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "path-foo"}}}}}}, - }}}}}}}}}, + wantPolicies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "authz_deny_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://foo.abc"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "spiffe://bar"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}, + }}, + }}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://abc.*.com"}, + }}, + }}, + }, + }}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + }, + }, }, - "authz_allow_policy_2": { - Principals: []*v3rbacpb.Principal{ - {Identifier: &v3rbacpb.Principal_Any{Any: true}}}, - Permissions: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ - Rules: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ - Rules: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_UrlPath{ - UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "path-bar"}}}}}}, - {Rule: &v3rbacpb.Permission_UrlPath{ - UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ - MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}}}}}}, - }}}}, + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "authz_allow_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_SafeRegex{}, + }}, + }}, + }, + }}}, + }, + Permissions: []*v3rbacpb.Permission{ {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ Rules: []*v3rbacpb.Permission{ {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ Rules: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_Header{ - Header: &v3routepb.HeaderMatcher{ - Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "foo"}}}}, - {Rule: &v3rbacpb.Permission_Header{ - Header: &v3routepb.HeaderMatcher{ - Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "bar"}}}}, - }}}}, + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "path-foo"}, + }}}, + }}, + }, + }}}, + }, + }}}, + }, + }, + "authz_allow_policy_2": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ Rules: []*v3rbacpb.Permission{ - {Rule: &v3rbacpb.Permission_Header{ - Header: &v3routepb.HeaderMatcher{ - Name: "key-2", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "baz"}}}}, - }}}}}}}}}}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "path-bar"}, + }}}, + }}, + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}, + }}}, + }}, + }, + }}}, + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "foo"}, + }, + }}, + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "bar"}, + }, + }}, + }, + }}}, + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-2", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "baz"}, + }, + }}, + }, + }}}, + }, + }}}, + }, + }}}, + }, + }, + }, }, - }}, + }, + }, + "unknown field": { + authzPolicy: `{"random": 123}`, + wantErr: "failed to unmarshal policy", }, "missing name field": { authzPolicy: `{}`, @@ -167,10 +218,8 @@ func TestTranslatePolicy(t *testing.T) { wantErr: "failed to unmarshal policy", }, "missing allow rules field": { - authzPolicy: `{"name": "authz-foo"}`, - wantErr: `"allow_rules" is not present`, - wantDenyPolicy: nil, - wantAllowPolicy: nil, + authzPolicy: `{"name": "authz-foo"}`, + wantErr: `"allow_rules" is not present`, }, "missing rule name field": { authzPolicy: `{ @@ -210,18 +259,14 @@ func TestTranslatePolicy(t *testing.T) { wantErr: `"allow_rules" 0: "headers" 0: unsupported "key" :method`, }, } - for name, test := range tests { t.Run(name, func(t *testing.T) { - gotDenyPolicy, gotAllowPolicy, gotErr := translatePolicy(test.authzPolicy) + gotPolicies, gotErr := translatePolicy(test.authzPolicy) if gotErr != nil && !strings.HasPrefix(gotErr.Error(), test.wantErr) { t.Fatalf("unexpected error\nwant:%v\ngot:%v", test.wantErr, gotErr) } - if diff := cmp.Diff(gotDenyPolicy, test.wantDenyPolicy, protocmp.Transform()); diff != "" { - t.Fatalf("unexpected deny policy\ndiff (-want +got):\n%s", diff) - } - if diff := cmp.Diff(gotAllowPolicy, test.wantAllowPolicy, protocmp.Transform()); diff != "" { - t.Fatalf("unexpected allow policy\ndiff (-want +got):\n%s", diff) + if diff := cmp.Diff(gotPolicies, test.wantPolicies, protocmp.Transform()); diff != "" { + t.Fatalf("unexpected policy\ndiff (-want +got):\n%s", diff) } }) } diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go new file mode 100644 index 00000000000..373c32cc5cc --- /dev/null +++ b/authz/sdk_end2end_test.go @@ -0,0 +1,178 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + pb "google.golang.org/grpc/examples/features/proto/echo" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const ( + message = "Hi" +) + +type server struct { + pb.UnimplementedEchoServer +} + +func (s *server) UnaryEcho(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { + return &pb.EchoResponse{Message: message}, nil +} + +func startServer(t *testing.T, policy string) string { + i, _ := NewStatic(policy) + serverOpts := []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor), + } + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + s := grpc.NewServer(serverOpts...) + pb.RegisterEchoServer(s, &server{}) + go func() { + if err := s.Serve(lis); err != nil { + t.Fatalf("failed to serve %v", err) + } + }() + return lis.Addr().String() +} + +func runClient(ctx context.Context, t *testing.T, serverAddr string) (*pb.EchoResponse, error) { + dialOptions := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithBlock(), + } + clientConn, err := grpc.Dial(serverAddr, dialOptions...) + if err != nil { + t.Fatalf("grpc.Dial(%v, %v) failed: %v", serverAddr, dialOptions, err) + } + defer clientConn.Close() + c := pb.NewEchoClient(clientConn) + return c.UnaryEcho(ctx, &pb.EchoRequest{Message: message}, grpc.WaitForReady(true)) +} + +func TestSdkEnd2End(t *testing.T) { + tests := map[string]struct { + authzPolicy string + md metadata.MD + wantStatusCode codes.Code + wantResp string + }{ + "DeniesUnauthorizedRpcRequest": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ], + "deny_rules": + [ + { + "name": "deny_Echo", + "request": { + "paths": + [ + "/grpc.examples.echo.Echo/UnaryEcho" + ], + "headers": + [ + { + "key": "key-abc", + "values": + [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-abc", "val-abc"), + wantStatusCode: codes.PermissionDenied, + }, + "AllowsAuthorizedRpcRequest": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_Echo", + "request": + { + "paths": + [ + "/grpc.examples.echo.Echo/UnaryEcho" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_all", + "request": + { + "headers": + [ + { + "key": "key-abc", + "values": + [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-xyz", "val-xyz"), + wantStatusCode: codes.OK, + wantResp: message, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + serverAddr := startServer(t, test.authzPolicy) + ctx := metadata.NewOutgoingContext(context.Background(), test.md) + + resp, err := runClient(ctx, t, serverAddr) + if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { + t.Fatalf("unexpected authorization decision. status code want:%v got:%v", test.wantStatusCode, gotStatusCode) + } + if resp.GetMessage() != test.wantResp { + t.Fatalf("unexpected response message want:%v got:%v", test.wantResp, resp.GetMessage()) + } + }) + } +} diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go new file mode 100644 index 00000000000..c1a232990b0 --- /dev/null +++ b/authz/sdk_server_interceptors.go @@ -0,0 +1,66 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package authz + +import ( + "context" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/xds/rbac" + "google.golang.org/grpc/status" +) + +type StaticInterceptor struct { + // Either contains two engines deny engine followed by allow engine or only one allow engine. + engines rbac.ChainEngine +} + +// NewStatic returns a new StaticInterceptor from a static authorization policy JSON string. +func NewStatic(authzPolicy string) (*StaticInterceptor, error) { + RBACPolicies, err := translatePolicy(authzPolicy) + if err != nil { + return nil, err + } + chainEngine, err := rbac.NewChainEngine(RBACPolicies) + if err != nil { + return nil, err + } + if chainEngine.IsEmpty() { + return nil, fmt.Errorf("failed to initialize RBAC engines") + } + return &StaticInterceptor{*chainEngine}, nil +} + +// UnaryInterceptor intercepts incoming Unary RPC request. +// Only authorized requests are allowed to pass. Otherwise, unauthorized error is returned to client. +func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if err := i.engines.IsAuthorized(ctx); status.Code(err) != codes.OK { + return nil, err + } + return handler(ctx, req) +} + +// StreamInterceptor intercepts incoming Stream RPC request. +// Only authorized requests are allowed to pass. Otherwise, unauthorized error is returned to client. +func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { + if err := i.engines.IsAuthorized(ss.Context()); status.Code(err) != codes.OK { + return err + } + return handler(srv, ss) +} diff --git a/authz/sdk_server_interceptors_test.go b/authz/sdk_server_interceptors_test.go new file mode 100644 index 00000000000..9ad8a19cb4b --- /dev/null +++ b/authz/sdk_server_interceptors_test.go @@ -0,0 +1,339 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz + +import ( + "context" + "net" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/xds/rbac" + "google.golang.org/grpc/metadata" + p "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +type addr struct { + ipAddress string +} + +func (addr) Network() string { return "" } +func (a *addr) String() string { return a.ipAddress } + +type ServerTransportStreamWithMethod struct { + method string +} + +func (sts *ServerTransportStreamWithMethod) Method() string { + return sts.method +} +func (sts *ServerTransportStreamWithMethod) SetHeader(md metadata.MD) error { + return nil +} +func (sts *ServerTransportStreamWithMethod) SendHeader(md metadata.MD) error { + return nil +} +func (sts *ServerTransportStreamWithMethod) SetTrailer(md metadata.MD) error { + return nil +} + +type fakeStream struct { + ctx context.Context +} + +func (f *fakeStream) SetHeader(metadata.MD) error { return nil } +func (f *fakeStream) SendHeader(metadata.MD) error { return nil } +func (f *fakeStream) SetTrailer(metadata.MD) {} +func (f *fakeStream) Context() context.Context { return f.ctx } +func (f *fakeStream) SendMsg(m interface{}) error { return nil } +func (f *fakeStream) RecvMsg(m interface{}) error { return nil } + +func TestNewStatic(t *testing.T) { + tests := map[string]struct { + authzPolicy string + wantErr bool + }{ + "InvalidPolicyFailsToCreateInterceptor": { + authzPolicy: `{}`, + wantErr: true, + }, + "ValidPolicyCreatesInterceptor": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ] + }`, + wantErr: false, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if _, err := NewStatic(test.authzPolicy); (err != nil) != test.wantErr { + t.Fatalf("NewStatic(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) + } + }) + } +} + +func TestStaticInterceptors(t *testing.T) { + tests := map[string]struct { + authzPolicy string + md metadata.MD + fullMethod string + wantStatusCode codes.Code + }{ + "DeniesRpcRequestMatchInDenyNoMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [ + { + "name": "allow_bar", + "request": { + "paths": [ + "*/bar" + ] + } + } + ], + "deny_rules": [ + { + "name": "deny_foo", + "request": { + "paths": [ + "*/foo" + ], + "headers": [ + { + "key": "key-abc", + "values": [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-abc", "val-abc"), + fullMethod: "/package.service/foo", + wantStatusCode: codes.PermissionDenied, + }, + "DeniesRpcRequestMatchInDenyAndAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [ + { + "name": "allow_foo", + "request": { + "paths": [ + "*/foo" + ] + } + } + ], + "deny_rules": [ + { + "name": "deny_foo", + "request": { + "paths": [ + "*/foo" + ] + } + } + ] + }`, + fullMethod: "/package.service/foo", + wantStatusCode: codes.PermissionDenied, + }, + "AllowsRpcRequestNoMatchInDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [ + { + "name": "allow_foo", + "request": { + "paths": [ + "*/foo" + ] + } + } + ], + "deny_rules": [ + { + "name": "deny_foo", + "request": { + "paths": [ + "*/foo" + ], + "headers": [ + { + "key": "key-abc", + "values": [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-xyz", "val-xyz"), + fullMethod: "/package.service/foo", + wantStatusCode: codes.OK, + }, + "DeniesRpcRequestNoMatchInDenyAndAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [ + { + "name": "allow_baz_user", + "source": { + "principals": [ + "baz" + ] + } + } + ], + "deny_rules": [ + { + "name": "deny_bar", + "request": { + "paths": [ + "*/bar" + ] + } + } + ] + }`, + fullMethod: "/package.service/foo", + wantStatusCode: codes.PermissionDenied, + }, + "AllowsRpcRequestEmptyDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [ + { + "name": "allow_foo", + "request": { + "paths": [ + "*/foo" + ], + "headers": [ + { + "key": "key-abc", + "values": [ + "val-abc", + "val-def" + ] + } + ] + } + } + ] + }`, + md: metadata.Pairs("key-abc", "val-abc", "key-xyz", "val-xyz"), + fullMethod: "/package.service/foo", + wantStatusCode: codes.OK, + }, + "DeniesRpcRequestEmptyDenyNoMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [ + { + "name": "allow_bar", + "request": { + "paths": [ + "*/bar" + ] + } + } + ] + }`, + fullMethod: "/package.service/foo", + wantStatusCode: codes.PermissionDenied, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + i, err := NewStatic(test.authzPolicy) + if err != nil { + t.Fatalf("NewStatic(%v) failed to create interceptor. err: %v", test.authzPolicy, err) + } + + ctx := metadata.NewIncomingContext(context.Background(), test.md) + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error listening: %v", err) + } + defer lis.Close() + connCh := make(chan net.Conn, 1) + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + _, err = net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + conn := <-connCh + defer conn.Close() + rbac.GetConnection = func(context.Context) net.Conn { + return conn + } + ctx = p.NewContext(ctx, &p.Peer{Addr: &addr{}}) + stream := &ServerTransportStreamWithMethod{ + method: test.fullMethod, + } + ctx = grpc.NewContextWithServerTransportStream(ctx, stream) + + // Testing UnaryInterceptor + unaryHandler := func(_ context.Context, _ interface{}) (interface{}, error) { + return message, nil + } + resp, err := i.UnaryInterceptor(ctx, nil, &grpc.UnaryServerInfo{}, unaryHandler) + if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { + t.Fatalf("UnaryInterceptor returned unexpected error code want:%v got:%v", test.wantStatusCode, gotStatusCode) + } + if resp != nil && resp != message { + t.Fatalf("UnaryInterceptor returned unexpected response want:%v got:%v", message, resp) + } + + // Testing StreamInterceptor + streamHandler := func(_ interface{}, _ grpc.ServerStream) error { + return nil + } + err = i.StreamInterceptor(nil, &fakeStream{ctx: ctx}, &grpc.StreamServerInfo{}, streamHandler) + if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { + t.Fatalf("StreamInterceptor returned unexpected error code want:%v got:%v", test.wantStatusCode, gotStatusCode) + } + }) + } +} diff --git a/go.mod b/go.mod index 2f2cf1eb766..ab463bd0aeb 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( golang.org/x/net v0.0.0-20200822124328-c89045814202 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd - google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 + google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 + google.golang.org/grpc/examples v0.0.0-20210812181202-a42567fe92f0 google.golang.org/protobuf v1.25.0 ) diff --git a/go.sum b/go.sum index 372b4ea3d20..b24db2f0aee 100644 --- a/go.sum +++ b/go.sum @@ -101,14 +101,17 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= +google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc/examples v0.0.0-20210812181202-a42567fe92f0 h1:c8JeMIKJoZiKyyHwcCJK2+++7GCk3D1smA3Fi8c0i6w= +google.golang.org/grpc/examples v0.0.0-20210812181202-a42567fe92f0/go.mod h1:bF8wuZSAZTcbF7ZPKrDI/qY52toTP/yxLpRRY4Eu9Js= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -117,6 +120,7 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index f08228f7791..49c8e5b46f5 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -38,7 +38,7 @@ import ( "google.golang.org/grpc/status" ) -var getConnection = transport.GetConnection +var GetConnection = transport.GetConnection // ChainEngine represents a chain of RBAC Engines, used to make authorization // decisions on incoming RPCs. @@ -89,6 +89,10 @@ func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { return status.Error(codes.OK, "") } +func (cre *ChainEngine) IsEmpty() bool { + return len(cre.chainedEngines) == 0 +} + // engine is used for matching incoming RPCs to policies. type engine struct { policies map[string]*policyMatcher @@ -160,7 +164,7 @@ func newRPCData(ctx context.Context) (*rpcData, error) { // The connection is needed in order to find the destination address and // port of the incoming RPC Call. - conn := getConnection(ctx) + conn := GetConnection(ctx) if conn == nil { return nil, errors.New("missing connection in incoming context") } diff --git a/internal/xds/rbac/rbac_engine_test.go b/internal/xds/rbac/rbac_engine_test.go index 807df9b87a8..91c5bc15c8f 100644 --- a/internal/xds/rbac/rbac_engine_test.go +++ b/internal/xds/rbac/rbac_engine_test.go @@ -441,8 +441,8 @@ func (s) TestNewChainEngine(t *testing.T) { // and verifies that it works as expected. func (s) TestChainEngine(t *testing.T) { defer func(gc func(ctx context.Context) net.Conn) { - getConnection = gc - }(getConnection) + GetConnection = gc + }(GetConnection) tests := []struct { name string rbacConfigs []*v3rbacpb.RBAC @@ -885,7 +885,7 @@ func (s) TestChainEngine(t *testing.T) { } conn := <-connCh defer conn.Close() - getConnection = func(context.Context) net.Conn { + GetConnection = func(context.Context) net.Conn { return conn } ctx = peer.NewContext(ctx, data.rpcData.peerInfo) From a5ee344b39ec82986267a465dac498b9d2f102e5 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Tue, 17 Aug 2021 16:27:36 -0700 Subject: [PATCH 02/14] Use TestService in end2end tests. --- authz/sdk_end2end_test.go | 31 +++++++++++---------------- authz/sdk_server_interceptors_test.go | 4 ++++ go.mod | 3 +-- go.sum | 6 +----- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index 373c32cc5cc..82575387f26 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -25,21 +25,17 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" - pb "google.golang.org/grpc/examples/features/proto/echo" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + pb "google.golang.org/grpc/test/grpc_testing" ) -const ( - message = "Hi" -) - -type server struct { - pb.UnimplementedEchoServer +type testServer struct { + pb.UnimplementedTestServiceServer } -func (s *server) UnaryEcho(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) { - return &pb.EchoResponse{Message: message}, nil +func (s *testServer) UnaryCall(ctx context.Context, req *pb.SimpleRequest) (*pb.SimpleResponse, error) { + return &pb.SimpleResponse{}, nil } func startServer(t *testing.T, policy string) string { @@ -53,7 +49,7 @@ func startServer(t *testing.T, policy string) string { t.Fatalf("error listening: %v", err) } s := grpc.NewServer(serverOpts...) - pb.RegisterEchoServer(s, &server{}) + pb.RegisterTestServiceServer(s, &testServer{}) go func() { if err := s.Serve(lis); err != nil { t.Fatalf("failed to serve %v", err) @@ -62,7 +58,7 @@ func startServer(t *testing.T, policy string) string { return lis.Addr().String() } -func runClient(ctx context.Context, t *testing.T, serverAddr string) (*pb.EchoResponse, error) { +func runClient(ctx context.Context, t *testing.T, serverAddr string) (*pb.SimpleResponse, error) { dialOptions := []grpc.DialOption{ grpc.WithInsecure(), grpc.WithBlock(), @@ -72,8 +68,8 @@ func runClient(ctx context.Context, t *testing.T, serverAddr string) (*pb.EchoRe t.Fatalf("grpc.Dial(%v, %v) failed: %v", serverAddr, dialOptions, err) } defer clientConn.Close() - c := pb.NewEchoClient(clientConn) - return c.UnaryEcho(ctx, &pb.EchoRequest{Message: message}, grpc.WaitForReady(true)) + c := pb.NewTestServiceClient(clientConn) + return c.UnaryCall(ctx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) } func TestSdkEnd2End(t *testing.T) { @@ -99,7 +95,7 @@ func TestSdkEnd2End(t *testing.T) { "request": { "paths": [ - "/grpc.examples.echo.Echo/UnaryEcho" + "/grpc.testing.TestService/UnaryCall" ], "headers": [ @@ -130,7 +126,7 @@ func TestSdkEnd2End(t *testing.T) { { "paths": [ - "/grpc.examples.echo.Echo/UnaryEcho" + "/grpc.testing.TestService/UnaryCall" ] } } @@ -166,13 +162,10 @@ func TestSdkEnd2End(t *testing.T) { serverAddr := startServer(t, test.authzPolicy) ctx := metadata.NewOutgoingContext(context.Background(), test.md) - resp, err := runClient(ctx, t, serverAddr) + _, err := runClient(ctx, t, serverAddr) if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { t.Fatalf("unexpected authorization decision. status code want:%v got:%v", test.wantStatusCode, gotStatusCode) } - if resp.GetMessage() != test.wantResp { - t.Fatalf("unexpected response message want:%v got:%v", test.wantResp, resp.GetMessage()) - } }) } } diff --git a/authz/sdk_server_interceptors_test.go b/authz/sdk_server_interceptors_test.go index 9ad8a19cb4b..2e11792cf66 100644 --- a/authz/sdk_server_interceptors_test.go +++ b/authz/sdk_server_interceptors_test.go @@ -31,6 +31,10 @@ import ( "google.golang.org/grpc/status" ) +const ( + message = "Hi" +) + type addr struct { ipAddress string } diff --git a/go.mod b/go.mod index ab463bd0aeb..2f2cf1eb766 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( golang.org/x/net v0.0.0-20200822124328-c89045814202 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd - google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 - google.golang.org/grpc/examples v0.0.0-20210812181202-a42567fe92f0 + google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 google.golang.org/protobuf v1.25.0 ) diff --git a/go.sum b/go.sum index b24db2f0aee..372b4ea3d20 100644 --- a/go.sum +++ b/go.sum @@ -101,17 +101,14 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98 h1:LCO0fg4kb6WwkXQXRQQgUYsFeFb5taTX5WAx5O/Vt28= -google.golang.org/genproto v0.0.0-20200806141610-86f49bd18e98/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc/examples v0.0.0-20210812181202-a42567fe92f0 h1:c8JeMIKJoZiKyyHwcCJK2+++7GCk3D1smA3Fi8c0i6w= -google.golang.org/grpc/examples v0.0.0-20210812181202-a42567fe92f0/go.mod h1:bF8wuZSAZTcbF7ZPKrDI/qY52toTP/yxLpRRY4Eu9Js= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -120,7 +117,6 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= From 6fd1ead3ac5ed7910a650ba919de785fe61db2aa Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Tue, 17 Aug 2021 21:50:46 -0700 Subject: [PATCH 03/14] test fixes --- authz/sdk_end2end_test.go | 9 +-------- authz/sdk_server_interceptors.go | 13 +++++++++---- internal/xds/rbac/rbac_engine.go | 10 ++++++---- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index 82575387f26..f2b740655b2 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -50,11 +50,7 @@ func startServer(t *testing.T, policy string) string { } s := grpc.NewServer(serverOpts...) pb.RegisterTestServiceServer(s, &testServer{}) - go func() { - if err := s.Serve(lis); err != nil { - t.Fatalf("failed to serve %v", err) - } - }() + go s.Serve(lis) return lis.Addr().String() } @@ -77,7 +73,6 @@ func TestSdkEnd2End(t *testing.T) { authzPolicy string md metadata.MD wantStatusCode codes.Code - wantResp string }{ "DeniesUnauthorizedRpcRequest": { authzPolicy: `{ @@ -154,14 +149,12 @@ func TestSdkEnd2End(t *testing.T) { }`, md: metadata.Pairs("key-xyz", "val-xyz"), wantStatusCode: codes.OK, - wantResp: message, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { serverAddr := startServer(t, test.authzPolicy) ctx := metadata.NewOutgoingContext(context.Background(), test.md) - _, err := runClient(ctx, t, serverAddr) if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { t.Fatalf("unexpected authorization decision. status code want:%v got:%v", test.wantStatusCode, gotStatusCode) diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go index c1a232990b0..8eef3557d90 100644 --- a/authz/sdk_server_interceptors.go +++ b/authz/sdk_server_interceptors.go @@ -26,12 +26,15 @@ import ( "google.golang.org/grpc/status" ) +// StaticInterceptor contains engines used to make authorization decisions. It +// either contains two engines deny engine followed by an allow engine or only +// one allow engine. type StaticInterceptor struct { - // Either contains two engines deny engine followed by allow engine or only one allow engine. engines rbac.ChainEngine } -// NewStatic returns a new StaticInterceptor from a static authorization policy JSON string. +// NewStatic returns a new StaticInterceptor from a static authorization policy +// JSON string. func NewStatic(authzPolicy string) (*StaticInterceptor, error) { RBACPolicies, err := translatePolicy(authzPolicy) if err != nil { @@ -48,7 +51,8 @@ func NewStatic(authzPolicy string) (*StaticInterceptor, error) { } // UnaryInterceptor intercepts incoming Unary RPC request. -// Only authorized requests are allowed to pass. Otherwise, unauthorized error is returned to client. +// Only authorized requests are allowed to pass. Otherwise, unauthorized error +// is returned to client. func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { if err := i.engines.IsAuthorized(ctx); status.Code(err) != codes.OK { return nil, err @@ -57,7 +61,8 @@ func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{ } // StreamInterceptor intercepts incoming Stream RPC request. -// Only authorized requests are allowed to pass. Otherwise, unauthorized error is returned to client. +// Only authorized requests are allowed to pass. Otherwise, unauthorized error +// is returned to client. func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { if err := i.engines.IsAuthorized(ss.Context()); status.Code(err) != codes.OK { return err diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index 49c8e5b46f5..48a760f5ed0 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -38,6 +38,7 @@ import ( "google.golang.org/grpc/status" ) +// GetConnection allows tests to inject fake connection. var GetConnection = transport.GetConnection // ChainEngine represents a chain of RBAC Engines, used to make authorization @@ -60,6 +61,11 @@ func NewChainEngine(policies []*v3rbacpb.RBAC) (*ChainEngine, error) { return &ChainEngine{chainedEngines: engines}, nil } +// IsEmpty returns true if ChainEngine contains zero engine. +func (cre *ChainEngine) IsEmpty() bool { + return len(cre.chainedEngines) == 0 +} + // IsAuthorized determines if an incoming RPC is authorized based on the chain of RBAC // engines and their associated actions. // @@ -89,10 +95,6 @@ func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { return status.Error(codes.OK, "") } -func (cre *ChainEngine) IsEmpty() bool { - return len(cre.chainedEngines) == 0 -} - // engine is used for matching incoming RPCs to policies. type engine struct { policies map[string]*policyMatcher From 8feabc115e06ef4d21b85c621033f7133f5d15ce Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Tue, 24 Aug 2021 12:03:27 -0700 Subject: [PATCH 04/14] Resolving comments. --- authz/rbac_translator.go | 8 +- authz/sdk_end2end_test.go | 263 ++++++++++++++++++----- authz/sdk_server_interceptors.go | 36 ++-- authz/sdk_server_interceptors_test.go | 293 +------------------------- internal/xds/rbac/rbac_engine.go | 10 +- internal/xds/rbac/rbac_engine_test.go | 6 +- 6 files changed, 242 insertions(+), 374 deletions(-) diff --git a/authz/rbac_translator.go b/authz/rbac_translator.go index fa0a5001df3..821a28d0262 100644 --- a/authz/rbac_translator.go +++ b/authz/rbac_translator.go @@ -285,7 +285,7 @@ func translatePolicy(policyStr string) ([]*v3rbacpb.RBAC, error) { if len(policy.AllowRules) == 0 { return nil, fmt.Errorf(`"allow_rules" is not present`) } - var RBACPolicies []*v3rbacpb.RBAC + var rbacs []*v3rbacpb.RBAC if len(policy.DenyRules) > 0 { denyPolicies, err := parseRules(policy.DenyRules, policy.Name) if err != nil { @@ -295,13 +295,13 @@ func translatePolicy(policyStr string) ([]*v3rbacpb.RBAC, error) { Action: v3rbacpb.RBAC_DENY, Policies: denyPolicies, } - RBACPolicies = append(RBACPolicies, denyRBAC) + rbacs = append(rbacs, denyRBAC) } allowPolicies, err := parseRules(policy.AllowRules, policy.Name) if err != nil { return nil, fmt.Errorf(`"allow_rules" %v`, err) } allowRBAC := &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: allowPolicies} - RBACPolicies = append(RBACPolicies, allowRBAC) - return RBACPolicies, nil + rbacs = append(rbacs, allowRBAC) + return rbacs, nil } diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index f2b740655b2..ad312ccc0a6 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -16,14 +16,22 @@ * */ -package authz +// Package authz_test contains tests for authz. +// +// Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed +// in a later release. +package authz_test import ( "context" + "io" "net" "testing" "google.golang.org/grpc" + "google.golang.org/grpc/authz" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -38,65 +46,56 @@ func (s *testServer) UnaryCall(ctx context.Context, req *pb.SimpleRequest) (*pb. return &pb.SimpleResponse{}, nil } -func startServer(t *testing.T, policy string) string { - i, _ := NewStatic(policy) - serverOpts := []grpc.ServerOption{ - grpc.ChainUnaryInterceptor(i.UnaryInterceptor), - grpc.ChainStreamInterceptor(i.StreamInterceptor), +func (s *testServer) StreamingInputCall(stream pb.TestService_StreamingInputCallServer) error { + for { + _, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&pb.StreamingInputCallResponse{}) + } + if err != nil { + return err + } } - lis, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("error listening: %v", err) - } - s := grpc.NewServer(serverOpts...) - pb.RegisterTestServiceServer(s, &testServer{}) - go s.Serve(lis) - return lis.Addr().String() } -func runClient(ctx context.Context, t *testing.T, serverAddr string) (*pb.SimpleResponse, error) { - dialOptions := []grpc.DialOption{ - grpc.WithInsecure(), - grpc.WithBlock(), - } - clientConn, err := grpc.Dial(serverAddr, dialOptions...) - if err != nil { - t.Fatalf("grpc.Dial(%v, %v) failed: %v", serverAddr, dialOptions, err) - } - defer clientConn.Close() - c := pb.NewTestServiceClient(clientConn) - return c.UnaryCall(ctx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) -} - -func TestSdkEnd2End(t *testing.T) { +func TestSDKEnd2End(t *testing.T) { tests := map[string]struct { authzPolicy string md metadata.MD wantStatusCode codes.Code + wantErr string }{ - "DeniesUnauthorizedRpcRequest": { + "DeniesRpcRequestMatchInDenyNoMatchInAllow": { authzPolicy: `{ "name": "authz", "allow_rules": [ { - "name": "allow_all" + "name": "allow_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/UnaryCall", + "/grpc.testing.TestService/StreamingOutputCall" + ] + } } ], - "deny_rules": + "deny_rules": [ { - "name": "deny_Echo", + "name": "deny_TestServiceCalls", "request": { - "paths": + "paths": [ - "/grpc.testing.TestService/UnaryCall" + "/grpc.testing.TestService/UnaryCall", + "/grpc.testing.TestService/StreamingInputCall" ], - "headers": + "headers": [ { "key": "key-abc", - "values": + "values": [ "val-abc", "val-def" @@ -109,29 +108,58 @@ func TestSdkEnd2End(t *testing.T) { }`, md: metadata.Pairs("key-abc", "val-abc"), wantStatusCode: codes.PermissionDenied, + wantErr: "Unauthorized RPC request rejected.", }, - "AllowsAuthorizedRpcRequest": { + "DeniesRpcRequestMatchInDenyAndAllow": { authzPolicy: `{ "name": "authz", - "allow_rules": + "allow_rules": [ { - "name": "allow_Echo", - "request": - { - "paths": + "name": "allow_TestServiceCalls", + "request": { + "paths": [ - "/grpc.testing.TestService/UnaryCall" + "/grpc.testing.TestService/*" ] } } ], - "deny_rules": + "deny_rules": [ { - "name": "deny_all", - "request": - { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/*" + ] + } + } + ] + }`, + wantStatusCode: codes.PermissionDenied, + wantErr: "Unauthorized RPC request rejected.", + }, + "AllowsRpcRequestNoMatchInDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ], + "deny_rules": + [ + { + "name": "deny_TestServiceCalls", + "request": { + "paths": + [ + "/grpc.testing.TestService/UnaryCall", + "/grpc.testing.TestService/StreamingInputCall" + ], "headers": [ { @@ -150,15 +178,144 @@ func TestSdkEnd2End(t *testing.T) { md: metadata.Pairs("key-xyz", "val-xyz"), wantStatusCode: codes.OK, }, + "AllowsRpcRequestNoMatchInDenyAndAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_some_user", + "source": { + "principals": + [ + "some_user" + ] + } + } + ], + "deny_rules": + [ + { + "name": "deny_StreamingOutputCall", + "request": { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ] + }`, + wantStatusCode: codes.PermissionDenied, + wantErr: "Unauthorized RPC request rejected.", + }, + "AllowsRpcRequestEmptyDenyMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_UnaryCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/UnaryCall" + ] + } + }, + { + "name": "allow_StreamingInputCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/StreamingInputCall" + ] + } + } + ] + }`, + wantStatusCode: codes.OK, + }, + "DeniesRpcRequestEmptyDenyNoMatchInAllow": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_StreamingOutputCall", + "request": + { + "paths": + [ + "/grpc.testing.TestService/StreamingOutputCall" + ] + } + } + ] + }`, + wantStatusCode: codes.PermissionDenied, + wantErr: "Unauthorized RPC request rejected.", + }, } for name, test := range tests { t.Run(name, func(t *testing.T) { - serverAddr := startServer(t, test.authzPolicy) - ctx := metadata.NewOutgoingContext(context.Background(), test.md) - _, err := runClient(ctx, t, serverAddr) - if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { - t.Fatalf("unexpected authorization decision. status code want:%v got:%v", test.wantStatusCode, gotStatusCode) + + i, _ := authz.NewStatic(test.authzPolicy) + serverOpts := []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor), } + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + s := grpc.NewServer(serverOpts...) + pb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + + dialOptions := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithBlock(), + } + clientConn, err := grpc.Dial(lis.Addr().String(), dialOptions...) + if err != nil { + t.Fatalf("grpc.Dial(%v, %v) failed: %v", lis.Addr().String(), dialOptions, err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + // Verifying Unary RPC + unaryCtx := metadata.NewOutgoingContext(context.Background(), test.md) + _, err = client.UnaryCall(unaryCtx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) + gotStatus, _ := status.FromError(err) + if gotStatus.Code() != test.wantStatusCode { + t.Fatalf("[UnaryCall] status code want:%v got:%v", test.wantStatusCode, gotStatus.Code()) + } + if gotStatus.Message() != test.wantErr { + t.Fatalf("[UnaryCall] error message want:%v got:%v", test.wantErr, gotStatus.Message()) + } + + // Verifying Streaming RPC + streamCtx := metadata.NewOutgoingContext(context.Background(), test.md) + stream, err := client.StreamingInputCall(streamCtx, grpc.WaitForReady(true)) + if err != nil { + t.Fatalf("Failed StreamingInputCall err:%v", err) + } + req := &pb.StreamingInputCallRequest{} + if err := stream.Send(req); err != nil { + t.Fatalf("stream.Send failed err: %v", err) + } + _, err = stream.CloseAndRecv() + gotStatus, _ = status.FromError(err) + if gotStatus.Code() != test.wantStatusCode { + t.Fatalf("[StreamingCall] status code want:%v got:%v", test.wantStatusCode, gotStatus.Code()) + } + if gotStatus.Message() != test.wantErr { + t.Fatalf("[StreamingCall] error message want:%v got:%v", test.wantErr, gotStatus.Message()) + } + }) } } diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go index 8eef3557d90..ab1359127b5 100644 --- a/authz/sdk_server_interceptors.go +++ b/authz/sdk_server_interceptors.go @@ -18,7 +18,6 @@ package authz import ( "context" - "fmt" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -36,36 +35,41 @@ type StaticInterceptor struct { // NewStatic returns a new StaticInterceptor from a static authorization policy // JSON string. func NewStatic(authzPolicy string) (*StaticInterceptor, error) { - RBACPolicies, err := translatePolicy(authzPolicy) + rbacs, err := translatePolicy(authzPolicy) if err != nil { return nil, err } - chainEngine, err := rbac.NewChainEngine(RBACPolicies) + chainEngine, err := rbac.NewChainEngine(rbacs) if err != nil { return nil, err } - if chainEngine.IsEmpty() { - return nil, fmt.Errorf("failed to initialize RBAC engines") - } return &StaticInterceptor{*chainEngine}, nil } -// UnaryInterceptor intercepts incoming Unary RPC request. -// Only authorized requests are allowed to pass. Otherwise, unauthorized error -// is returned to client. -func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - if err := i.engines.IsAuthorized(ctx); status.Code(err) != codes.OK { +// UnaryInterceptor intercepts incoming Unary RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + err := i.engines.IsAuthorized(ctx) + if status.Code(err) == codes.InvalidArgument { return nil, err } + if status.Code(err) != codes.OK { + return nil, status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") + } return handler(ctx, req) } -// StreamInterceptor intercepts incoming Stream RPC request. -// Only authorized requests are allowed to pass. Otherwise, unauthorized error -// is returned to client. -func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { - if err := i.engines.IsAuthorized(ss.Context()); status.Code(err) != codes.OK { +// StreamInterceptor intercepts incoming Stream RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := i.engines.IsAuthorized(ss.Context()) + if status.Code(err) == codes.InvalidArgument { return err } + if status.Code(err) != codes.OK { + return status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") + } return handler(srv, ss) } diff --git a/authz/sdk_server_interceptors_test.go b/authz/sdk_server_interceptors_test.go index 2e11792cf66..e2c1072e7d8 100644 --- a/authz/sdk_server_interceptors_test.go +++ b/authz/sdk_server_interceptors_test.go @@ -16,60 +16,14 @@ * */ -package authz +package authz_test import ( - "context" - "net" "testing" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/internal/xds/rbac" - "google.golang.org/grpc/metadata" - p "google.golang.org/grpc/peer" - "google.golang.org/grpc/status" + "google.golang.org/grpc/authz" ) -const ( - message = "Hi" -) - -type addr struct { - ipAddress string -} - -func (addr) Network() string { return "" } -func (a *addr) String() string { return a.ipAddress } - -type ServerTransportStreamWithMethod struct { - method string -} - -func (sts *ServerTransportStreamWithMethod) Method() string { - return sts.method -} -func (sts *ServerTransportStreamWithMethod) SetHeader(md metadata.MD) error { - return nil -} -func (sts *ServerTransportStreamWithMethod) SendHeader(md metadata.MD) error { - return nil -} -func (sts *ServerTransportStreamWithMethod) SetTrailer(md metadata.MD) error { - return nil -} - -type fakeStream struct { - ctx context.Context -} - -func (f *fakeStream) SetHeader(metadata.MD) error { return nil } -func (f *fakeStream) SendHeader(metadata.MD) error { return nil } -func (f *fakeStream) SetTrailer(metadata.MD) {} -func (f *fakeStream) Context() context.Context { return f.ctx } -func (f *fakeStream) SendMsg(m interface{}) error { return nil } -func (f *fakeStream) RecvMsg(m interface{}) error { return nil } - func TestNewStatic(t *testing.T) { tests := map[string]struct { authzPolicy string @@ -94,250 +48,9 @@ func TestNewStatic(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - if _, err := NewStatic(test.authzPolicy); (err != nil) != test.wantErr { + if _, err := authz.NewStatic(test.authzPolicy); (err != nil) != test.wantErr { t.Fatalf("NewStatic(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) } }) } } - -func TestStaticInterceptors(t *testing.T) { - tests := map[string]struct { - authzPolicy string - md metadata.MD - fullMethod string - wantStatusCode codes.Code - }{ - "DeniesRpcRequestMatchInDenyNoMatchInAllow": { - authzPolicy: `{ - "name": "authz", - "allow_rules": [ - { - "name": "allow_bar", - "request": { - "paths": [ - "*/bar" - ] - } - } - ], - "deny_rules": [ - { - "name": "deny_foo", - "request": { - "paths": [ - "*/foo" - ], - "headers": [ - { - "key": "key-abc", - "values": [ - "val-abc", - "val-def" - ] - } - ] - } - } - ] - }`, - md: metadata.Pairs("key-abc", "val-abc"), - fullMethod: "/package.service/foo", - wantStatusCode: codes.PermissionDenied, - }, - "DeniesRpcRequestMatchInDenyAndAllow": { - authzPolicy: `{ - "name": "authz", - "allow_rules": [ - { - "name": "allow_foo", - "request": { - "paths": [ - "*/foo" - ] - } - } - ], - "deny_rules": [ - { - "name": "deny_foo", - "request": { - "paths": [ - "*/foo" - ] - } - } - ] - }`, - fullMethod: "/package.service/foo", - wantStatusCode: codes.PermissionDenied, - }, - "AllowsRpcRequestNoMatchInDenyMatchInAllow": { - authzPolicy: `{ - "name": "authz", - "allow_rules": [ - { - "name": "allow_foo", - "request": { - "paths": [ - "*/foo" - ] - } - } - ], - "deny_rules": [ - { - "name": "deny_foo", - "request": { - "paths": [ - "*/foo" - ], - "headers": [ - { - "key": "key-abc", - "values": [ - "val-abc", - "val-def" - ] - } - ] - } - } - ] - }`, - md: metadata.Pairs("key-xyz", "val-xyz"), - fullMethod: "/package.service/foo", - wantStatusCode: codes.OK, - }, - "DeniesRpcRequestNoMatchInDenyAndAllow": { - authzPolicy: `{ - "name": "authz", - "allow_rules": [ - { - "name": "allow_baz_user", - "source": { - "principals": [ - "baz" - ] - } - } - ], - "deny_rules": [ - { - "name": "deny_bar", - "request": { - "paths": [ - "*/bar" - ] - } - } - ] - }`, - fullMethod: "/package.service/foo", - wantStatusCode: codes.PermissionDenied, - }, - "AllowsRpcRequestEmptyDenyMatchInAllow": { - authzPolicy: `{ - "name": "authz", - "allow_rules": [ - { - "name": "allow_foo", - "request": { - "paths": [ - "*/foo" - ], - "headers": [ - { - "key": "key-abc", - "values": [ - "val-abc", - "val-def" - ] - } - ] - } - } - ] - }`, - md: metadata.Pairs("key-abc", "val-abc", "key-xyz", "val-xyz"), - fullMethod: "/package.service/foo", - wantStatusCode: codes.OK, - }, - "DeniesRpcRequestEmptyDenyNoMatchInAllow": { - authzPolicy: `{ - "name": "authz", - "allow_rules": [ - { - "name": "allow_bar", - "request": { - "paths": [ - "*/bar" - ] - } - } - ] - }`, - fullMethod: "/package.service/foo", - wantStatusCode: codes.PermissionDenied, - }, - } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - i, err := NewStatic(test.authzPolicy) - if err != nil { - t.Fatalf("NewStatic(%v) failed to create interceptor. err: %v", test.authzPolicy, err) - } - - ctx := metadata.NewIncomingContext(context.Background(), test.md) - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error listening: %v", err) - } - defer lis.Close() - connCh := make(chan net.Conn, 1) - go func() { - conn, err := lis.Accept() - if err != nil { - t.Errorf("Error accepting connection: %v", err) - return - } - connCh <- conn - }() - _, err = net.Dial("tcp", lis.Addr().String()) - if err != nil { - t.Fatalf("Error dialing: %v", err) - } - conn := <-connCh - defer conn.Close() - rbac.GetConnection = func(context.Context) net.Conn { - return conn - } - ctx = p.NewContext(ctx, &p.Peer{Addr: &addr{}}) - stream := &ServerTransportStreamWithMethod{ - method: test.fullMethod, - } - ctx = grpc.NewContextWithServerTransportStream(ctx, stream) - - // Testing UnaryInterceptor - unaryHandler := func(_ context.Context, _ interface{}) (interface{}, error) { - return message, nil - } - resp, err := i.UnaryInterceptor(ctx, nil, &grpc.UnaryServerInfo{}, unaryHandler) - if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { - t.Fatalf("UnaryInterceptor returned unexpected error code want:%v got:%v", test.wantStatusCode, gotStatusCode) - } - if resp != nil && resp != message { - t.Fatalf("UnaryInterceptor returned unexpected response want:%v got:%v", message, resp) - } - - // Testing StreamInterceptor - streamHandler := func(_ interface{}, _ grpc.ServerStream) error { - return nil - } - err = i.StreamInterceptor(nil, &fakeStream{ctx: ctx}, &grpc.StreamServerInfo{}, streamHandler) - if gotStatusCode := status.Code(err); gotStatusCode != test.wantStatusCode { - t.Fatalf("StreamInterceptor returned unexpected error code want:%v got:%v", test.wantStatusCode, gotStatusCode) - } - }) - } -} diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index 48a760f5ed0..f08228f7791 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -38,8 +38,7 @@ import ( "google.golang.org/grpc/status" ) -// GetConnection allows tests to inject fake connection. -var GetConnection = transport.GetConnection +var getConnection = transport.GetConnection // ChainEngine represents a chain of RBAC Engines, used to make authorization // decisions on incoming RPCs. @@ -61,11 +60,6 @@ func NewChainEngine(policies []*v3rbacpb.RBAC) (*ChainEngine, error) { return &ChainEngine{chainedEngines: engines}, nil } -// IsEmpty returns true if ChainEngine contains zero engine. -func (cre *ChainEngine) IsEmpty() bool { - return len(cre.chainedEngines) == 0 -} - // IsAuthorized determines if an incoming RPC is authorized based on the chain of RBAC // engines and their associated actions. // @@ -166,7 +160,7 @@ func newRPCData(ctx context.Context) (*rpcData, error) { // The connection is needed in order to find the destination address and // port of the incoming RPC Call. - conn := GetConnection(ctx) + conn := getConnection(ctx) if conn == nil { return nil, errors.New("missing connection in incoming context") } diff --git a/internal/xds/rbac/rbac_engine_test.go b/internal/xds/rbac/rbac_engine_test.go index 91c5bc15c8f..807df9b87a8 100644 --- a/internal/xds/rbac/rbac_engine_test.go +++ b/internal/xds/rbac/rbac_engine_test.go @@ -441,8 +441,8 @@ func (s) TestNewChainEngine(t *testing.T) { // and verifies that it works as expected. func (s) TestChainEngine(t *testing.T) { defer func(gc func(ctx context.Context) net.Conn) { - GetConnection = gc - }(GetConnection) + getConnection = gc + }(getConnection) tests := []struct { name string rbacConfigs []*v3rbacpb.RBAC @@ -885,7 +885,7 @@ func (s) TestChainEngine(t *testing.T) { } conn := <-connCh defer conn.Close() - GetConnection = func(context.Context) net.Conn { + getConnection = func(context.Context) net.Conn { return conn } ctx = peer.NewContext(ctx, data.rpcData.peerInfo) From 1332c85997864d8218e8adda93a6ff90a32c9d5c Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Tue, 24 Aug 2021 18:35:55 -0700 Subject: [PATCH 05/14] Debugging Streaming call error --- authz/sdk_end2end_test.go | 15 ++++++++------- authz/sdk_server_interceptors.go | 12 ++++++------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index ad312ccc0a6..86fb2b836d1 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -29,6 +29,7 @@ import ( "io" "net" "testing" + "time" "google.golang.org/grpc" "google.golang.org/grpc/authz" @@ -71,11 +72,10 @@ func TestSDKEnd2End(t *testing.T) { "allow_rules": [ { - "name": "allow_TestServiceCalls", + "name": "allow_StreamingOutputCall", "request": { "paths": [ - "/grpc.testing.TestService/UnaryCall", "/grpc.testing.TestService/StreamingOutputCall" ] } @@ -286,7 +286,7 @@ func TestSDKEnd2End(t *testing.T) { defer clientConn.Close() client := pb.NewTestServiceClient(clientConn) - // Verifying Unary RPC + // Verifying Unary RPC. unaryCtx := metadata.NewOutgoingContext(context.Background(), test.md) _, err = client.UnaryCall(unaryCtx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) gotStatus, _ := status.FromError(err) @@ -297,14 +297,15 @@ func TestSDKEnd2End(t *testing.T) { t.Fatalf("[UnaryCall] error message want:%v got:%v", test.wantErr, gotStatus.Message()) } - // Verifying Streaming RPC - streamCtx := metadata.NewOutgoingContext(context.Background(), test.md) + // Verifying Streaming RPC. + streamCtx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + streamCtx = metadata.NewOutgoingContext(streamCtx, test.md) stream, err := client.StreamingInputCall(streamCtx, grpc.WaitForReady(true)) if err != nil { t.Fatalf("Failed StreamingInputCall err:%v", err) } - req := &pb.StreamingInputCallRequest{} - if err := stream.Send(req); err != nil { + if err := stream.Send(&pb.StreamingInputCallRequest{Payload: &pb.Payload{Body: []byte("hi")}}); err != nil { t.Fatalf("stream.Send failed err: %v", err) } _, err = stream.CloseAndRecv() diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go index ab1359127b5..e8794cd072d 100644 --- a/authz/sdk_server_interceptors.go +++ b/authz/sdk_server_interceptors.go @@ -51,11 +51,11 @@ func NewStatic(authzPolicy string) (*StaticInterceptor, error) { // error is returned to the client. func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { err := i.engines.IsAuthorized(ctx) - if status.Code(err) == codes.InvalidArgument { - return nil, err + if status.Code(err) == codes.PermissionDenied { + return nil, status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") } if status.Code(err) != codes.OK { - return nil, status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") + return nil, err } return handler(ctx, req) } @@ -65,11 +65,11 @@ func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{ // error is returned to the client. func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { err := i.engines.IsAuthorized(ss.Context()) - if status.Code(err) == codes.InvalidArgument { - return err + if status.Code(err) == codes.PermissionDenied { + return status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") } if status.Code(err) != codes.OK { - return status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") + return err } return handler(srv, ss) } From 89848a1e9ed2a7753211a8a25dd02ee357988e90 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Tue, 24 Aug 2021 21:29:53 -0700 Subject: [PATCH 06/14] Using single context. --- authz/sdk_end2end_test.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index 86fb2b836d1..dd57e63ae5e 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -29,7 +29,6 @@ import ( "io" "net" "testing" - "time" "google.golang.org/grpc" "google.golang.org/grpc/authz" @@ -286,9 +285,10 @@ func TestSDKEnd2End(t *testing.T) { defer clientConn.Close() client := pb.NewTestServiceClient(clientConn) + ctx := metadata.NewOutgoingContext(context.Background(), test.md) + // Verifying Unary RPC. - unaryCtx := metadata.NewOutgoingContext(context.Background(), test.md) - _, err = client.UnaryCall(unaryCtx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) gotStatus, _ := status.FromError(err) if gotStatus.Code() != test.wantStatusCode { t.Fatalf("[UnaryCall] status code want:%v got:%v", test.wantStatusCode, gotStatus.Code()) @@ -298,10 +298,7 @@ func TestSDKEnd2End(t *testing.T) { } // Verifying Streaming RPC. - streamCtx, cancel := context.WithTimeout(context.Background(), time.Second*2) - defer cancel() - streamCtx = metadata.NewOutgoingContext(streamCtx, test.md) - stream, err := client.StreamingInputCall(streamCtx, grpc.WaitForReady(true)) + stream, err := client.StreamingInputCall(ctx, grpc.WaitForReady(true)) if err != nil { t.Fatalf("Failed StreamingInputCall err:%v", err) } From 3ca35f8b02948be9f1f5d8729efcafc703958af9 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Tue, 24 Aug 2021 22:25:36 -0700 Subject: [PATCH 07/14] Adding comments to test. --- authz/sdk_end2end_test.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index dd57e63ae5e..7964a438a08 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -260,7 +260,7 @@ func TestSDKEnd2End(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - + // Start a gRPC server with SDK unary and stream server interceptors. i, _ := authz.NewStatic(test.authzPolicy) serverOpts := []grpc.ServerOption{ grpc.ChainUnaryInterceptor(i.UnaryInterceptor), @@ -274,6 +274,7 @@ func TestSDKEnd2End(t *testing.T) { pb.RegisterTestServiceServer(s, &testServer{}) go s.Serve(lis) + // Establish a connection to the server. dialOptions := []grpc.DialOption{ grpc.WithInsecure(), grpc.WithBlock(), @@ -287,7 +288,7 @@ func TestSDKEnd2End(t *testing.T) { ctx := metadata.NewOutgoingContext(context.Background(), test.md) - // Verifying Unary RPC. + // Verifying authorization decision for Unary RPC. _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) gotStatus, _ := status.FromError(err) if gotStatus.Code() != test.wantStatusCode { @@ -297,13 +298,18 @@ func TestSDKEnd2End(t *testing.T) { t.Fatalf("[UnaryCall] error message want:%v got:%v", test.wantErr, gotStatus.Message()) } - // Verifying Streaming RPC. + // Verifying authorization decision for Streaming RPC. stream, err := client.StreamingInputCall(ctx, grpc.WaitForReady(true)) if err != nil { - t.Fatalf("Failed StreamingInputCall err:%v", err) + t.Fatalf("failed StreamingInputCall err: %v", err) + } + req := &pb.StreamingInputCallRequest{ + Payload: &pb.Payload{ + Body: []byte("hi"), + }, } - if err := stream.Send(&pb.StreamingInputCallRequest{Payload: &pb.Payload{Body: []byte("hi")}}); err != nil { - t.Fatalf("stream.Send failed err: %v", err) + if err := stream.Send(req); err != nil { + t.Fatalf("failed stream.Send err: %v", err) } _, err = stream.CloseAndRecv() gotStatus, _ = status.FromError(err) From 49de13b77cf46da89a66f39b9fea33bad9af2b46 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Wed, 25 Aug 2021 10:29:39 -0700 Subject: [PATCH 08/14] Adds check for nil error --- authz/sdk_server_interceptors.go | 16 ++++++++-------- internal/xds/rbac/rbac_engine.go | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go index e8794cd072d..9c3142574ea 100644 --- a/authz/sdk_server_interceptors.go +++ b/authz/sdk_server_interceptors.go @@ -51,13 +51,13 @@ func NewStatic(authzPolicy string) (*StaticInterceptor, error) { // error is returned to the client. func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { err := i.engines.IsAuthorized(ctx) + if err == nil { + return handler(ctx, req) + } if status.Code(err) == codes.PermissionDenied { return nil, status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") } - if status.Code(err) != codes.OK { - return nil, err - } - return handler(ctx, req) + return nil, err } // StreamInterceptor intercepts incoming Stream RPC requests. @@ -65,11 +65,11 @@ func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{ // error is returned to the client. func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { err := i.engines.IsAuthorized(ss.Context()) + if err == nil { + return handler(srv, ss) + } if status.Code(err) == codes.PermissionDenied { return status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") } - if status.Code(err) != codes.OK { - return err - } - return handler(srv, ss) + return err } diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index f08228f7791..59904525243 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -86,7 +86,7 @@ func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { // If the incoming RPC gets through all of the engines successfully (i.e. // doesn't not match an allow or match a deny engine), the RPC is authorized // to proceed. - return status.Error(codes.OK, "") + return nil } // engine is used for matching incoming RPCs to policies. From 6a2839e17ca61e85a9c398296563632400698804 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Thu, 26 Aug 2021 15:19:04 -0700 Subject: [PATCH 09/14] Resolving comments. --- authz/rbac_translator.go | 5 ++- authz/sdk_end2end_test.go | 53 ++++++++++---------------------- authz/sdk_server_interceptors.go | 4 +-- 3 files changed, 20 insertions(+), 42 deletions(-) diff --git a/authz/rbac_translator.go b/authz/rbac_translator.go index 821a28d0262..27bd658e36f 100644 --- a/authz/rbac_translator.go +++ b/authz/rbac_translator.go @@ -285,7 +285,7 @@ func translatePolicy(policyStr string) ([]*v3rbacpb.RBAC, error) { if len(policy.AllowRules) == 0 { return nil, fmt.Errorf(`"allow_rules" is not present`) } - var rbacs []*v3rbacpb.RBAC + rbacs := make([]*v3rbacpb.RBAC, 0, 2) if len(policy.DenyRules) > 0 { denyPolicies, err := parseRules(policy.DenyRules, policy.Name) if err != nil { @@ -302,6 +302,5 @@ func translatePolicy(policyStr string) ([]*v3rbacpb.RBAC, error) { return nil, fmt.Errorf(`"allow_rules" %v`, err) } allowRBAC := &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: allowPolicies} - rbacs = append(rbacs, allowRBAC) - return rbacs, nil + return append(rbacs, allowRBAC), nil } diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index 7964a438a08..ea25870a328 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -16,12 +16,6 @@ * */ -// Package authz_test contains tests for authz. -// -// Experimental -// -// Notice: This package is EXPERIMENTAL and may be changed or removed -// in a later release. package authz_test import ( @@ -107,7 +101,7 @@ func TestSDKEnd2End(t *testing.T) { }`, md: metadata.Pairs("key-abc", "val-abc"), wantStatusCode: codes.PermissionDenied, - wantErr: "Unauthorized RPC request rejected.", + wantErr: "unauthorized RPC request rejected", }, "DeniesRpcRequestMatchInDenyAndAllow": { authzPolicy: `{ @@ -138,7 +132,7 @@ func TestSDKEnd2End(t *testing.T) { ] }`, wantStatusCode: codes.PermissionDenied, - wantErr: "Unauthorized RPC request rejected.", + wantErr: "unauthorized RPC request rejected", }, "AllowsRpcRequestNoMatchInDenyMatchInAllow": { authzPolicy: `{ @@ -206,7 +200,7 @@ func TestSDKEnd2End(t *testing.T) { ] }`, wantStatusCode: codes.PermissionDenied, - wantErr: "Unauthorized RPC request rejected.", + wantErr: "unauthorized RPC request rejected", }, "AllowsRpcRequestEmptyDenyMatchInAllow": { authzPolicy: `{ @@ -255,33 +249,27 @@ func TestSDKEnd2End(t *testing.T) { ] }`, wantStatusCode: codes.PermissionDenied, - wantErr: "Unauthorized RPC request rejected.", + wantErr: "unauthorized RPC request rejected", }, } for name, test := range tests { t.Run(name, func(t *testing.T) { // Start a gRPC server with SDK unary and stream server interceptors. i, _ := authz.NewStatic(test.authzPolicy) - serverOpts := []grpc.ServerOption{ - grpc.ChainUnaryInterceptor(i.UnaryInterceptor), - grpc.ChainStreamInterceptor(i.StreamInterceptor), - } - lis, err := net.Listen("tcp", ":0") + lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } - s := grpc.NewServer(serverOpts...) + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)) pb.RegisterTestServiceServer(s, &testServer{}) go s.Serve(lis) // Establish a connection to the server. - dialOptions := []grpc.DialOption{ - grpc.WithInsecure(), - grpc.WithBlock(), - } - clientConn, err := grpc.Dial(lis.Addr().String(), dialOptions...) + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) if err != nil { - t.Fatalf("grpc.Dial(%v, %v) failed: %v", lis.Addr().String(), dialOptions, err) + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) } defer clientConn.Close() client := pb.NewTestServiceClient(clientConn) @@ -289,17 +277,13 @@ func TestSDKEnd2End(t *testing.T) { ctx := metadata.NewOutgoingContext(context.Background(), test.md) // Verifying authorization decision for Unary RPC. - _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}, grpc.WaitForReady(true)) - gotStatus, _ := status.FromError(err) - if gotStatus.Code() != test.wantStatusCode { - t.Fatalf("[UnaryCall] status code want:%v got:%v", test.wantStatusCode, gotStatus.Code()) - } - if gotStatus.Message() != test.wantErr { - t.Fatalf("[UnaryCall] error message want:%v got:%v", test.wantErr, gotStatus.Message()) + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr { + t.Fatalf("[UnaryCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message()) } // Verifying authorization decision for Streaming RPC. - stream, err := client.StreamingInputCall(ctx, grpc.WaitForReady(true)) + stream, err := client.StreamingInputCall(ctx) if err != nil { t.Fatalf("failed StreamingInputCall err: %v", err) } @@ -312,14 +296,9 @@ func TestSDKEnd2End(t *testing.T) { t.Fatalf("failed stream.Send err: %v", err) } _, err = stream.CloseAndRecv() - gotStatus, _ = status.FromError(err) - if gotStatus.Code() != test.wantStatusCode { - t.Fatalf("[StreamingCall] status code want:%v got:%v", test.wantStatusCode, gotStatus.Code()) - } - if gotStatus.Message() != test.wantErr { - t.Fatalf("[StreamingCall] error message want:%v got:%v", test.wantErr, gotStatus.Message()) + if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr { + t.Fatalf("[StreamingCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message()) } - }) } } diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go index 9c3142574ea..a408fdf3e8d 100644 --- a/authz/sdk_server_interceptors.go +++ b/authz/sdk_server_interceptors.go @@ -55,7 +55,7 @@ func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{ return handler(ctx, req) } if status.Code(err) == codes.PermissionDenied { - return nil, status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") + return nil, status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") } return nil, err } @@ -69,7 +69,7 @@ func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStr return handler(srv, ss) } if status.Code(err) == codes.PermissionDenied { - return status.Errorf(codes.PermissionDenied, "Unauthorized RPC request rejected.") + return status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") } return err } From 61a51d7b7a99085cc08d8a01ea25eefca5f2024c Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Fri, 27 Aug 2021 00:54:28 -0700 Subject: [PATCH 10/14] Change status code to internal. --- internal/xds/rbac/rbac_engine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index 59904525243..a513fd3be1d 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -69,7 +69,7 @@ func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { // and then be used for the whole chain of RBAC Engines. rpcData, err := newRPCData(ctx) if err != nil { - return status.Errorf(codes.InvalidArgument, "missing fields in ctx %+v: %v", ctx, err) + return status.Errorf(codes.Internal, "missing fields in ctx %+v: %v", ctx, err) } for _, engine := range cre.chainedEngines { matchingPolicyName, ok := engine.findMatchingPolicy(rpcData) From 15829b282e13c31285439d2c8c552d93cb23d93f Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Fri, 27 Aug 2021 10:01:20 -0700 Subject: [PATCH 11/14] Add deadline to context. --- authz/sdk_end2end_test.go | 5 ++++- authz/sdk_server_interceptors.go | 24 ++++++++++++------------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index ea25870a328..96ec20b01f7 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -23,6 +23,7 @@ import ( "io" "net" "testing" + "time" "google.golang.org/grpc" "google.golang.org/grpc/authz" @@ -274,7 +275,9 @@ func TestSDKEnd2End(t *testing.T) { defer clientConn.Close() client := pb.NewTestServiceClient(clientConn) - ctx := metadata.NewOutgoingContext(context.Background(), test.md) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, test.md) // Verifying authorization decision for Unary RPC. _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go index a408fdf3e8d..a2f992b5f26 100644 --- a/authz/sdk_server_interceptors.go +++ b/authz/sdk_server_interceptors.go @@ -51,13 +51,13 @@ func NewStatic(authzPolicy string) (*StaticInterceptor, error) { // error is returned to the client. func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { err := i.engines.IsAuthorized(ctx) - if err == nil { - return handler(ctx, req) - } - if status.Code(err) == codes.PermissionDenied { - return nil, status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + if err != nil { + if status.Code(err) == codes.PermissionDenied { + return nil, status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + } + return nil, err } - return nil, err + return handler(ctx, req) } // StreamInterceptor intercepts incoming Stream RPC requests. @@ -65,11 +65,11 @@ func (i *StaticInterceptor) UnaryInterceptor(ctx context.Context, req interface{ // error is returned to the client. func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { err := i.engines.IsAuthorized(ss.Context()) - if err == nil { - return handler(srv, ss) - } - if status.Code(err) == codes.PermissionDenied { - return status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + if err != nil { + if status.Code(err) == codes.PermissionDenied { + return status.Errorf(codes.PermissionDenied, "unauthorized RPC request rejected") + } + return err } - return err + return handler(srv, ss) } From 6385ef241f8fcf1da8af7da3a12a49acf80ab5f7 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Wed, 1 Sep 2021 11:09:52 -0700 Subject: [PATCH 12/14] Log error on missing fields in context. --- internal/xds/rbac/rbac_engine.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index a513fd3be1d..8732e146dbf 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -32,12 +32,15 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) +var logger = grpclog.Component("rbac") + var getConnection = transport.GetConnection // ChainEngine represents a chain of RBAC Engines, used to make authorization @@ -69,6 +72,7 @@ func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { // and then be used for the whole chain of RBAC Engines. rpcData, err := newRPCData(ctx) if err != nil { + logger.Errorf("missing fields in ctx %+v: %v", ctx, err) return status.Errorf(codes.Internal, "missing fields in ctx %+v: %v", ctx, err) } for _, engine := range cre.chainedEngines { From 6da894d6c4f478f52092869f1da4b74554df54d8 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Wed, 1 Sep 2021 13:06:29 -0700 Subject: [PATCH 13/14] Check for eof in stream.Send. Expected when server returns error. --- authz/sdk_end2end_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index 96ec20b01f7..92a5e4f4b21 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -295,7 +295,7 @@ func TestSDKEnd2End(t *testing.T) { Body: []byte("hi"), }, } - if err := stream.Send(req); err != nil { + if err := stream.Send(req); err != nil && err != io.EOF { t.Fatalf("failed stream.Send err: %v", err) } _, err = stream.CloseAndRecv() From 5b5db7ccfb123ebbd993df18938956eea33ea7c6 Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh Date: Wed, 1 Sep 2021 16:22:34 -0700 Subject: [PATCH 14/14] Update logs. --- internal/xds/rbac/rbac_engine.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go index 8732e146dbf..9e1eb16077e 100644 --- a/internal/xds/rbac/rbac_engine.go +++ b/internal/xds/rbac/rbac_engine.go @@ -72,8 +72,8 @@ func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { // and then be used for the whole chain of RBAC Engines. rpcData, err := newRPCData(ctx) if err != nil { - logger.Errorf("missing fields in ctx %+v: %v", ctx, err) - return status.Errorf(codes.Internal, "missing fields in ctx %+v: %v", ctx, err) + logger.Errorf("newRPCData: %v", err) + return status.Errorf(codes.Internal, "gRPC RBAC: %v", err) } for _, engine := range cre.chainedEngines { matchingPolicyName, ok := engine.findMatchingPolicy(rpcData)