Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ast/compile: reorder body for safety differently #4801

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 33 additions & 57 deletions ast/compile.go
Expand Up @@ -3134,13 +3134,10 @@ func (vs unsafeVars) Slice() (result []unsafePair) {
// contains a mapping of expressions to unsafe variables in those expressions.
func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {

body, unsafe := reorderBodyForClosures(arity, globals, body)
if len(unsafe) != 0 {
return nil, unsafe
}

reordered := Body{}
bodyVars := body.Vars(SafetyCheckVisitorParams)
reordered := make(Body, 0, len(body))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

micro-optimizations of little consequence:

  1. the bodyVars never change, so we can collect them outside of the loop (compared to reorderBodyForClosures)
  2. we know how large reordered could become, and usually would, if there are no unsafe variables; so allocate upfront.

safe := VarSet{}
unsafe := unsafeVars{}

for _, e := range body {
for v := range e.Vars(SafetyCheckVisitorParams) {
Expand All @@ -3160,21 +3157,35 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo
continue
}

safe.Update(outputVarsForExpr(e, arity, safe))
ovs := outputVarsForExpr(e, arity, safe)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the ones we've been missing in reorderBodyForClosures: because there, we kept feeding globals as third parameter, instead of the accumulated safe vars, we'd never see output variables for function calls whose inputs only become safe after reordering.


// check closures: is this expression closing over variables that
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically the inlining of reorderBodyForClosures.

// haven't been made safe by what's already included in `reordered`?
vs := unsafeVarsInClosures(e, arity, safe)
cv := vs.Intersect(bodyVars).Diff(globals)
uv := cv.Diff(outputVarsForBody(reordered, arity, safe))

if len(uv) > 0 {
if uv.Equal(ovs) { // special case "closure-self"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is ugly, but it's what makes x = [ x | x = 1 ] fail without affecting the other test cases 🙈

continue
}
unsafe.Set(e, uv)
}

for v := range unsafe[e] {
if safe.Contains(v) {
if ovs.Contains(v) || safe.Contains(v) {
delete(unsafe[e], v)
}
}

if len(unsafe[e]) == 0 {
delete(unsafe, e)
reordered.Append(e)
safe.Update(ovs) // this expression's outputs are safe
}
}

if len(reordered) == n {
if len(reordered) == n { // fixed point, could not add any expr of body
break
}
}
Expand Down Expand Up @@ -3281,55 +3292,20 @@ func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetCompreh
sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body)
}

// reorderBodyForClosures returns a copy of the body ordered such that
// expressions (such as array comprehensions) that close over variables are ordered
// after other expressions that contain the same variable in an output position.
func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {

reordered := Body{}
unsafe := unsafeVars{}

for {
n := len(reordered)

for _, e := range body {
if reordered.Contains(e) {
continue
}

// Collect vars that are contained in closures within this
// expression.
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
return true
}
vis.Walk(x)
return true
})

// Compute vars that are closed over from the body but not yet
// contained in the output position of an expression in the reordered
// body. These vars are considered unsafe.
cv := vs.Intersect(body.Vars(SafetyCheckVisitorParams)).Diff(globals)
uv := cv.Diff(outputVarsForBody(reordered, arity, globals))

if len(uv) == 0 {
reordered = append(reordered, e)
delete(unsafe, e)
} else {
unsafe.Set(e, uv)
}
}

if len(reordered) == n {
break
// unsafeVarsInClosures collects vars that are contained in closures within
// this expression.
func unsafeVarsInClosures(e *Expr, arity func(Ref) int, safe VarSet) VarSet {
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
return true
}
}

return reordered, unsafe
vis.Walk(x)
return true
})
return vs
}

// OutputVarsFromBody returns all variables which are the "output" for
Expand Down
89 changes: 62 additions & 27 deletions ast/compile_test.go
Expand Up @@ -723,43 +723,78 @@ func TestCompilerCheckSafetyBodyReordering(t *testing.T) {
}

func TestCompilerCheckSafetyBodyReorderingClosures(t *testing.T) {
c := NewCompiler()
c.Modules = map[string]*Module{
"mod": MustParseModule(
`package compr
opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true}

tests := []struct {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"modernized" the test case structure here, adding the one we care about

note string
mod *Module
exp Body
}{
{
note: "comprehensions-1",
mod: MustParseModule(`package compr

import data.b
import data.c
p = true { v = [null | true]; xs = [x | a[i] = x; a = [y | y != 1; y = c[j]]]; xs[j] > 0; z = [true | data.a.b.d.t with input as i2; i2 = i]; b[i] = j }
`),
exp: MustParseBody(`v = [null | true]; data.b[i] = j; xs = [x | a = [y | y = data.c[j]; y != 1]; a[i] = x]; xs[j] > 0; z = [true | i2 = i; data.a.b.d.t with input as i2]`),
},
{
note: "comprehensions-2",
mod: MustParseModule(`package compr

import data.b
import data.c
q = true { _ = [x | x = b[i]]; _ = b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _ }
`),
exp: MustParseBody(`_ = [x | x = data.b[i]]; _ = data.b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _`),
},

{
note: "comprehensions-3",
mod: MustParseModule(`package compr

import data.b
import data.c
fn(x) = y {
trim(x, ".", y)
}
r = true { a = [x | split(y, ".", z); x = z[i]; fn("...foo.bar..", y)] }
`),
exp: MustParseBody(`a = [x | data.compr.fn("...foo.bar..", y); split(y, ".", z); x = z[i]]`),
},
{
note: "closure over function output",
mod: MustParseModule(`package test
import future.keywords

p = true { v = [null | true]; xs = [x | a[i] = x; a = [y | y != 1; y = c[j]]]; xs[j] > 0; z = [true | data.a.b.d.t with input as i2; i2 = i]; b[i] = j }
q = true { _ = [x | x = b[i]]; _ = b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _ }
r = true { a = [x | split(y, ".", z); x = z[i]; fn("...foo.bar..", y)] }`,
),
}

compileStages(c, c.checkSafetyRuleBodies)
assertNotFailed(t, c)

result1 := c.Modules["mod"].Rules[1].Body
expected1 := MustParseBody(`v = [null | true]; data.b[i] = j; xs = [x | a = [y | y = data.c[j]; y != 1]; a[i] = x]; z = [true | i2 = i; data.a.b.d.t with input as i2]; xs[j] > 0`)
if !result1.Equal(expected1) {
t.Errorf("Expected reordered body to be equal to:\n%v\nBut got:\n%v", expected1, result1)
p {
object.get(input.subject.roles[_], comp, [""], output)
comp = [ 1 | true ]
every y in [2] {
y in output
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the case that broke PE for #4766.

}

result2 := c.Modules["mod"].Rules[2].Body
expected2 := MustParseBody(`_ = [x | x = data.b[i]]; _ = data.b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _`)
if !result2.Equal(expected2) {
t.Errorf("Expected pre-ordered body to equal:\n%v\nBut got:\n%v", expected2, result2)
}`),
exp: MustParseBodyWithOpts(`comp = [1 | true]
__local2__ = [2]
object.get(input.subject.roles[_], comp, [""], output)
every __local0__, __local1__ in __local2__ { internal.member_2(__local1__, output) }`, opts),
},
}

result3 := c.Modules["mod"].Rules[3].Body
expected3 := MustParseBody(`a = [x | data.compr.fn("...foo.bar..", y); split(y, ".", z); x = z[i]]`)
if !result3.Equal(expected3) {
t.Errorf("Expected pre-ordered body to equal:\n%v\nBut got:\n%v", expected3, result3)
for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
c := NewCompiler()
c.Modules = map[string]*Module{"mod": tc.mod}
compileStages(c, c.checkSafetyRuleBodies)
assertNotFailed(t, c)
last := len(c.Modules["mod"].Rules) - 1
actual := c.Modules["mod"].Rules[last].Body
if !actual.Equal(tc.exp) {
t.Errorf("Expected reordered body to be equal to:\n%v\nBut got:\n%v", tc.exp, actual)
}
})
}
}

Expand Down Expand Up @@ -797,7 +832,7 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) {
{"array-compr-mixed", `p { _ = [x | y = [a | a = z[i]]] }`, `{a, x, z, i}`},
{"array-compr-builtin", `p { [true | eq != 2] }`, `{eq,}`},
{"closure-self", `p { x = [x | x = 1] }`, `{x,}`},
{"closure-transitive", `p { x = y; x = [y | y = 1] }`, `{y,}`},
{"closure-transitive", `p { x = y; x = [y | y = 1] }`, `{x,y}`},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we accept that? I hope so 😅

{"nested", `p { count(baz[i].attr[bar[dead.beef]], n) }`, `{dead,}`},
{"negated-import", `p { not foo; not bar; not baz }`, `set()`},
{"rewritten", `p[{"foo": dead[i]}] { true }`, `{dead, i}`},
Expand Down