Skip to content

Commit

Permalink
topdown/eval: fix 'every' term plugging on save (#4775)
Browse files Browse the repository at this point in the history
Previously missing plugging could cause unsafe variables in the
PE output.

Now, all terms in the 'every' body should be plugged properly.
The approach taken here is to plug them all, and then fix the
key and val var names of the copied every expression. Those vars
are fresh after the compiler is done with the expression, so
plugging them should never have any effect outside of the rename.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Jun 29, 2022
1 parent 8a3bf90 commit c0c902c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
16 changes: 8 additions & 8 deletions topdown/eval.go
Expand Up @@ -2978,9 +2978,7 @@ func (e evalTerm) save(iter unifyIterator) error {
suffix := e.ref[e.pos:]
ref := make(ast.Ref, len(suffix)+1)
ref[0] = v
for i := 0; i < len(suffix); i++ {
ref[i+1] = suffix[i]
}
copy(ref[1:], suffix)

return e.e.biunify(ast.NewTerm(ref), e.rterm, e.bindings, e.rbindings, iter)
})
Expand Down Expand Up @@ -3044,18 +3042,20 @@ func (e *evalEvery) save(iter unifyIterator) error {
cpy := e.expr.Copy()
every := cpy.Terms.(*ast.Every)

for i := range every.Body { // TODO(sr): is there an easier way?
for i := range every.Body {
switch t := every.Body[i].Terms.(type) {
case *ast.Term:
every.Body[i].Terms = e.e.bindings.Plug(t)
every.Body[i].Terms = e.e.bindings.PlugNamespaced(t, e.e.caller.bindings)
case []*ast.Term:
for j := range t {
t[j] = e.e.bindings.Plug(t[j])
for j := 1; j < len(t); j++ { // don't plug operator, t[0]
t[j] = e.e.bindings.PlugNamespaced(t[j], e.e.caller.bindings)
}
}
}

every.Domain = e.e.bindings.plugNamespaced(every.Domain, e.e.caller.bindings)
every.Key = e.e.bindings.PlugNamespaced(every.Key, e.e.caller.bindings)
every.Value = e.e.bindings.PlugNamespaced(every.Value, e.e.caller.bindings)
every.Domain = e.e.bindings.PlugNamespaced(every.Domain, e.e.caller.bindings)
cpy.Terms = every

return e.e.saveExpr(cpy, e.e.bindings, iter)
Expand Down
44 changes: 37 additions & 7 deletions topdown/topdown_partial_test.go
Expand Up @@ -3107,7 +3107,10 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [] { x > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [] { __local3__ = input; __local1__ > __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [] {
__local3__1 = input
__local1__1 > __local3__1
}`},
},
{
note: "every: known domain, unknowns in body",
Expand All @@ -3116,7 +3119,10 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [1, 2, 3] { x > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [1, 2, 3] { __local3__ = input; __local1__ > __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [1, 2, 3] {
__local3__1 = input
__local1__1 > __local3__1
}`},
},
{
note: "every: known domain, unknowns in body (with call+assignment)",
Expand All @@ -3125,7 +3131,12 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [1, 2, 3] { y := x+10; y > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [1, 2, 3] { plus(__local1__, 10, __local4__); __local2__ = __local4__; __local5__ = input; __local2__ > __local5__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [1, 2, 3] {
plus(__local1__1, 10, __local4__1)
__local2__1 = __local4__1
__local5__1 = input
__local2__1 > __local5__1
}`},
},
{
note: "every: known domain, unknowns in body, body impossible",
Expand All @@ -3134,7 +3145,11 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [1, 2, 3] { false; x > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [1, 2, 3] { false; __local3__ = input; __local1__ > __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [1, 2, 3] {
false
__local3__1 = input
__local1__1 > __local3__1
}`},
},
{
note: "every: unknown domain",
Expand All @@ -3143,7 +3158,7 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in input { x > 1 }
}`},
wantQueries: []string{`every __local0__, __local1__ in input { __local1__ > 1 }`},
wantQueries: []string{`every __local0__1, __local1__1 in input { __local1__1 > 1 }`},
},
{
note: "every: in-scope var in body",
Expand All @@ -3153,7 +3168,7 @@ func TestTopDownPartialEval(t *testing.T) {
y := 3
every x in [1, 2] { x != 0; input > y }
}`},
wantQueries: []string{`every __local1__, __local2__ in [1, 2] { __local2__ != 0; __local4__ = input; __local4__ > 3 }`},
wantQueries: []string{`every __local1__1, __local2__1 in [1, 2] { __local2__1 != 0; __local4__1 = input; __local4__1 > 3 }`},
},
{
note: "every: unknown domain, call in body",
Expand All @@ -3164,7 +3179,22 @@ func TestTopDownPartialEval(t *testing.T) {
y = concat(",", [x])
}
}`},
wantQueries: []string{`every __local0__, __local1__ in input { concat(",", [__local1__], __local3__); y = __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in input { concat(",", [__local1__1], __local3__1); y1 = __local3__1 }`},
},
{
note: "every: closing over function args",
query: "data.test.p",
modules: []string{`package test
p {
f(input)
}
f(x) {
every y in [1] {
a = x
1 == y
}
}`},
wantQueries: []string{`every __local1__2, __local2__2 in [1] { a2 = input; 1 = __local2__2 }`},
},
}

Expand Down

0 comments on commit c0c902c

Please sign in to comment.