Skip to content

Commit

Permalink
ast/compile: fix print rewriting for arrays in unification
Browse files Browse the repository at this point in the history
outputVarsForExprEq, which is used in the rewriting of `print()` calls,
wasn't able to find `x` in

    [_, x, _] = f(1)

and thus certain print expressions failed during the rewriting stage of
print() calls in the compiler.

Now, the "collection (array/object) = func call" is handled.

Furthermore, this adds calls to isRefSafe in a few places: as it turned
out, the test cases

    {"array/ref", "[1,2,x] = a[_]", "[a]", "[x]"},
    {"object/ref", `{"x": x} = a[_]`, "[a]", "[x]"},

would have, without the check, passed without "a" in "safe" (3rd field).
The fact that these cases are including "a" is evidence for this being
the correct behaviour.

Fixes open-policy-agent#4078.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Dec 8, 2021
1 parent e8d1b5f commit 6f830dc
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 12 deletions.
30 changes: 30 additions & 0 deletions ast/compile_test.go
Expand Up @@ -3169,6 +3169,36 @@ func TestCompilerRewritePrintCalls(t *testing.T) {
f(__local0__) = __local2__ { true; __local2__ = {1 | __local0__[x]; __local3__ = {__local1__ | __local1__ = x}; internal.print([__local3__])} }
`,
},
{
note: "print call of var in head key",
module: `package test
f(_) = [1, 2, 3]
p[x] { [_, x, _] := f(true); print(x) }`,
exp: `package test
f(__local0__) = [1, 2, 3] { true }
p[__local2__] { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) }
`,
},
{
note: "print call of var in head value",
module: `package test
f(_) = [1, 2, 3]
p = x { [_, x, _] := f(true); print(x) }`,
exp: `package test
f(__local0__) = [1, 2, 3] { true }
p = __local2__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) }
`,
},
{
note: "print call of vars in head key and value",
module: `package test
f(_) = [1, 2, 3]
p[x] = y { [_, x, y] := f(true); print(x) }`,
exp: `package test
f(__local0__) = [1, 2, 3] { true }
p[__local2__] = __local3__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) }
`,
},
}

for _, tc := range cases {
Expand Down
44 changes: 38 additions & 6 deletions ast/unify.go
Expand Up @@ -9,10 +9,7 @@ func isRefSafe(ref Ref, safe VarSet) bool {
case Var:
return safe.Contains(head)
case Call:
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis.Walk(head)
unsafe := vis.Vars().Diff(safe)
return len(unsafe) == 0
return isCallSafe(head, safe)
default:
for v := range ref[0].Vars() {
if !safe.Contains(v) {
Expand All @@ -23,6 +20,13 @@ func isRefSafe(ref Ref, safe VarSet) bool {
}
}

func isCallSafe(call Call, safe VarSet) bool {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis.Walk(call)
unsafe := vis.Vars().Diff(safe)
return len(unsafe) == 0
}

// Unify returns a set of variables that will be unified when the equality expression defined by
// terms a and b is evaluated. The unifier assumes that variables in the VarSet safe are already
// unified.
Expand Down Expand Up @@ -67,6 +71,10 @@ func (u *unifier) unify(a *Term, b *Term) {
if isRefSafe(b, u.safe) {
u.markSafe(a)
}
case Call:
if isCallSafe(b, u.safe) {
u.markSafe(a)
}
default:
u.markSafe(a)
}
Expand All @@ -81,6 +89,16 @@ func (u *unifier) unify(a *Term, b *Term) {
}
}

case Call:
if isCallSafe(a, u.safe) {
switch b := b.Value.(type) {
case Var:
u.markSafe(b)
case *Array, Object:
u.markAllSafe(b)
}
}

case *ArrayComprehension:
switch b := b.Value.(type) {
case Var:
Expand All @@ -105,8 +123,16 @@ func (u *unifier) unify(a *Term, b *Term) {
switch b := b.Value.(type) {
case Var:
u.unifyAll(b, a)
case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension:
case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
u.markAllSafe(a)
case Ref:
if isRefSafe(b, u.safe) {
u.markAllSafe(a)
}
case Call:
if isCallSafe(b, u.safe) {
u.markAllSafe(a)
}
case *Array:
if a.Len() == b.Len() {
for i := 0; i < a.Len(); i++ {
Expand All @@ -120,7 +146,13 @@ func (u *unifier) unify(a *Term, b *Term) {
case Var:
u.unifyAll(b, a)
case Ref:
u.markAllSafe(a)
if isRefSafe(b, u.safe) {
u.markAllSafe(a)
}
case Call:
if isCallSafe(b, u.safe) {
u.markAllSafe(a)
}
case *object:
if a.Len() == b.Len() {
_ = a.Iter(func(k, v *Term) error {
Expand Down
55 changes: 49 additions & 6 deletions ast/unify_test.go
Expand Up @@ -4,7 +4,10 @@

package ast

import "testing"
import (
"fmt"
"testing"
)

func TestUnify(t *testing.T) {

Expand All @@ -28,10 +31,16 @@ func TestUnify(t *testing.T) {
{"object/var", `{"x": 1, "y": x} = y`, "[x]", "[y]"},
{"object/var (reversed)", `y = {"x": 1, "y": x}`, "[x]", "[y]"},
{"object/var-2", `{"x": 1, "y": x} = y`, "[y]", "[x]"},
{"object/var-3", `{"x": 1, "y": x} = y`, "[]", "[]"},
{"object/uneven", `{"x": x, "y": 1} = {"x": y}`, "[]", "[]"},
{"object/uneven", `{"x": x, "y": 1} = {"x": y}`, "[x]", "[]"},
{"call", "x = f(y)[z]", "[y]", "[x]"},
{"var/call-ref", "x = f(y)[z]", "[y]", "[x]"},
{"var/call-ref (reversed)", "f(y)[z] = x", "[y]", "[x]"},
{"var/call", "x = f(z)", "[z]", "[x]"},
{"var/call (reversed)", "f(z) = x", "[z]", "[x]"},
{"array/call", "[x, y] = f(z)", "[z]", "[x,y]"},
{"array/call (reversed)", "f(z) = [x, y]", "[z]", "[x,y]"},
{"object/call", `{"a": x} = f(z)`, "[z]", "[x]"},
{"object/call (reversed)", `f(z) = {"a": x}`, "[z]", "[x]"},

// transitive cases
{"trans/redundant", "[x, x] = [x, 0]", "[]", "[x]"},
Expand All @@ -43,10 +52,44 @@ func TestUnify(t *testing.T) {
{"trans/redundant-nested", "[x, z, z] = [1, [y, x], [2, 1]]", "[]", "[x, y, z]"},
{"trans/bidirectional", "[x, z, y] = [[z,y], [1,y], 2]", "[]", "[x, y, z]"},
{"trans/occurs", "[x, z, y] = [[y,z], [y, 1], [2, x]]", "[]", "[]"},

// unsafe refs
{note: "array/ref", expr: "[1,2,x] = a[_]"},
{note: "array/ref (reversed)", expr: "a[_] = [1,2,x]"},
{note: "object/ref", expr: `{"x": x} = a[_]`},
{note: "object/ref (reversed)", expr: `a[_] = {"x": x}`},
{note: "var/call-ref", expr: "x = f(y)[z]"},
{note: "var/call-ref (reversed)", expr: "f(y)[z] = x"},

// unsafe vars
{note: "array/var", expr: "[1,2,x] = y"},
{note: "array/var (reversed)", expr: "y = [1,2,x]"},
{note: "object/var", expr: `{"x": 1, "y": x} = y`},
{note: "object/var (reversed)", expr: `y = {"x": 1, "y": x}`},
{note: "var/call", expr: "x = f(z)"},
{note: "var/call (reversed)", expr: "f(z) = x"},

// unsafe call args
{note: "var/call-2", expr: "x = f(z)", safe: "[x]"},
{note: "var/call-2 (reversed)", expr: "f(z) = x", safe: "[x]"},
{note: "array/call", expr: "[x, y] = f(z)", safe: "[x,y]"},
{note: "array/call (reversed)", expr: "f(z) = [x, y]", safe: "[x,y]"},
{note: "object/call", expr: `{"a": x} = f(z)`, safe: "[x]"},
{note: "object/call (reversed)", expr: `f(z) = {"a": x}`, safe: "[x]"},

// partial cases
{note: "trans/ref", expr: "[x, y, [x, y, i]] = [1, a[i], z]", safe: "[a]", expected: "[x, y]"},
{note: "trans/ref", expr: "[x, y, [x, y, i]] = [1, a[i], z]", expected: "[x]"},
}

for i, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
for _, tc := range tests {
if tc.expected == "" {
tc.expected = "[]"
}
if tc.safe == "" {
tc.safe = "[]"
}
t.Run(fmt.Sprintf("%s/%s/%s", tc.note, tc.safe, tc.expected), func(t *testing.T) {

expr := MustParseBody(tc.expr)[0]
safe := VarSet{}
Expand Down Expand Up @@ -74,7 +117,7 @@ func TestUnify(t *testing.T) {
missing := expected.Diff(result)
extra := result.Diff(expected)
if len(missing) != 0 || len(extra) != 0 {
t.Fatalf("%s (%d): Missing vars: %v, extra vars: %v", tc.note, i, missing, extra)
t.Fatalf("missing vars: %v, extra vars: %v", missing, extra)
}
})
}
Expand Down

0 comments on commit 6f830dc

Please sign in to comment.