Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

authz: create file watcher interceptor for gRPC SDK API #4760

Merged
merged 16 commits into from Oct 9, 2021
Merged
76 changes: 32 additions & 44 deletions authz/sdk_end2end_test.go
Expand Up @@ -379,6 +379,30 @@ func (s) TestSDKFileWatcherEnd2End(t *testing.T) {
}
}

func verifyValidReload(ctx context.Context, tsc pb.TestServiceClient, wantCode codes.Code, wantErr string) (lastStatus *status.Status) {
for numRetries := 0; numRetries <= 20; numRetries++ {
ashithasantosh marked this conversation as resolved.
Show resolved Hide resolved
_, err := tsc.UnaryCall(ctx, &pb.SimpleRequest{})
if lastStatus = status.Convert(err); lastStatus.Code() == wantCode && lastStatus.Message() == wantErr {
return nil
}
time.Sleep(20 * time.Millisecond)
numRetries++
}
return lastStatus
ashithasantosh marked this conversation as resolved.
Show resolved Hide resolved
}

func verifySkipReload(ctx context.Context, tsc pb.TestServiceClient, wantCode codes.Code, wantErr string) (lastStatus *status.Status) {
for numRetries := 0; numRetries <= 20; numRetries++ {
_, err := tsc.UnaryCall(ctx, &pb.SimpleRequest{})
if lastStatus := status.Convert(err); lastStatus.Code() != wantCode || lastStatus.Message() != wantErr {
ashithasantosh marked this conversation as resolved.
Show resolved Hide resolved
return lastStatus
}
time.Sleep(20 * time.Millisecond)
numRetries++
}
return nil
}

func (s) TestSDKFileWatcher_ValidPolicyRefresh(t *testing.T) {
valid1 := sdkTests["DeniesRpcMatchInDenyAndAllow"]
file := createTmpPolicyFile(t, "valid_policy_refresh", []byte(valid1.authzPolicy))
Expand Down Expand Up @@ -422,20 +446,8 @@ func (s) TestSDKFileWatcher_ValidPolicyRefresh(t *testing.T) {
}

// Verifying authorization decision.
numRetries := 0
reloadSuccess := false
var gotStatus *status.Status
for numRetries <= 20 {
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if gotStatus = status.Convert(err); gotStatus.Code() == valid2.wantStatusCode && gotStatus.Message() == valid2.wantErr {
reloadSuccess = true
break
}
time.Sleep(100 * time.Millisecond)
numRetries++
}
if reloadSuccess == false {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid2.wantStatusCode, valid2.wantErr, gotStatus.Code(), gotStatus.Message())
if got := verifyValidReload(ctx, client, valid2.wantStatusCode, valid2.wantErr); got != nil {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid2.wantStatusCode, valid2.wantErr, got.Code(), got.Message())
ashithasantosh marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -481,14 +493,8 @@ func (s) TestSDKFileWatcher_InvalidPolicySkipReload(t *testing.T) {
}

// Verifying authorization decision.
numRetries := 0
for numRetries <= 20 {
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != valid.wantStatusCode || got.Message() != valid.wantErr {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid.wantStatusCode, valid.wantErr, got.Code(), got.Message())
}
time.Sleep(100 * time.Millisecond)
numRetries++
if got := verifySkipReload(ctx, client, valid.wantStatusCode, valid.wantErr); got != nil {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid.wantStatusCode, valid.wantErr, got.Code(), got.Message())
}
}

Expand Down Expand Up @@ -534,14 +540,8 @@ func (s) TestSDKFileWatcher_RecoversFromReloadFailure(t *testing.T) {
}

// Verifying authorization decision.
numRetries := 0
for numRetries <= 20 {
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if got := status.Convert(err); got.Code() != valid1.wantStatusCode || got.Message() != valid1.wantErr {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid1.wantStatusCode, valid1.wantErr, got.Code(), got.Message())
}
time.Sleep(100 * time.Millisecond)
numRetries++
if got := verifySkipReload(ctx, client, valid1.wantStatusCode, valid1.wantErr); got != nil {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid1.wantStatusCode, valid1.wantErr, got.Code(), got.Message())
}

// Rewrite the file with a different valid authorization policy.
Expand All @@ -551,19 +551,7 @@ func (s) TestSDKFileWatcher_RecoversFromReloadFailure(t *testing.T) {
}

// Verifying authorization decision.
numRetries = 0
reloadSuccess := false
var gotStatus *status.Status
for numRetries <= 20 {
_, err = client.UnaryCall(ctx, &pb.SimpleRequest{})
if gotStatus = status.Convert(err); gotStatus.Code() == valid2.wantStatusCode && gotStatus.Message() == valid2.wantErr {
reloadSuccess = true
break
}
time.Sleep(100 * time.Millisecond)
numRetries++
}
if reloadSuccess == false {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid2.wantStatusCode, valid2.wantErr, gotStatus.Code(), gotStatus.Message())
if got := verifyValidReload(ctx, client, valid2.wantStatusCode, valid2.wantErr); got != nil {
t.Fatalf("error want:{%v %v} got:{%v %v}", valid2.wantStatusCode, valid2.wantErr, got.Code(), got.Message())
}
}
14 changes: 8 additions & 6 deletions authz/sdk_server_interceptors.go
Expand Up @@ -86,11 +86,12 @@ func (i *StaticInterceptor) StreamInterceptor(srv interface{}, ss grpc.ServerStr
// FileWatcherInterceptor contains details used to make authorization decisions
// by watching a file path that contains authorization policy in JSON format.
type FileWatcherInterceptor struct {
internalInterceptor unsafe.Pointer // *StaticInterceptor
policyFile string
policyContents []byte
refreshDuration time.Duration
cancel context.CancelFunc
internalInterceptor unsafe.Pointer // *StaticInterceptor
policyFile string
policyContents []byte
latestValidPolicyContents []byte
refreshDuration time.Duration
cancel context.CancelFunc
}

// NewFileWatcher returns a new FileWatcherInterceptor from a policy file
Expand Down Expand Up @@ -141,13 +142,14 @@ func (i *FileWatcherInterceptor) updateInternalInterceptor() error {
if bytes.Equal(i.policyContents, policyContents) {
return nil
}
i.policyContents = policyContents
policyContentsString := string(policyContents)
interceptor, err := NewStatic(policyContentsString)
if err != nil {
return err
ashithasantosh marked this conversation as resolved.
Show resolved Hide resolved
}
atomic.StorePointer(&i.internalInterceptor, unsafe.Pointer(interceptor))
i.policyContents = policyContents
i.latestValidPolicyContents = i.policyContents
logger.Infof("authorization policy reload status: successfully loaded new policy %v", policyContentsString)
return nil
}
Expand Down