From b99d1040b71caf9b22be570edb68d85dcb6c515c Mon Sep 17 00:00:00 2001 From: Ashitha Santhosh <55257063+ashithasantosh@users.noreply.github.com> Date: Fri, 8 Oct 2021 17:09:55 -0700 Subject: [PATCH] authz: create file watcher interceptor for gRPC SDK API (#4760) * authz: create file watcher interceptor for gRPC SDK API --- authz/sdk_end2end_test.go | 327 ++++++++++++++++++++++---- authz/sdk_server_interceptors.go | 97 ++++++++ authz/sdk_server_interceptors_test.go | 75 +++++- 3 files changed, 451 insertions(+), 48 deletions(-) diff --git a/authz/sdk_end2end_test.go b/authz/sdk_end2end_test.go index 92a5e4f4b21..093b2bb437d 100644 --- a/authz/sdk_end2end_test.go +++ b/authz/sdk_end2end_test.go @@ -21,13 +21,16 @@ package authz_test import ( "context" "io" + "io/ioutil" "net" + "os" "testing" "time" "google.golang.org/grpc" "google.golang.org/grpc/authz" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" pb "google.golang.org/grpc/test/grpc_testing" @@ -53,15 +56,21 @@ func (s *testServer) StreamingInputCall(stream pb.TestService_StreamingInputCall } } -func TestSDKEnd2End(t *testing.T) { - tests := map[string]struct { - authzPolicy string - md metadata.MD - wantStatusCode codes.Code - wantErr string - }{ - "DeniesRpcRequestMatchInDenyNoMatchInAllow": { - authzPolicy: `{ +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +var sdkTests = map[string]struct { + authzPolicy string + md metadata.MD + wantStatus *status.Status +}{ + "DeniesRpcMatchInDenyNoMatchInAllow": { + authzPolicy: `{ "name": "authz", "allow_rules": [ @@ -100,12 +109,11 @@ func TestSDKEnd2End(t *testing.T) { } ] }`, - md: metadata.Pairs("key-abc", "val-abc"), - wantStatusCode: codes.PermissionDenied, - wantErr: "unauthorized RPC request rejected", - }, - "DeniesRpcRequestMatchInDenyAndAllow": { - authzPolicy: `{ + md: metadata.Pairs("key-abc", "val-abc"), + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "DeniesRpcMatchInDenyAndAllow": { + authzPolicy: `{ "name": "authz", "allow_rules": [ @@ -132,11 +140,10 @@ func TestSDKEnd2End(t *testing.T) { } ] }`, - wantStatusCode: codes.PermissionDenied, - wantErr: "unauthorized RPC request rejected", - }, - "AllowsRpcRequestNoMatchInDenyMatchInAllow": { - authzPolicy: `{ + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "AllowsRpcNoMatchInDenyMatchInAllow": { + authzPolicy: `{ "name": "authz", "allow_rules": [ @@ -169,11 +176,11 @@ func TestSDKEnd2End(t *testing.T) { } ] }`, - md: metadata.Pairs("key-xyz", "val-xyz"), - wantStatusCode: codes.OK, - }, - "AllowsRpcRequestNoMatchInDenyAndAllow": { - authzPolicy: `{ + md: metadata.Pairs("key-xyz", "val-xyz"), + wantStatus: status.New(codes.OK, ""), + }, + "DeniesRpcNoMatchInDenyAndAllow": { + authzPolicy: `{ "name": "authz", "allow_rules": [ @@ -200,11 +207,10 @@ func TestSDKEnd2End(t *testing.T) { } ] }`, - wantStatusCode: codes.PermissionDenied, - wantErr: "unauthorized RPC request rejected", - }, - "AllowsRpcRequestEmptyDenyMatchInAllow": { - authzPolicy: `{ + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, + "AllowsRpcEmptyDenyMatchInAllow": { + authzPolicy: `{ "name": "authz", "allow_rules": [ @@ -230,10 +236,10 @@ func TestSDKEnd2End(t *testing.T) { } ] }`, - wantStatusCode: codes.OK, - }, - "DeniesRpcRequestEmptyDenyNoMatchInAllow": { - authzPolicy: `{ + wantStatus: status.New(codes.OK, ""), + }, + "DeniesRpcEmptyDenyNoMatchInAllow": { + authzPolicy: `{ "name": "authz", "allow_rules": [ @@ -249,22 +255,85 @@ func TestSDKEnd2End(t *testing.T) { } ] }`, - wantStatusCode: codes.PermissionDenied, - wantErr: "unauthorized RPC request rejected", - }, - } - for name, test := range tests { + wantStatus: status.New(codes.PermissionDenied, "unauthorized RPC request rejected"), + }, +} + +func (s) TestSDKStaticPolicyEnd2End(t *testing.T) { + for name, test := range sdkTests { t.Run(name, func(t *testing.T) { // Start a gRPC server with SDK unary and stream server interceptors. i, _ := authz.NewStatic(test.authzPolicy) + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor), + grpc.ChainStreamInterceptor(i.StreamInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("error listening: %v", err) } + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, test.md) + + // Verifying authorization decision for Unary RPC. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + + // Verifying authorization decision for Streaming RPC. + stream, err := client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf("failed StreamingInputCall err: %v", err) + } + req := &pb.StreamingInputCallRequest{ + Payload: &pb.Payload{ + Body: []byte("hi"), + }, + } + if err := stream.Send(req); err != nil && err != io.EOF { + t.Fatalf("failed stream.Send err: %v", err) + } + _, err = stream.CloseAndRecv() + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) + } + }) + } +} + +func (s) TestSDKFileWatcherEnd2End(t *testing.T) { + for name, test := range sdkTests { + t.Run(name, func(t *testing.T) { + file := createTmpPolicyFile(t, name, []byte(test.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 1*time.Second) + defer i.Close() + + // Start a gRPC server with SDK unary and stream server interceptors. s := grpc.NewServer( grpc.ChainUnaryInterceptor(i.UnaryInterceptor), grpc.ChainStreamInterceptor(i.StreamInterceptor)) + defer s.Stop() pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() go s.Serve(lis) // Establish a connection to the server. @@ -281,8 +350,8 @@ func TestSDKEnd2End(t *testing.T) { // Verifying authorization decision for Unary RPC. _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) - if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr { - t.Fatalf("[UnaryCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message()) + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[UnaryCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) } // Verifying authorization decision for Streaming RPC. @@ -299,9 +368,181 @@ func TestSDKEnd2End(t *testing.T) { t.Fatalf("failed stream.Send err: %v", err) } _, err = stream.CloseAndRecv() - if got := status.Convert(err); got.Code() != test.wantStatusCode || got.Message() != test.wantErr { - t.Fatalf("[StreamingCall] error want:{%v %v} got:{%v %v}", test.wantStatusCode, test.wantErr, got.Code(), got.Message()) + if got := status.Convert(err); got.Code() != test.wantStatus.Code() || got.Message() != test.wantStatus.Message() { + t.Fatalf("[StreamingCall] error want:{%v} got:{%v}", test.wantStatus.Err(), got.Err()) } }) } } + +func retryUntil(ctx context.Context, tsc pb.TestServiceClient, want *status.Status) (lastErr error) { + for ctx.Err() == nil { + _, lastErr = tsc.UnaryCall(ctx, &pb.SimpleRequest{}) + if s := status.Convert(lastErr); s.Code() == want.Code() && s.Message() == want.Message() { + return nil + } + time.Sleep(20 * time.Millisecond) + } + return lastErr +} + +func (s) TestSDKFileWatcher_ValidPolicyRefresh(t *testing.T) { + valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "valid_policy_refresh", []byte(valid1.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptor. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Rewrite the file with a different valid authorization policy. + valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"] + if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Verifying authorization decision. + if got := retryUntil(ctx, client, valid2.wantStatus); got != nil { + t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got) + } +} + +func (s) TestSDKFileWatcher_InvalidPolicySkipReload(t *testing.T) { + valid := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "invalid_policy_skip_reload", []byte(valid.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 20*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err()) + } + + // Skips the invalid policy update, and continues to use the valid policy. + if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Wait 40 ms for background go routine to read updated files. + time.Sleep(40 * time.Millisecond) + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid.wantStatus.Code() || got.Message() != valid.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid.wantStatus.Err(), got.Err()) + } +} + +func (s) TestSDKFileWatcher_RecoversFromReloadFailure(t *testing.T) { + valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"] + file := createTmpPolicyFile(t, "recovers_from_reload_failure", []byte(valid1.authzPolicy)) + i, _ := authz.NewFileWatcher(file, 100*time.Millisecond) + defer i.Close() + + // Start a gRPC server with SDK unary server interceptors. + s := grpc.NewServer( + grpc.ChainUnaryInterceptor(i.UnaryInterceptor)) + defer s.Stop() + pb.RegisterTestServiceServer(s, &testServer{}) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("error listening: %v", err) + } + defer lis.Close() + go s.Serve(lis) + + // Establish a connection to the server. + clientConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%v) failed: %v", lis.Addr().String(), err) + } + defer clientConn.Close() + client := pb.NewTestServiceClient(clientConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Skips the invalid policy update, and continues to use the valid policy. + if err := ioutil.WriteFile(file, []byte("{}"), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Wait 120 ms for background go routine to read updated files. + time.Sleep(120 * time.Millisecond) + + // Verifying authorization decision. + _, err = client.UnaryCall(ctx, &pb.SimpleRequest{}) + if got := status.Convert(err); got.Code() != valid1.wantStatus.Code() || got.Message() != valid1.wantStatus.Message() { + t.Fatalf("error want:{%v} got:{%v}", valid1.wantStatus.Err(), got.Err()) + } + + // Rewrite the file with a different valid authorization policy. + valid2 := sdkTests["AllowsRpcEmptyDenyMatchInAllow"] + if err := ioutil.WriteFile(file, []byte(valid2.authzPolicy), os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", file, err) + } + + // Verifying authorization decision. + if got := retryUntil(ctx, client, valid2.wantStatus); got != nil { + t.Fatalf("error want:{%v} got:{%v}", valid2.wantStatus.Err(), got) + } +} diff --git a/authz/sdk_server_interceptors.go b/authz/sdk_server_interceptors.go index a2f992b5f26..72dc14ed85e 100644 --- a/authz/sdk_server_interceptors.go +++ b/authz/sdk_server_interceptors.go @@ -17,14 +17,23 @@ package authz import ( + "bytes" "context" + "fmt" + "io/ioutil" + "sync/atomic" + "time" + "unsafe" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/xds/rbac" "google.golang.org/grpc/status" ) +var logger = grpclog.Component("authz") + // StaticInterceptor contains engines used to make authorization decisions. It // either contains two engines deny engine followed by an allow engine or only // one allow engine. @@ -73,3 +82,91 @@ func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStr } return handler(srv, ss) } + +// FileWatcherInterceptor contains details used to make authorization decisions +// by watching a file path that contains authorization policy in JSON format. +type FileWatcherInterceptor struct { + internalInterceptor unsafe.Pointer // *StaticInterceptor + policyFile string + policyContents []byte + refreshDuration time.Duration + cancel context.CancelFunc +} + +// NewFileWatcher returns a new FileWatcherInterceptor from a policy file +// that contains JSON string of authorization policy and a refresh duration to +// specify the amount of time between policy refreshes. +func NewFileWatcher(file string, duration time.Duration) (*FileWatcherInterceptor, error) { + if file == "" { + return nil, fmt.Errorf("authorization policy file path is empty") + } + if duration <= time.Duration(0) { + return nil, fmt.Errorf("requires refresh interval(%v) greater than 0s", duration) + } + i := &FileWatcherInterceptor{policyFile: file, refreshDuration: duration} + if err := i.updateInternalInterceptor(); err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + i.cancel = cancel + // Create a background go routine for policy refresh. + go i.run(ctx) + return i, nil +} + +func (i *FileWatcherInterceptor) run(ctx context.Context) { + ticker := time.NewTicker(i.refreshDuration) + for { + if err := i.updateInternalInterceptor(); err != nil { + logger.Warningf("authorization policy reload status err: %v", err) + } + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + } + } +} + +// updateInternalInterceptor checks if the policy file that is watching has changed, +// and if so, updates the internalInterceptor with the policy. Unlike the +// constructor, if there is an error in reading the file or parsing the policy, the +// previous internalInterceptors will not be replaced. +func (i *FileWatcherInterceptor) updateInternalInterceptor() error { + policyContents, err := ioutil.ReadFile(i.policyFile) + if err != nil { + return fmt.Errorf("policyFile(%s) read failed: %v", i.policyFile, err) + } + if bytes.Equal(i.policyContents, policyContents) { + return nil + } + i.policyContents = policyContents + policyContentsString := string(policyContents) + interceptor, err := NewStatic(policyContentsString) + if err != nil { + return err + } + atomic.StorePointer(&i.internalInterceptor, unsafe.Pointer(interceptor)) + logger.Infof("authorization policy reload status: successfully loaded new policy %v", policyContentsString) + return nil +} + +// Close cleans up resources allocated by the interceptor. +func (i *FileWatcherInterceptor) Close() { + i.cancel() +} + +// UnaryInterceptor intercepts incoming Unary RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *FileWatcherInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).UnaryInterceptor(ctx, req, info, handler) +} + +// StreamInterceptor intercepts incoming Stream RPC requests. +// Only authorized requests are allowed to pass. Otherwise, an unauthorized +// error is returned to the client. +func (i *FileWatcherInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return ((*StaticInterceptor)(atomic.LoadPointer(&i.internalInterceptor))).StreamInterceptor(srv, ss, info, handler) +} diff --git a/authz/sdk_server_interceptors_test.go b/authz/sdk_server_interceptors_test.go index e2c1072e7d8..f43f9807612 100644 --- a/authz/sdk_server_interceptors_test.go +++ b/authz/sdk_server_interceptors_test.go @@ -19,19 +19,43 @@ package authz_test import ( + "fmt" + "io/ioutil" + "os" + "path" "testing" + "time" "google.golang.org/grpc/authz" ) -func TestNewStatic(t *testing.T) { +func createTmpPolicyFile(t *testing.T, dirSuffix string, policy []byte) string { + t.Helper() + + // Create a temp directory. Passing an empty string for the first argument + // uses the system temp directory. + dir, err := ioutil.TempDir("", dirSuffix) + if err != nil { + t.Fatalf("ioutil.TempDir() failed: %v", err) + } + t.Logf("Using tmpdir: %s", dir) + // Write policy into file. + filename := path.Join(dir, "policy.json") + if err := ioutil.WriteFile(filename, policy, os.ModePerm); err != nil { + t.Fatalf("ioutil.WriteFile(%q) failed: %v", filename, err) + } + t.Logf("Wrote policy %s to file at %s", string(policy), filename) + return filename +} + +func (s) TestNewStatic(t *testing.T) { tests := map[string]struct { authzPolicy string - wantErr bool + wantErr error }{ "InvalidPolicyFailsToCreateInterceptor": { authzPolicy: `{}`, - wantErr: true, + wantErr: fmt.Errorf(`"name" is not present`), }, "ValidPolicyCreatesInterceptor": { authzPolicy: `{ @@ -43,14 +67,55 @@ func TestNewStatic(t *testing.T) { } ] }`, - wantErr: false, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { - if _, err := authz.NewStatic(test.authzPolicy); (err != nil) != test.wantErr { + if _, err := authz.NewStatic(test.authzPolicy); fmt.Sprint(err) != fmt.Sprint(test.wantErr) { t.Fatalf("NewStatic(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) } }) } } + +func (s) TestNewFileWatcher(t *testing.T) { + tests := map[string]struct { + authzPolicy string + refreshDuration time.Duration + wantErr error + }{ + "InvalidRefreshDurationFailsToCreateInterceptor": { + refreshDuration: time.Duration(0), + wantErr: fmt.Errorf("requires refresh interval(0s) greater than 0s"), + }, + "InvalidPolicyFailsToCreateInterceptor": { + authzPolicy: `{}`, + refreshDuration: time.Duration(1), + wantErr: fmt.Errorf(`"name" is not present`), + }, + "ValidPolicyCreatesInterceptor": { + authzPolicy: `{ + "name": "authz", + "allow_rules": + [ + { + "name": "allow_all" + } + ] + }`, + refreshDuration: time.Duration(1), + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + file := createTmpPolicyFile(t, name, []byte(test.authzPolicy)) + i, err := authz.NewFileWatcher(file, test.refreshDuration) + if fmt.Sprint(err) != fmt.Sprint(test.wantErr) { + t.Fatalf("NewFileWatcher(%v) returned err: %v, want err: %v", test.authzPolicy, err, test.wantErr) + } + if i != nil { + i.Close() + } + }) + } +}