Skip to content

Commit

Permalink
ast: Fix object corruption during safety reordering
Browse files Browse the repository at this point in the history
The safety check was corrupting object and set values that contained
comprehension as object keys or set elements because the comprehension
values themselves were mutated in place. This change fixes the issue
by copying object/set values like we do in other places.

This change also removes the setExprIndices function which was also
mutating values inside of a visitor.

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall committed Sep 22, 2020
1 parent fda417c commit 428219c
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 61 deletions.
94 changes: 45 additions & 49 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,7 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo

if len(unsafe[e]) == 0 {
delete(unsafe, e)
reordered = append(reordered, e)
reordered.Append(e)
}
}

Expand All @@ -2204,95 +2204,91 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo
if i > 0 {
g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams))
}
vis := &bodySafetyVisitor{
xform := &bodySafetyTransformer{
builtins: builtins,
arity: arity,
current: e,
globals: g,
unsafe: unsafe,
}
NewGenericVisitor(vis.Visit).Walk(e)
NewGenericVisitor(xform.Visit).Walk(e)
}

// Need to reset expression indices as re-ordering may have
// changed them.
setExprIndices(reordered)

return reordered, unsafe
}

type bodySafetyVisitor struct {
type bodySafetyTransformer struct {
builtins map[string]*Builtin
arity func(Ref) int
current *Expr
globals VarSet
unsafe unsafeVars
}

func (vis *bodySafetyVisitor) Visit(x interface{}) bool {
switch x := x.(type) {
case *Expr:
cpy := *vis
cpy.current = x

switch ts := x.Terms.(type) {
case *SomeDecl:
NewGenericVisitor(cpy.Visit).Walk(ts)
case []*Term:
for _, t := range ts {
NewGenericVisitor(cpy.Visit).Walk(t)
}
case *Term:
NewGenericVisitor(cpy.Visit).Walk(ts)
}
for i := range x.With {
NewGenericVisitor(cpy.Visit).Walk(x.With[i])
func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
if term, ok := x.(*Term); ok {
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
NewGenericVisitor(xform.Visit).Walk(kcpy)
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return kcpy, vcpy, nil
})
term.Value = cpy
return true
case *set:
cpy, _ := x.Map(func(v *Term) (*Term, error) {
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return vcpy, nil
})
term.Value = cpy
return true
case *ArrayComprehension:
xform.reorderArrayComprehensionSafety(x)
return true
case *ObjectComprehension:
xform.reorderObjectComprehensionSafety(x)
return true
case *SetComprehension:
xform.reorderSetComprehensionSafety(x)
return true
}
return true
case *ArrayComprehension:
vis.checkArrayComprehensionSafety(x)
return true
case *ObjectComprehension:
vis.checkObjectComprehensionSafety(x)
return true
case *SetComprehension:
vis.checkSetComprehensionSafety(x)
return true
}
return false
}

// Check term for safety. This is analogous to the rule head safety check.
func (vis *bodySafetyVisitor) checkComprehensionSafety(tv VarSet, body Body) Body {
func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body {
bv := body.Vars(SafetyCheckVisitorParams)
bv.Update(vis.globals)
bv.Update(xform.globals)
uv := tv.Diff(bv)
for v := range uv {
vis.unsafe.Add(vis.current, v)
xform.unsafe.Add(xform.current, v)
}

// Check body for safety, reordering as necessary.
r, u := reorderBodyForSafety(vis.builtins, vis.arity, vis.globals, body)
r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body)
if len(u) == 0 {
return r
}

vis.unsafe.Update(u)
xform.unsafe.Update(u)
return body
}

func (vis *bodySafetyVisitor) checkArrayComprehensionSafety(ac *ArrayComprehension) {
ac.Body = vis.checkComprehensionSafety(ac.Term.Vars(), ac.Body)
func (xform *bodySafetyTransformer) reorderArrayComprehensionSafety(ac *ArrayComprehension) {
ac.Body = xform.reorderComprehensionSafety(ac.Term.Vars(), ac.Body)
}

func (vis *bodySafetyVisitor) checkObjectComprehensionSafety(oc *ObjectComprehension) {
func (xform *bodySafetyTransformer) reorderObjectComprehensionSafety(oc *ObjectComprehension) {
tv := oc.Key.Vars()
tv.Update(oc.Value.Vars())
oc.Body = vis.checkComprehensionSafety(tv, oc.Body)
oc.Body = xform.reorderComprehensionSafety(tv, oc.Body)
}

func (vis *bodySafetyVisitor) checkSetComprehensionSafety(sc *SetComprehension) {
sc.Body = vis.checkComprehensionSafety(sc.Term.Vars(), sc.Body)
func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetComprehension) {
sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body)
}

// reorderBodyForClosures returns a copy of the body ordered such that
Expand Down
2 changes: 2 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,8 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) {
{"call-vars-input", "p { f(x, x) } f(x) = x { true }", `{x,}`},
{"call-no-output", "p { f(x) } f(x) = x { true }", `{x,}`},
{"call-too-few", "p { f(1,x) } f(x,y) { true }", "{x,}"},
{"object-key-comprehension", "p { { {p|x}: 0 } }", "{x,}"},
{"set-value-comprehension", "p { {1, {p|x}} }", "{x,}"},
}

makeErrMsg := func(varName string) string {
Expand Down
15 changes: 3 additions & 12 deletions ast/parser_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,16 +430,16 @@ func ParseBody(input string) (Body, error) {
for _, stmt := range stmts {
switch stmt := stmt.(type) {
case Body:
result = append(result, stmt...)
for i := range stmt {
result.Append(stmt[i])
}
case *Comment:
// skip
default:
return nil, fmt.Errorf("expected body but got %T", stmt)
}
}

setExprIndices(result)

return result, nil
}

Expand Down Expand Up @@ -618,15 +618,6 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu
return nil, errs
}

func setExprIndices(x interface{}) {
WalkBodies(x, func(b Body) bool {
for i, expr := range b {
expr.Index = i
}
return false
})
}

func setRuleModule(rule *Rule, module *Module) {
rule.Module = module
if rule.Else != nil {
Expand Down
13 changes: 13 additions & 0 deletions ast/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ func transformHead(t Transformer, head *Head) (*Head, error) {
}
return h, nil
}

func transformArgs(t Transformer, args Args) (Args, error) {
y, err := Transform(t, args)
if err != nil {
Expand All @@ -347,6 +348,18 @@ func transformBody(t Transformer, body Body) (Body, error) {
return r, nil
}

func transformExpr(t Transformer, expr *Expr) (*Expr, error) {
y, err := Transform(t, expr)
if err != nil {
return nil, err
}
h, ok := y.(*Expr)
if !ok {
return nil, fmt.Errorf("illegal transform: %T != %T", expr, y)
}
return h, nil
}

func transformTerm(t Transformer, term *Term) (*Term, error) {
v, err := transformValue(t, term.Value)
if err != nil {
Expand Down

0 comments on commit 428219c

Please sign in to comment.