Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topdown/copypropagation: keep refs into livevars #4936

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions rego/rego.go
Expand Up @@ -2147,9 +2147,10 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,

var unknowns []*ast.Term

if ectx.parsedUnknowns != nil {
switch {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style preference: I like condition-less switch statements.

case ectx.parsedUnknowns != nil:
unknowns = ectx.parsedUnknowns
} else if ectx.unknowns != nil {
case ectx.unknowns != nil:
unknowns = make([]*ast.Term, len(ectx.unknowns))
for i := range ectx.unknowns {
var err error
Expand All @@ -2158,7 +2159,7 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,
return nil, err
}
}
} else {
default:
// Use input document as unknown if caller has not specified any.
unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)}
}
Expand Down
58 changes: 47 additions & 11 deletions topdown/copypropagation/copypropagation.go
Expand Up @@ -5,6 +5,7 @@
package copypropagation

import (
"fmt"
"sort"

"github.com/open-policy-agent/opa/ast"
Expand All @@ -31,6 +32,18 @@ type CopyPropagator struct {
sorted []ast.Var // sorted copy of vars to ensure deterministic result
ensureNonEmptyBody bool
compiler *ast.Compiler
localvargen *localVarGenerator
}

type localVarGenerator struct {
next int
}

func (l *localVarGenerator) Generate() ast.Var {
result := ast.Var(fmt.Sprintf("__localcp%d__", l.next))
l.next++
return result

}

// New returns a new CopyPropagator that optimizes queries while preserving vars
Expand All @@ -46,7 +59,7 @@ func New(livevars ast.VarSet) *CopyPropagator {
return sorted[i].Compare(sorted[j]) < 0
})

return &CopyPropagator{livevars: livevars, sorted: sorted}
return &CopyPropagator{livevars: livevars, sorted: sorted, localvargen: &localVarGenerator{}}
}

// WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
Expand Down Expand Up @@ -282,12 +295,16 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.
// updateBindings returns false if the expression can be killed. If the
// expression is killed, the binding list is updated to map a var to value.
func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
if pctx.negated || len(expr.With) > 0 {
switch {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(just style pref)

case pctx.negated || len(expr.With) > 0:
return true
}
if expr.IsEquality() {

case expr.IsEquality():
a, b := expr.Operand(0), expr.Operand(1)
if a.Equal(b) {
if p.livevarRef(a) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡ This is the new thing: they might be equal, but they can still be relevant!

pctx.removedEqs.Put(p.localvargen.Generate(), a.Value)
}
return false
}
k, v, keep := p.updateBindingsEq(a, b)
Expand All @@ -297,7 +314,8 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool
}
return false
}
} else if expr.IsCall() {

case expr.IsCall():
terms := expr.Terms.([]*ast.Term)
if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output
output := terms[len(terms)-1]
Expand All @@ -310,6 +328,21 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool
return !isNoop(expr)
}

func (p *CopyPropagator) livevarRef(a *ast.Term) bool {
ref, ok := a.Value.(ast.Ref)
if !ok {
return false
}

for _, v := range p.sorted {
if ref[0].Value.Compare(v) == 0 {
return true
}
}

return false
}

func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
k, v, keep := p.updateBindingsEqAsymmetric(a, b)
if !keep {
Expand Down Expand Up @@ -340,8 +373,7 @@ type plugContext struct {
}

type binding struct {
k ast.Value
v ast.Value
k, v ast.Value
}

func containedIn(value ast.Value, x interface{}) bool {
Expand Down Expand Up @@ -374,7 +406,7 @@ func sortbindings(bindings *ast.ValueMap) []*binding {
return false
})
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].k.Compare(sorted[j].k) < 0
return sorted[i].k.Compare(sorted[j].k) > 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔍 changing the ordering here let's us avoid adding __local0__ = input to the body when we already have __localcp0__ = input.a.

})
return sorted
}
Expand All @@ -397,17 +429,21 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
a, b := expr.Operand(0), expr.Operand(1)
varA, ok1 := a.Value.(ast.Var)
varB, ok2 := b.Value.(ast.Var)
if ok1 && ok2 {

switch {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(just style pref)

case ok1 && ok2:
if _, ok := uf.Merge(varA, varB); !ok {
return nil, false
}
} else if ok1 && ast.IsConstant(b.Value) {

case ok1 && ast.IsConstant(b.Value):
root := uf.MakeSet(varA)
if root.constant != nil && !root.constant.Equal(b) {
return nil, false
}
root.constant = b
} else if ok2 && ast.IsConstant(a.Value) {

case ok2 && ast.IsConstant(a.Value):
root := uf.MakeSet(varB)
if root.constant != nil && !root.constant.Equal(a) {
return nil, false
Expand Down
8 changes: 8 additions & 0 deletions topdown/query.go
Expand Up @@ -340,6 +340,14 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
defer q.metrics.Timer(metrics.RegoPartialEval).Stop()

livevars := ast.NewVarSet()
for _, t := range q.unknowns {
switch v := t.Value.(type) {
case ast.Var:
livevars.Add(v)
case ast.Ref:
livevars.Add(v[0].Value.(ast.Var))
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before, we hadn't added input or data if they're part of the unknowns to the livevars. This felt a bit odd to me -- what if input.foo is unknown, but input.bar == input.bar is in the query to-be-copy-propagated? Turns out that's not possible -- we're only running CP after a successful PE run, as a final step. So, if we've made it to CP, and there's a tautology to check (input.x == input.x), checking its first term is good enough.


ast.WalkVars(q.query, func(x ast.Var) bool {
if !x.IsGenerated() {
Expand Down
129 changes: 100 additions & 29 deletions topdown/topdown_partial_test.go
Expand Up @@ -1826,6 +1826,100 @@ func TestTopDownPartialEval(t *testing.T) {
),
},
},
{
note: "copy propagation: circular reference (bug 3559)",
query: "data.test.p",
modules: []string{`package test
p {
q[_]
}
q[x] {
x = input[x]
}`,
},
wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`},
},
{
note: "copy propagation: circular reference (bug 3071)",
query: "data.test.p",
modules: []string{`package test
p[y] {
s := { i | input[i] }
s & set() != s
y := sprintf("%v", [s])
}`,
},
wantQueries: []string{`data.partial.test.p`},
wantSupport: []string{`package partial.test
p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) }
`},
},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

☝️ these two are just moved to keep the CP-related tests together.

{
note: "copy propagation: tautology in query, input ref",
query: "input.a == input.a",
wantQueries: []string{`__localq1__ = input.a`},
},
{
note: "copy propagation: tautology in query, var ref, var is input",
query: "x := input; x.a == x.a",
wantQueries: []string{`__localq2__ = input.a`},
},
{
note: "copy propagation: tautology, input ref",
query: "data.test.p",
modules: []string{`package test
p {
input.a == input.a
}`,
},
wantQueries: []string{`__localcp0__ = input.a`},
},
{
note: "copy propagation: tautology, var ref, ref is input",
query: "data.test.p",
modules: []string{`package test
p {
x := input
x.a == x.a
}`,
},
wantQueries: []string{`__localcp0__ = input.a`},
},
{
note: "copy propagation: tautology, var ref, ref is unknown data",
query: "data.test.p",
unknowns: []string{"data.bar.foo"},
modules: []string{`package test
p {
data.bar.foo.a == data.bar.foo.a
}`,
},
wantQueries: []string{`__localcp0__ = data.bar.foo.a`},
},
{
note: "copy propagation: tautology, var ref, ref is input, via unknown",
// NOTE(sr): If we were having unkowns: [input.foo] and the rule body was
// input.a == input.a, we'd never reach copy-propagation -- partial eval would
// have failed before.
query: "data.test.p",
unknowns: []string{"input"},
modules: []string{`package test
p {
input.foo.a == input.foo.a
}`,
},
wantQueries: []string{`__localcp0__ = input.foo.a`},
},
{
note: "copy propagation: tautology, var ref, ref is head var",
query: "data.test.p(input)",
modules: []string{`package test
p(x) {
x.a == x.a
}`,
},
wantQueries: []string{`__localcp1__ = input.a`},
},
{
note: "save set vars are namespaced",
query: "input = x; data.test.f(1)",
Expand Down Expand Up @@ -2985,7 +3079,12 @@ func TestTopDownPartialEval(t *testing.T) {
x = true
}`,
},
wantQueries: []string{"a1 = input.foo1; b1 = input.foo2; c1 = input.foo3; d1 = input.foo4; e1 = input.foo5"},
wantQueries: []string{`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the only impact of the ordering change in our test suite. 🙃

e1 = input.foo5
d1 = input.foo4
c1 = input.foo3
b1 = input.foo2
a1 = input.foo1`},
},
{
note: "partial object rules not memoized",
Expand Down Expand Up @@ -3054,34 +3153,6 @@ func TestTopDownPartialEval(t *testing.T) {
shallow: true,
skipPartialNamespace: true,
},
{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved these to the other block of copy-prop related tests.

note: "copypropagation: circular reference (bug 3559)",
query: "data.test.p",
modules: []string{`package test
p {
q[_]
}
q[x] {
x = input[x]
}`,
},
wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`},
},
{
note: "copypropagation: circular reference (bug 3071)",
query: "data.test.p",
modules: []string{`package test
p[y] {
s := { i | input[i] }
s & set() != s
y := sprintf("%v", [s])
}`,
},
wantQueries: []string{`data.partial.test.p`},
wantSupport: []string{`package partial.test
p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) }
`},
},
{
note: "every: empty domain, no unknowns",
query: "data.test.p",
Expand Down