diff --git a/rego/rego.go b/rego/rego.go index 4d82bab3c8..f6edcafa29 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -2147,9 +2147,10 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, var unknowns []*ast.Term - if ectx.parsedUnknowns != nil { + switch { + case ectx.parsedUnknowns != nil: unknowns = ectx.parsedUnknowns - } else if ectx.unknowns != nil { + case ectx.unknowns != nil: unknowns = make([]*ast.Term, len(ectx.unknowns)) for i := range ectx.unknowns { var err error @@ -2158,7 +2159,7 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, return nil, err } } - } else { + default: // Use input document as unknown if caller has not specified any. unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)} } diff --git a/topdown/copypropagation/copypropagation.go b/topdown/copypropagation/copypropagation.go index debae03dda..daf3bac4df 100644 --- a/topdown/copypropagation/copypropagation.go +++ b/topdown/copypropagation/copypropagation.go @@ -5,6 +5,7 @@ package copypropagation import ( + "fmt" "sort" "github.com/open-policy-agent/opa/ast" @@ -31,6 +32,18 @@ type CopyPropagator struct { sorted []ast.Var // sorted copy of vars to ensure deterministic result ensureNonEmptyBody bool compiler *ast.Compiler + localvargen *localVarGenerator +} + +type localVarGenerator struct { + next int +} + +func (l *localVarGenerator) Generate() ast.Var { + result := ast.Var(fmt.Sprintf("__localcp%d__", l.next)) + l.next++ + return result + } // New returns a new CopyPropagator that optimizes queries while preserving vars @@ -46,7 +59,7 @@ func New(livevars ast.VarSet) *CopyPropagator { return sorted[i].Compare(sorted[j]) < 0 }) - return &CopyPropagator{livevars: livevars, sorted: sorted} + return &CopyPropagator{livevars: livevars, sorted: sorted, localvargen: &localVarGenerator{}} } // WithEnsureNonEmptyBody configures p to ensure that results are always non-empty. @@ -282,12 +295,16 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast. // updateBindings returns false if the expression can be killed. If the // expression is killed, the binding list is updated to map a var to value. func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool { - if pctx.negated || len(expr.With) > 0 { + switch { + case pctx.negated || len(expr.With) > 0: return true - } - if expr.IsEquality() { + + case expr.IsEquality(): a, b := expr.Operand(0), expr.Operand(1) if a.Equal(b) { + if p.livevarRef(a) { + pctx.removedEqs.Put(p.localvargen.Generate(), a.Value) + } return false } k, v, keep := p.updateBindingsEq(a, b) @@ -297,7 +314,8 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool } return false } - } else if expr.IsCall() { + + case expr.IsCall(): terms := expr.Terms.([]*ast.Term) if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output output := terms[len(terms)-1] @@ -310,6 +328,21 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool return !isNoop(expr) } +func (p *CopyPropagator) livevarRef(a *ast.Term) bool { + ref, ok := a.Value.(ast.Ref) + if !ok { + return false + } + + for _, v := range p.sorted { + if ref[0].Value.Compare(v) == 0 { + return true + } + } + + return false +} + func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) { k, v, keep := p.updateBindingsEqAsymmetric(a, b) if !keep { @@ -340,8 +373,7 @@ type plugContext struct { } type binding struct { - k ast.Value - v ast.Value + k, v ast.Value } func containedIn(value ast.Value, x interface{}) bool { @@ -374,7 +406,7 @@ func sortbindings(bindings *ast.ValueMap) []*binding { return false }) sort.Slice(sorted, func(i, j int) bool { - return sorted[i].k.Compare(sorted[j].k) < 0 + return sorted[i].k.Compare(sorted[j].k) > 0 }) return sorted } @@ -397,17 +429,21 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) { a, b := expr.Operand(0), expr.Operand(1) varA, ok1 := a.Value.(ast.Var) varB, ok2 := b.Value.(ast.Var) - if ok1 && ok2 { + + switch { + case ok1 && ok2: if _, ok := uf.Merge(varA, varB); !ok { return nil, false } - } else if ok1 && ast.IsConstant(b.Value) { + + case ok1 && ast.IsConstant(b.Value): root := uf.MakeSet(varA) if root.constant != nil && !root.constant.Equal(b) { return nil, false } root.constant = b - } else if ok2 && ast.IsConstant(a.Value) { + + case ok2 && ast.IsConstant(a.Value): root := uf.MakeSet(varB) if root.constant != nil && !root.constant.Equal(a) { return nil, false diff --git a/topdown/query.go b/topdown/query.go index 93f2ee3fe2..544c752e86 100644 --- a/topdown/query.go +++ b/topdown/query.go @@ -340,6 +340,14 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support [] defer q.metrics.Timer(metrics.RegoPartialEval).Stop() livevars := ast.NewVarSet() + for _, t := range q.unknowns { + switch v := t.Value.(type) { + case ast.Var: + livevars.Add(v) + case ast.Ref: + livevars.Add(v[0].Value.(ast.Var)) + } + } ast.WalkVars(q.query, func(x ast.Var) bool { if !x.IsGenerated() { diff --git a/topdown/topdown_partial_test.go b/topdown/topdown_partial_test.go index bef79be8a9..435cf00e10 100644 --- a/topdown/topdown_partial_test.go +++ b/topdown/topdown_partial_test.go @@ -1826,6 +1826,100 @@ func TestTopDownPartialEval(t *testing.T) { ), }, }, + { + note: "copy propagation: circular reference (bug 3559)", + query: "data.test.p", + modules: []string{`package test + p { + q[_] + } + q[x] { + x = input[x] + }`, + }, + wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`}, + }, + { + note: "copy propagation: circular reference (bug 3071)", + query: "data.test.p", + modules: []string{`package test + p[y] { + s := { i | input[i] } + s & set() != s + y := sprintf("%v", [s]) + }`, + }, + wantQueries: []string{`data.partial.test.p`}, + wantSupport: []string{`package partial.test + p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) } + `}, + }, + { + note: "copy propagation: tautology in query, input ref", + query: "input.a == input.a", + wantQueries: []string{`__localq1__ = input.a`}, + }, + { + note: "copy propagation: tautology in query, var ref, var is input", + query: "x := input; x.a == x.a", + wantQueries: []string{`__localq2__ = input.a`}, + }, + { + note: "copy propagation: tautology, input ref", + query: "data.test.p", + modules: []string{`package test + p { + input.a == input.a + }`, + }, + wantQueries: []string{`__localcp0__ = input.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is input", + query: "data.test.p", + modules: []string{`package test + p { + x := input + x.a == x.a + }`, + }, + wantQueries: []string{`__localcp0__ = input.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is unknown data", + query: "data.test.p", + unknowns: []string{"data.bar.foo"}, + modules: []string{`package test + p { + data.bar.foo.a == data.bar.foo.a + }`, + }, + wantQueries: []string{`__localcp0__ = data.bar.foo.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is input, via unknown", + // NOTE(sr): If we were having unkowns: [input.foo] and the rule body was + // input.a == input.a, we'd never reach copy-propagation -- partial eval would + // have failed before. + query: "data.test.p", + unknowns: []string{"input"}, + modules: []string{`package test + p { + input.foo.a == input.foo.a + }`, + }, + wantQueries: []string{`__localcp0__ = input.foo.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is head var", + query: "data.test.p(input)", + modules: []string{`package test + p(x) { + x.a == x.a + }`, + }, + wantQueries: []string{`__localcp1__ = input.a`}, + }, { note: "save set vars are namespaced", query: "input = x; data.test.f(1)", @@ -2985,7 +3079,12 @@ func TestTopDownPartialEval(t *testing.T) { x = true }`, }, - wantQueries: []string{"a1 = input.foo1; b1 = input.foo2; c1 = input.foo3; d1 = input.foo4; e1 = input.foo5"}, + wantQueries: []string{` + e1 = input.foo5 + d1 = input.foo4 + c1 = input.foo3 + b1 = input.foo2 + a1 = input.foo1`}, }, { note: "partial object rules not memoized", @@ -3054,34 +3153,6 @@ func TestTopDownPartialEval(t *testing.T) { shallow: true, skipPartialNamespace: true, }, - { - note: "copypropagation: circular reference (bug 3559)", - query: "data.test.p", - modules: []string{`package test - p { - q[_] - } - q[x] { - x = input[x] - }`, - }, - wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`}, - }, - { - note: "copypropagation: circular reference (bug 3071)", - query: "data.test.p", - modules: []string{`package test - p[y] { - s := { i | input[i] } - s & set() != s - y := sprintf("%v", [s]) - }`, - }, - wantQueries: []string{`data.partial.test.p`}, - wantSupport: []string{`package partial.test - p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) } - `}, - }, { note: "every: empty domain, no unknowns", query: "data.test.p",