From 8a3bf905e451cce5da537d70dccde30e99e5ab00 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Wed, 29 Jun 2022 08:20:15 +0200 Subject: [PATCH] ast/compile: reorder body for safety differently (#4801) Our previous approach, ordering for closures first, and taking the growing set of output variables of the reordered body into account in a second step, did not work out for some examples: object.get(input.subject.roles[_], comp, [""], output) comp = [ 1 | true ] every y in [2] { y in output } Here, the closure of `every` would have not checkout out because we've never registered the output variable `output` -- since the first call to object.get is unsafe without `comp`, too. Now, the two stages have been merged. It's got surprisingly little fallout, one test case had to be adjusted, which I believe to be a minor case, too. Fixes the second part of #4766. Signed-off-by: Stephan Renatus --- ast/compile.go | 90 +++++++++++++++++---------------------------- ast/compile_test.go | 89 ++++++++++++++++++++++++++++++-------------- 2 files changed, 95 insertions(+), 84 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 62c843be44..5417119a2f 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -3134,13 +3134,10 @@ func (vs unsafeVars) Slice() (result []unsafePair) { // contains a mapping of expressions to unsafe variables in those expressions. func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { - body, unsafe := reorderBodyForClosures(arity, globals, body) - if len(unsafe) != 0 { - return nil, unsafe - } - - reordered := Body{} + bodyVars := body.Vars(SafetyCheckVisitorParams) + reordered := make(Body, 0, len(body)) safe := VarSet{} + unsafe := unsafeVars{} for _, e := range body { for v := range e.Vars(SafetyCheckVisitorParams) { @@ -3160,10 +3157,23 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo continue } - safe.Update(outputVarsForExpr(e, arity, safe)) + ovs := outputVarsForExpr(e, arity, safe) + + // check closures: is this expression closing over variables that + // haven't been made safe by what's already included in `reordered`? + vs := unsafeVarsInClosures(e, arity, safe) + cv := vs.Intersect(bodyVars).Diff(globals) + uv := cv.Diff(outputVarsForBody(reordered, arity, safe)) + + if len(uv) > 0 { + if uv.Equal(ovs) { // special case "closure-self" + continue + } + unsafe.Set(e, uv) + } for v := range unsafe[e] { - if safe.Contains(v) { + if ovs.Contains(v) || safe.Contains(v) { delete(unsafe[e], v) } } @@ -3171,10 +3181,11 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo if len(unsafe[e]) == 0 { delete(unsafe, e) reordered.Append(e) + safe.Update(ovs) // this expression's outputs are safe } } - if len(reordered) == n { + if len(reordered) == n { // fixed point, could not add any expr of body break } } @@ -3281,55 +3292,20 @@ func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetCompreh sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body) } -// reorderBodyForClosures returns a copy of the body ordered such that -// expressions (such as array comprehensions) that close over variables are ordered -// after other expressions that contain the same variable in an output position. -func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { - - reordered := Body{} - unsafe := unsafeVars{} - - for { - n := len(reordered) - - for _, e := range body { - if reordered.Contains(e) { - continue - } - - // Collect vars that are contained in closures within this - // expression. - vs := VarSet{} - WalkClosures(e, func(x interface{}) bool { - vis := &VarVisitor{vars: vs} - if ev, ok := x.(*Every); ok { - vis.Walk(ev.Body) - return true - } - vis.Walk(x) - return true - }) - - // Compute vars that are closed over from the body but not yet - // contained in the output position of an expression in the reordered - // body. These vars are considered unsafe. - cv := vs.Intersect(body.Vars(SafetyCheckVisitorParams)).Diff(globals) - uv := cv.Diff(outputVarsForBody(reordered, arity, globals)) - - if len(uv) == 0 { - reordered = append(reordered, e) - delete(unsafe, e) - } else { - unsafe.Set(e, uv) - } - } - - if len(reordered) == n { - break +// unsafeVarsInClosures collects vars that are contained in closures within +// this expression. +func unsafeVarsInClosures(e *Expr, arity func(Ref) int, safe VarSet) VarSet { + vs := VarSet{} + WalkClosures(e, func(x interface{}) bool { + vis := &VarVisitor{vars: vs} + if ev, ok := x.(*Every); ok { + vis.Walk(ev.Body) + return true } - } - - return reordered, unsafe + vis.Walk(x) + return true + }) + return vs } // OutputVarsFromBody returns all variables which are the "output" for diff --git a/ast/compile_test.go b/ast/compile_test.go index 0c6130b177..7ba920f03c 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -723,43 +723,78 @@ func TestCompilerCheckSafetyBodyReordering(t *testing.T) { } func TestCompilerCheckSafetyBodyReorderingClosures(t *testing.T) { - c := NewCompiler() - c.Modules = map[string]*Module{ - "mod": MustParseModule( - `package compr + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + + tests := []struct { + note string + mod *Module + exp Body + }{ + { + note: "comprehensions-1", + mod: MustParseModule(`package compr import data.b import data.c +p = true { v = [null | true]; xs = [x | a[i] = x; a = [y | y != 1; y = c[j]]]; xs[j] > 0; z = [true | data.a.b.d.t with input as i2; i2 = i]; b[i] = j } +`), + exp: MustParseBody(`v = [null | true]; data.b[i] = j; xs = [x | a = [y | y = data.c[j]; y != 1]; a[i] = x]; xs[j] > 0; z = [true | i2 = i; data.a.b.d.t with input as i2]`), + }, + { + note: "comprehensions-2", + mod: MustParseModule(`package compr + +import data.b +import data.c +q = true { _ = [x | x = b[i]]; _ = b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _ } +`), + exp: MustParseBody(`_ = [x | x = data.b[i]]; _ = data.b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _`), + }, + { + note: "comprehensions-3", + mod: MustParseModule(`package compr + +import data.b +import data.c fn(x) = y { trim(x, ".", y) } +r = true { a = [x | split(y, ".", z); x = z[i]; fn("...foo.bar..", y)] } +`), + exp: MustParseBody(`a = [x | data.compr.fn("...foo.bar..", y); split(y, ".", z); x = z[i]]`), + }, + { + note: "closure over function output", + mod: MustParseModule(`package test +import future.keywords -p = true { v = [null | true]; xs = [x | a[i] = x; a = [y | y != 1; y = c[j]]]; xs[j] > 0; z = [true | data.a.b.d.t with input as i2; i2 = i]; b[i] = j } -q = true { _ = [x | x = b[i]]; _ = b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _ } -r = true { a = [x | split(y, ".", z); x = z[i]; fn("...foo.bar..", y)] }`, - ), - } - - compileStages(c, c.checkSafetyRuleBodies) - assertNotFailed(t, c) - - result1 := c.Modules["mod"].Rules[1].Body - expected1 := MustParseBody(`v = [null | true]; data.b[i] = j; xs = [x | a = [y | y = data.c[j]; y != 1]; a[i] = x]; z = [true | i2 = i; data.a.b.d.t with input as i2]; xs[j] > 0`) - if !result1.Equal(expected1) { - t.Errorf("Expected reordered body to be equal to:\n%v\nBut got:\n%v", expected1, result1) +p { + object.get(input.subject.roles[_], comp, [""], output) + comp = [ 1 | true ] + every y in [2] { + y in output } - - result2 := c.Modules["mod"].Rules[2].Body - expected2 := MustParseBody(`_ = [x | x = data.b[i]]; _ = data.b[j]; _ = [x | x = true; x != false]; true != false; _ = [x | data.foo[_] = x]; data.foo[_] = _`) - if !result2.Equal(expected2) { - t.Errorf("Expected pre-ordered body to equal:\n%v\nBut got:\n%v", expected2, result2) +}`), + exp: MustParseBodyWithOpts(`comp = [1 | true] + __local2__ = [2] + object.get(input.subject.roles[_], comp, [""], output) + every __local0__, __local1__ in __local2__ { internal.member_2(__local1__, output) }`, opts), + }, } - result3 := c.Modules["mod"].Rules[3].Body - expected3 := MustParseBody(`a = [x | data.compr.fn("...foo.bar..", y); split(y, ".", z); x = z[i]]`) - if !result3.Equal(expected3) { - t.Errorf("Expected pre-ordered body to equal:\n%v\nBut got:\n%v", expected3, result3) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + c := NewCompiler() + c.Modules = map[string]*Module{"mod": tc.mod} + compileStages(c, c.checkSafetyRuleBodies) + assertNotFailed(t, c) + last := len(c.Modules["mod"].Rules) - 1 + actual := c.Modules["mod"].Rules[last].Body + if !actual.Equal(tc.exp) { + t.Errorf("Expected reordered body to be equal to:\n%v\nBut got:\n%v", tc.exp, actual) + } + }) } } @@ -797,7 +832,7 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) { {"array-compr-mixed", `p { _ = [x | y = [a | a = z[i]]] }`, `{a, x, z, i}`}, {"array-compr-builtin", `p { [true | eq != 2] }`, `{eq,}`}, {"closure-self", `p { x = [x | x = 1] }`, `{x,}`}, - {"closure-transitive", `p { x = y; x = [y | y = 1] }`, `{y,}`}, + {"closure-transitive", `p { x = y; x = [y | y = 1] }`, `{x,y}`}, {"nested", `p { count(baz[i].attr[bar[dead.beef]], n) }`, `{dead,}`}, {"negated-import", `p { not foo; not bar; not baz }`, `set()`}, {"rewritten", `p[{"foo": dead[i]}] { true }`, `{dead, i}`},