Skip to content

Commit

Permalink
ast: Fix panic during local var rewriting
Browse files Browse the repository at this point in the history
This commit fixes an issue similar to
e88579b: when a comprehension is
nested inside of a set or used as an object key, the rewriting needs
to be careful to make a copy of the set/object to avoid mutating the
elemenet/key in-place.

Fixes open-policy-agent#2720

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall committed Sep 24, 2020
1 parent 186ef99 commit 120634e
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 27 deletions.
90 changes: 63 additions & 27 deletions ast/compile.go
Expand Up @@ -1102,8 +1102,6 @@ func (c *Compiler) rewriteLocalVars() {

WalkRules(mod, func(rule *Rule) bool {

var errs Errors

// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
Expand All @@ -1114,29 +1112,14 @@ func (c *Compiler) rewriteLocalVars() {
// This behaviour is consistent scoping inside the body. For example:
//
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
WalkTerms(rule.Head, func(term *Term) bool {
stop := false
stack := newLocalDeclaredVars()
switch v := term.Value.(type) {
case *ArrayComprehension:
errs = rewriteDeclaredVarsInArrayComprehension(gen, stack, v, errs)
stop = true
case *SetComprehension:
errs = rewriteDeclaredVarsInSetComprehension(gen, stack, v, errs)
stop = true
case *ObjectComprehension:
errs = rewriteDeclaredVarsInObjectComprehension(gen, stack, v, errs)
stop = true
}

for k, v := range stack.rewritten {
c.RewrittenVars[k] = v
}
nestedXform := &rewriteNestedHeadVarLocalTransform{
gen: gen,
RewrittenVars: c.RewrittenVars,
}

return stop
})
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)

for _, err := range errs {
for _, err := range nestedXform.errs {
c.err(err)
}

Expand Down Expand Up @@ -1169,25 +1152,78 @@ func (c *Compiler) rewriteLocalVars() {
rule.Body = body

// Rewrite vars in head that refer to locally declared vars in the body.
xform := rewriteHeadVarLocalTransform{declared: declared}
localXform := rewriteHeadVarLocalTransform{declared: declared}

for i := range rule.Head.Args {
rule.Head.Args[i], _ = transformTerm(xform, rule.Head.Args[i])
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
}

if rule.Head.Key != nil {
rule.Head.Key, _ = transformTerm(xform, rule.Head.Key)
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
}

if rule.Head.Value != nil {
rule.Head.Value, _ = transformTerm(xform, rule.Head.Value)
rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value)
}

return false
})
}
}

type rewriteNestedHeadVarLocalTransform struct {
gen *localVarGenerator
errs Errors
RewrittenVars map[Var]Var
}

func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool {

if term, ok := x.(*Term); ok {

stop := false
stack := newLocalDeclaredVars()

switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
NewGenericVisitor(xform.Visit).Walk(kcpy)
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return kcpy, vcpy, nil
})
term.Value = cpy
stop = true
case *set:
cpy, _ := x.Map(func(v *Term) (*Term, error) {
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return vcpy, nil
})
term.Value = cpy
stop = true
case *ArrayComprehension:
xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs)
stop = true
case *SetComprehension:
xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs)
stop = true
case *ObjectComprehension:
xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs)
stop = true
}

for k, v := range stack.rewritten {
xform.RewrittenVars[k] = v
}

return stop
}

return false
}

type rewriteHeadVarLocalTransform struct {
declared map[Var]Var
}
Expand Down
37 changes: 37 additions & 0 deletions ast/compile_test.go
Expand Up @@ -1941,6 +1941,40 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) {
Var("__local0__"): Var("x"),
},
},
{
module: `
package test
f({{t | t := 0}: 1}) {
true
}
`,
exp: `
package test
f({{__local0__ | __local0__ = 0}: 1}) { true }
`,
expRewrittenMap: map[Var]Var{
Var("__local0__"): Var("t"),
},
},
{
module: `
package test
f({{t | t := 0}}) {
true
}
`,
exp: `
package test
f({{__local0__ | __local0__ = 0}}) { true }
`,
expRewrittenMap: map[Var]Var{
Var("__local0__"): Var("t"),
},
},
}

for i, tc := range tests {
Expand Down Expand Up @@ -2007,6 +2041,8 @@ func TestRewriteLocalVarDeclarationErrors(t *testing.T) {
arg_redeclared(arg1) {
arg1 := 1
}
arg_nested_redeclared({{arg_nested| arg_nested := 1; arg_nested := 2}}) { true }
`)

compileStages(c, c.rewriteLocalVars)
Expand All @@ -2017,6 +2053,7 @@ func TestRewriteLocalVarDeclarationErrors(t *testing.T) {
"var input referenced above",
"var nested assigned above",
"arg arg1 redeclared",
"var arg_nested assigned above",
"cannot assign vars inside negated expression",
"cannot assign to ref",
"cannot assign to arraycomprehension",
Expand Down

0 comments on commit 120634e

Please sign in to comment.