From 2465f3498bb42bf8fa7bb54c32fe6bf788f690d7 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Mon, 25 Jul 2022 14:27:28 +0200 Subject: [PATCH] topdown/copypropagation: keep refs into livevars Before, a query of input.a == input.a would not survive copypropagation. With this change, it'll be recorded as removedEq, and subsequent processing steps ensure that it's kept in the body. Changing the sort order in sortBindings allows us to limit the unnecessary variable bindings: with the previous ordering, we'd get __local0__1 = input; __localcp0__ = input.a for the query `x := input; input.a == input.a`. Sorting the other way, we'll process `__localcp0__ = input.a` first, add it to the body, and when we check `__local0__1 = input`, we find that `input` is already contained in the body, and is thus not needed. Fixes #4848. Signed-off-by: Stephan Renatus --- rego/rego.go | 7 +- topdown/copypropagation/copypropagation.go | 58 +++++++-- topdown/query.go | 8 ++ topdown/topdown_partial_test.go | 129 ++++++++++++++++----- 4 files changed, 159 insertions(+), 43 deletions(-) diff --git a/rego/rego.go b/rego/rego.go index 4d82bab3c8..f6edcafa29 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -2147,9 +2147,10 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, var unknowns []*ast.Term - if ectx.parsedUnknowns != nil { + switch { + case ectx.parsedUnknowns != nil: unknowns = ectx.parsedUnknowns - } else if ectx.unknowns != nil { + case ectx.unknowns != nil: unknowns = make([]*ast.Term, len(ectx.unknowns)) for i := range ectx.unknowns { var err error @@ -2158,7 +2159,7 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, return nil, err } } - } else { + default: // Use input document as unknown if caller has not specified any. unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)} } diff --git a/topdown/copypropagation/copypropagation.go b/topdown/copypropagation/copypropagation.go index debae03dda..daf3bac4df 100644 --- a/topdown/copypropagation/copypropagation.go +++ b/topdown/copypropagation/copypropagation.go @@ -5,6 +5,7 @@ package copypropagation import ( + "fmt" "sort" "github.com/open-policy-agent/opa/ast" @@ -31,6 +32,18 @@ type CopyPropagator struct { sorted []ast.Var // sorted copy of vars to ensure deterministic result ensureNonEmptyBody bool compiler *ast.Compiler + localvargen *localVarGenerator +} + +type localVarGenerator struct { + next int +} + +func (l *localVarGenerator) Generate() ast.Var { + result := ast.Var(fmt.Sprintf("__localcp%d__", l.next)) + l.next++ + return result + } // New returns a new CopyPropagator that optimizes queries while preserving vars @@ -46,7 +59,7 @@ func New(livevars ast.VarSet) *CopyPropagator { return sorted[i].Compare(sorted[j]) < 0 }) - return &CopyPropagator{livevars: livevars, sorted: sorted} + return &CopyPropagator{livevars: livevars, sorted: sorted, localvargen: &localVarGenerator{}} } // WithEnsureNonEmptyBody configures p to ensure that results are always non-empty. @@ -282,12 +295,16 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast. // updateBindings returns false if the expression can be killed. If the // expression is killed, the binding list is updated to map a var to value. func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool { - if pctx.negated || len(expr.With) > 0 { + switch { + case pctx.negated || len(expr.With) > 0: return true - } - if expr.IsEquality() { + + case expr.IsEquality(): a, b := expr.Operand(0), expr.Operand(1) if a.Equal(b) { + if p.livevarRef(a) { + pctx.removedEqs.Put(p.localvargen.Generate(), a.Value) + } return false } k, v, keep := p.updateBindingsEq(a, b) @@ -297,7 +314,8 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool } return false } - } else if expr.IsCall() { + + case expr.IsCall(): terms := expr.Terms.([]*ast.Term) if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output output := terms[len(terms)-1] @@ -310,6 +328,21 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool return !isNoop(expr) } +func (p *CopyPropagator) livevarRef(a *ast.Term) bool { + ref, ok := a.Value.(ast.Ref) + if !ok { + return false + } + + for _, v := range p.sorted { + if ref[0].Value.Compare(v) == 0 { + return true + } + } + + return false +} + func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) { k, v, keep := p.updateBindingsEqAsymmetric(a, b) if !keep { @@ -340,8 +373,7 @@ type plugContext struct { } type binding struct { - k ast.Value - v ast.Value + k, v ast.Value } func containedIn(value ast.Value, x interface{}) bool { @@ -374,7 +406,7 @@ func sortbindings(bindings *ast.ValueMap) []*binding { return false }) sort.Slice(sorted, func(i, j int) bool { - return sorted[i].k.Compare(sorted[j].k) < 0 + return sorted[i].k.Compare(sorted[j].k) > 0 }) return sorted } @@ -397,17 +429,21 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) { a, b := expr.Operand(0), expr.Operand(1) varA, ok1 := a.Value.(ast.Var) varB, ok2 := b.Value.(ast.Var) - if ok1 && ok2 { + + switch { + case ok1 && ok2: if _, ok := uf.Merge(varA, varB); !ok { return nil, false } - } else if ok1 && ast.IsConstant(b.Value) { + + case ok1 && ast.IsConstant(b.Value): root := uf.MakeSet(varA) if root.constant != nil && !root.constant.Equal(b) { return nil, false } root.constant = b - } else if ok2 && ast.IsConstant(a.Value) { + + case ok2 && ast.IsConstant(a.Value): root := uf.MakeSet(varB) if root.constant != nil && !root.constant.Equal(a) { return nil, false diff --git a/topdown/query.go b/topdown/query.go index 93f2ee3fe2..544c752e86 100644 --- a/topdown/query.go +++ b/topdown/query.go @@ -340,6 +340,14 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support [] defer q.metrics.Timer(metrics.RegoPartialEval).Stop() livevars := ast.NewVarSet() + for _, t := range q.unknowns { + switch v := t.Value.(type) { + case ast.Var: + livevars.Add(v) + case ast.Ref: + livevars.Add(v[0].Value.(ast.Var)) + } + } ast.WalkVars(q.query, func(x ast.Var) bool { if !x.IsGenerated() { diff --git a/topdown/topdown_partial_test.go b/topdown/topdown_partial_test.go index bef79be8a9..435cf00e10 100644 --- a/topdown/topdown_partial_test.go +++ b/topdown/topdown_partial_test.go @@ -1826,6 +1826,100 @@ func TestTopDownPartialEval(t *testing.T) { ), }, }, + { + note: "copy propagation: circular reference (bug 3559)", + query: "data.test.p", + modules: []string{`package test + p { + q[_] + } + q[x] { + x = input[x] + }`, + }, + wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`}, + }, + { + note: "copy propagation: circular reference (bug 3071)", + query: "data.test.p", + modules: []string{`package test + p[y] { + s := { i | input[i] } + s & set() != s + y := sprintf("%v", [s]) + }`, + }, + wantQueries: []string{`data.partial.test.p`}, + wantSupport: []string{`package partial.test + p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) } + `}, + }, + { + note: "copy propagation: tautology in query, input ref", + query: "input.a == input.a", + wantQueries: []string{`__localq1__ = input.a`}, + }, + { + note: "copy propagation: tautology in query, var ref, var is input", + query: "x := input; x.a == x.a", + wantQueries: []string{`__localq2__ = input.a`}, + }, + { + note: "copy propagation: tautology, input ref", + query: "data.test.p", + modules: []string{`package test + p { + input.a == input.a + }`, + }, + wantQueries: []string{`__localcp0__ = input.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is input", + query: "data.test.p", + modules: []string{`package test + p { + x := input + x.a == x.a + }`, + }, + wantQueries: []string{`__localcp0__ = input.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is unknown data", + query: "data.test.p", + unknowns: []string{"data.bar.foo"}, + modules: []string{`package test + p { + data.bar.foo.a == data.bar.foo.a + }`, + }, + wantQueries: []string{`__localcp0__ = data.bar.foo.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is input, via unknown", + // NOTE(sr): If we were having unkowns: [input.foo] and the rule body was + // input.a == input.a, we'd never reach copy-propagation -- partial eval would + // have failed before. + query: "data.test.p", + unknowns: []string{"input"}, + modules: []string{`package test + p { + input.foo.a == input.foo.a + }`, + }, + wantQueries: []string{`__localcp0__ = input.foo.a`}, + }, + { + note: "copy propagation: tautology, var ref, ref is head var", + query: "data.test.p(input)", + modules: []string{`package test + p(x) { + x.a == x.a + }`, + }, + wantQueries: []string{`__localcp1__ = input.a`}, + }, { note: "save set vars are namespaced", query: "input = x; data.test.f(1)", @@ -2985,7 +3079,12 @@ func TestTopDownPartialEval(t *testing.T) { x = true }`, }, - wantQueries: []string{"a1 = input.foo1; b1 = input.foo2; c1 = input.foo3; d1 = input.foo4; e1 = input.foo5"}, + wantQueries: []string{` + e1 = input.foo5 + d1 = input.foo4 + c1 = input.foo3 + b1 = input.foo2 + a1 = input.foo1`}, }, { note: "partial object rules not memoized", @@ -3054,34 +3153,6 @@ func TestTopDownPartialEval(t *testing.T) { shallow: true, skipPartialNamespace: true, }, - { - note: "copypropagation: circular reference (bug 3559)", - query: "data.test.p", - modules: []string{`package test - p { - q[_] - } - q[x] { - x = input[x] - }`, - }, - wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`}, - }, - { - note: "copypropagation: circular reference (bug 3071)", - query: "data.test.p", - modules: []string{`package test - p[y] { - s := { i | input[i] } - s & set() != s - y := sprintf("%v", [s]) - }`, - }, - wantQueries: []string{`data.partial.test.p`}, - wantSupport: []string{`package partial.test - p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) } - `}, - }, { note: "every: empty domain, no unknowns", query: "data.test.p",