Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topdown/eval: fix 'every' term plugging on save #4775

Merged
merged 2 commits into from Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions topdown/eval.go
Expand Up @@ -2978,9 +2978,7 @@ func (e evalTerm) save(iter unifyIterator) error {
suffix := e.ref[e.pos:]
ref := make(ast.Ref, len(suffix)+1)
ref[0] = v
for i := 0; i < len(suffix); i++ {
ref[i+1] = suffix[i]
}
copy(ref[1:], suffix)

return e.e.biunify(ast.NewTerm(ref), e.rterm, e.bindings, e.rbindings, iter)
})
Expand Down Expand Up @@ -3044,18 +3042,20 @@ func (e *evalEvery) save(iter unifyIterator) error {
cpy := e.expr.Copy()
every := cpy.Terms.(*ast.Every)

for i := range every.Body { // TODO(sr): is there an easier way?
for i := range every.Body {
switch t := every.Body[i].Terms.(type) {
case *ast.Term:
every.Body[i].Terms = e.e.bindings.Plug(t)
every.Body[i].Terms = e.e.bindings.PlugNamespaced(t, e.e.caller.bindings)
case []*ast.Term:
for j := range t {
t[j] = e.e.bindings.Plug(t[j])
for j := 1; j < len(t); j++ { // don't plug operator, t[0]
t[j] = e.e.bindings.PlugNamespaced(t[j], e.e.caller.bindings)
}
}
}
Comment on lines +3045 to 3054
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to these changes, but it looks like we're only dealing with single terms and calls here. What happens if the body contains an ast.Every? Is there some rewriting that I'm not aware of?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good find. Probably an oversight. I'll look into it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


every.Domain = e.e.bindings.plugNamespaced(every.Domain, e.e.caller.bindings)
every.Key = e.e.bindings.PlugNamespaced(every.Key, e.e.caller.bindings)
every.Value = e.e.bindings.PlugNamespaced(every.Value, e.e.caller.bindings)
every.Domain = e.e.bindings.PlugNamespaced(every.Domain, e.e.caller.bindings)
cpy.Terms = every

return e.e.saveExpr(cpy, e.e.bindings, iter)
Expand Down
44 changes: 37 additions & 7 deletions topdown/topdown_partial_test.go
Expand Up @@ -3107,7 +3107,10 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [] { x > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [] { __local3__ = input; __local1__ > __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [] {
__local3__1 = input
__local1__1 > __local3__1
}`},
},
{
note: "every: known domain, unknowns in body",
Expand All @@ -3116,7 +3119,10 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [1, 2, 3] { x > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [1, 2, 3] { __local3__ = input; __local1__ > __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [1, 2, 3] {
__local3__1 = input
__local1__1 > __local3__1
}`},
},
{
note: "every: known domain, unknowns in body (with call+assignment)",
Expand All @@ -3125,7 +3131,12 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [1, 2, 3] { y := x+10; y > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [1, 2, 3] { plus(__local1__, 10, __local4__); __local2__ = __local4__; __local5__ = input; __local2__ > __local5__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [1, 2, 3] {
plus(__local1__1, 10, __local4__1)
__local2__1 = __local4__1
__local5__1 = input
__local2__1 > __local5__1
}`},
},
{
note: "every: known domain, unknowns in body, body impossible",
Expand All @@ -3134,7 +3145,11 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in [1, 2, 3] { false; x > input }
}`},
wantQueries: []string{`every __local0__, __local1__ in [1, 2, 3] { false; __local3__ = input; __local1__ > __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in [1, 2, 3] {
false
__local3__1 = input
__local1__1 > __local3__1
}`},
},
{
note: "every: unknown domain",
Expand All @@ -3143,7 +3158,7 @@ func TestTopDownPartialEval(t *testing.T) {
p {
every x in input { x > 1 }
}`},
wantQueries: []string{`every __local0__, __local1__ in input { __local1__ > 1 }`},
wantQueries: []string{`every __local0__1, __local1__1 in input { __local1__1 > 1 }`},
},
{
note: "every: in-scope var in body",
Expand All @@ -3153,7 +3168,7 @@ func TestTopDownPartialEval(t *testing.T) {
y := 3
every x in [1, 2] { x != 0; input > y }
}`},
wantQueries: []string{`every __local1__, __local2__ in [1, 2] { __local2__ != 0; __local4__ = input; __local4__ > 3 }`},
wantQueries: []string{`every __local1__1, __local2__1 in [1, 2] { __local2__1 != 0; __local4__1 = input; __local4__1 > 3 }`},
},
{
note: "every: unknown domain, call in body",
Expand All @@ -3164,7 +3179,22 @@ func TestTopDownPartialEval(t *testing.T) {
y = concat(",", [x])
}
}`},
wantQueries: []string{`every __local0__, __local1__ in input { concat(",", [__local1__], __local3__); y = __local3__ }`},
wantQueries: []string{`every __local0__1, __local1__1 in input { concat(",", [__local1__1], __local3__1); y1 = __local3__1 }`},
},
{
note: "every: closing over function args",
query: "data.test.p",
modules: []string{`package test
p {
f(input)
}
f(x) {
every y in [1] {
a = x
1 == y
}
}`},
wantQueries: []string{`every __local1__2, __local2__2 in [1] { a2 = input; 1 = __local2__2 }`},
},
}

Expand Down