Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple Policy Compiler #924

Merged
merged 10 commits into from
May 20, 2024
199 changes: 199 additions & 0 deletions policy/compiler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package policy provides an extensible parser and compiler for composing
// a graph of CEL expressions into a single evaluable expression.
package policy

import (
"fmt"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
)

type compiler struct {
env *cel.Env
info *ast.SourceInfo
src *Source
}

type compiledRule struct {
variables []*compiledVariable
matches []*compiledMatch
}

type compiledVariable struct {
name string
expr *cel.Ast
}

type compiledMatch struct {
cond *cel.Ast
output *cel.Ast
nestedRule *compiledRule
}

// Compile generates a single CEL AST from a collection of policy expressions associated with a CEL environment.
func Compile(env *cel.Env, p *Policy) (*cel.Ast, *cel.Issues) {
c := &compiler{
env: env,
info: p.SourceInfo(),
src: p.Source(),
}
errs := common.NewErrors(c.src)
iss := cel.NewIssuesWithSourceInfo(errs, c.info)
rule, ruleIss := c.compileRule(p.Rule(), c.env, iss)
iss = iss.Append(ruleIss)
if iss.Err() != nil {
return nil, iss
}
ruleRoot, _ := env.Compile("true")
opt := cel.NewStaticOptimizer(&ruleComposer{rule: rule})
ruleExprAST, iss := opt.Optimize(env, ruleRoot)
return ruleExprAST, iss.Append(iss)
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
}

func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*compiledRule, *cel.Issues) {
var err error
compiledVars := make([]*compiledVariable, len(r.Variables()))
for i, v := range r.Variables() {
exprSrc := c.relSource(v.Expression())
varAST, exprIss := ruleEnv.CompileSource(exprSrc)
if exprIss.Err() == nil {
ruleEnv, err = ruleEnv.Extend(cel.Variable(fmt.Sprintf("variables.%s", v.Name().Value), varAST.OutputType()))
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
iss.ReportErrorAtID(v.Expression().ID, "invalid variable declaration")
}
compiledVars[i] = &compiledVariable{
name: v.name.Value,
expr: varAST,
}
}
iss = iss.Append(exprIss)
}
compiledMatches := []*compiledMatch{}
for _, m := range r.Matches() {
condSrc := c.relSource(m.Condition())
condAST, condIss := ruleEnv.CompileSource(condSrc)
iss = iss.Append(condIss)
if m.HasOutput() && m.HasRule() {
iss.ReportErrorAtID(m.Condition().ID, "either output or rule may be set but not both")
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
continue
}
if m.HasOutput() {
outSrc := c.relSource(m.Output())
outAST, outIss := ruleEnv.CompileSource(outSrc)
iss = iss.Append(outIss)
compiledMatches = append(compiledMatches, &compiledMatch{
cond: condAST,
output: outAST,
})
continue
}
if m.HasRule() {
nestedRule, ruleIss := c.compileRule(m.Rule(), ruleEnv, iss)
iss = iss.Append(ruleIss)
compiledMatches = append(compiledMatches, &compiledMatch{
cond: condAST,
nestedRule: nestedRule,
})
}
}
return &compiledRule{
variables: compiledVars,
matches: compiledMatches,
}, iss
}

func (c *compiler) relSource(pstr ValueString) *RelativeSource {
line := 0
col := 1
if offset, found := c.info.GetOffsetRange(pstr.ID); found {
if loc, found := c.src.OffsetLocation(offset.Start); found {
line = loc.Line()
col = loc.Column()
}
}
return c.src.Relative(pstr.Value, line, col)
}

type ruleComposer struct {
rule *compiledRule
}

// Optimize implements an AST optimizer for CEL which composes an expression graph into a single
// expression value.
func (opt *ruleComposer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
ruleExpr, _ := optimizeRule(ctx, opt.rule)
ctx.UpdateExpr(a.Expr(), ruleExpr)
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
return ctx.NewAST(ruleExpr)
}

func optimizeRule(ctx *cel.OptimizerContext, r *compiledRule) (ast.Expr, bool) {
matchExpr := ctx.NewCall("optional.none")
matches := r.matches
optionalResult := true
for i := len(matches) - 1; i >= 0; i-- {
m := matches[i]
cond := ctx.CopyASTAndMetadata(m.cond.NativeRep())
triviallyTrue := cond.Kind() == ast.LiteralKind && cond.AsLiteral() == types.True
if m.output != nil {
out := ctx.CopyASTAndMetadata(m.output.NativeRep())
if triviallyTrue {
matchExpr = out
optionalResult = false
continue
}
if optionalResult {
out = ctx.NewCall("optional.of", out)
}
matchExpr = ctx.NewCall(
operators.Conditional,
cond,
out,
matchExpr)
continue
}
nestedRule, nestedOptional := optimizeRule(ctx, m.nestedRule)
if optionalResult && !nestedOptional {
nestedRule = ctx.NewCall("optional.of", nestedRule)
}
if !optionalResult && nestedOptional {
matchExpr = ctx.NewCall("optional.of", matchExpr)
optionalResult = true
}
if !optionalResult && !nestedOptional {
ctx.ReportErrorAtID(nestedRule.ID(), "subrule early terminates policy")
continue
}
matchExpr = ctx.NewMemberCall("or", nestedRule, matchExpr)
}

vars := r.variables
for i := len(vars) - 1; i >= 0; i-- {
v := vars[i]
varAST := ctx.CopyASTAndMetadata(v.expr.NativeRep())
// Build up the bindings in reverse order, starting from root, all the way up to the outermost
// binding:
// currExpr = cel.bind(outerVar, outerExpr, currExpr)
inlined, bindMacro := ctx.NewBindMacro(matchExpr.ID(), fmt.Sprintf("variables.%s", v.name), varAST, matchExpr)
ctx.SetMacroCall(inlined.ID(), bindMacro)
matchExpr = inlined
}
return matchExpr, optionalResult
}
152 changes: 152 additions & 0 deletions policy/compiler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package policy

import (
"fmt"
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
)

func TestCompile(t *testing.T) {
for _, tst := range policyTests {
r := newRunner(t, tst.name, tst.envOpts...)
r.run(t)
}
}

func BenchmarkCompile(b *testing.B) {
for _, tst := range policyTests {
r := newRunner(b, tst.name, tst.envOpts...)
r.bench(b)
}
}

func newRunner(t testing.TB, name string, opts ...cel.EnvOption) *runner {
r := &runner{name: name, envOptions: opts}
r.setup(t)
return r
}

type runner struct {
name string
envOptions []cel.EnvOption
env *cel.Env
prg cel.Program
}

func (r *runner) setup(t testing.TB) {
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", r.name))
srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", r.name))
parser, err := NewParser()
if err != nil {
t.Fatalf("NewParser() failed: %v", err)
}
policy, iss := parser.Parse(srcFile)
if iss.Err() != nil {
t.Fatalf("Parse() failed: %v", iss.Err())
}
if policy.name.Value != r.name {
t.Errorf("policy name is %v, wanted %s", policy.name, r.name)
}
env, err := cel.NewEnv(
cel.OptionalTypes(),
cel.EnableMacroCallTracking(),
cel.ExtendedValidations())
if err != nil {
t.Fatalf("cel.NewEnv() failed: %v", err)
}
// Configure declarations
configOpts, err := config.AsEnvOptions(env)
if err != nil {
t.Fatalf("config.AsEnvOptions() failed: %v", err)
}
env, err = env.Extend(configOpts...)
if err != nil {
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
// Configure any implementations
env, err = env.Extend(r.envOptions...)
if err != nil {
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
ast, iss := Compile(env, policy)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
r.env = env
r.prg = prg
}

func (r *runner) run(t *testing.T) {
tests := readTestSuite(t, fmt.Sprintf("testdata/%s/tests.yaml", r.name))
for _, s := range tests.Sections {
section := s.Name
for _, tst := range s.Tests {
tc := tst
t.Run(fmt.Sprintf("%s/%s/%s", r.name, section, tc.Name), func(t *testing.T) {
out, _, err := r.prg.Eval(tc.Input)
if err != nil {
t.Fatalf("prg.Eval(tc.Input) failed: %v", err)
}
wantExpr, iss := r.env.Compile(tc.Output)
if iss.Err() != nil {
t.Fatalf("env.Compile(%q) failed :%v", tc.Output, iss.Err())
}
testPrg, err := r.env.Program(wantExpr)
if err != nil {
t.Fatalf("env.Program(wantExpr) failed: %v", err)
}
testOut, _, err := testPrg.Eval(cel.NoVars())
if err != nil {
t.Fatalf("testPrg.Eval() failed: %v", err)
}
if optOut, ok := out.(*types.Optional); ok {
if optOut.Equal(types.OptionalNone) == types.True {
if testOut.Equal(types.OptionalNone) != types.True {
t.Errorf("policy eval got %v, wanted %v", out, testOut)
}
} else if testOut.Equal(optOut.GetValue()) != types.True {
t.Errorf("policy eval got %v, wanted %v", out, testOut)
}
}
})
}
}
}

func (r *runner) bench(b *testing.B) {
tests := readTestSuite(b, fmt.Sprintf("testdata/%s/tests.yaml", r.name))
for _, s := range tests.Sections {
section := s.Name
for _, tst := range s.Tests {
tc := tst
b.Run(fmt.Sprintf("%s/%s/%s", r.name, section, tc.Name), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _, err := r.prg.Eval(tc.Input)
if err != nil {
b.Fatalf("policy eval failed: %v", err)
}
}
})
}
}
}