Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ast/compile: 'every' rewriting steps #4231

Merged
merged 20 commits into from Jan 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
84c2c7c
ast: prealloc expression
srenatus Jan 10, 2022
7f3380c
ast/visit: skip 'Every' Body when skipping closures
srenatus Jan 14, 2022
b1ce68b
ast/parser_ext: add MustParseModuleWithOpts, MustParseBodyWithOpts
srenatus Jan 14, 2022
4f9e3da
ast/compile: 'every' rewriting (dynamic, declared vars)
srenatus Jan 10, 2022
6f0b910
ast/compile: simplify every rewriting (key/val are vars)
srenatus Jan 18, 2022
3227f5d
ast/compile: deal with wildcard cases
srenatus Jan 18, 2022
22ced5d
ast/compile: add missing stack.Pop()
srenatus Jan 18, 2022
b757da8
ast/compile_test: fix indentation
srenatus Jan 18, 2022
00e3896
ast/compile: fix output vars of 'every' body
srenatus Jan 18, 2022
ddd81d8
ast/compile: safety check vars in every.Body
srenatus Jan 18, 2022
8e78ae7
ast/compile_test: add nested case for every rewriting
srenatus Jan 19, 2022
8ac09b4
ast/compiler: rewrite prints in 'every' bodies
srenatus Jan 19, 2022
0b72c1c
ast/compile_test: add "rewrite dynamics" tests for "every"
srenatus Jan 19, 2022
3954eb3
ast/compile: expand "every" domain
srenatus Jan 20, 2022
0f1bfe0
ast/compile: outputVarsForExpr: don't return vars from "Every" body
srenatus Jan 21, 2022
9289cb5
ast/compile: fix safety reordering for every
srenatus Jan 21, 2022
decbdb1
ast: rename (Every).Vars() -> (Every).KeyValueVars()
srenatus Jan 21, 2022
5f41d87
ast/compile: use VarVisitor for KeyValueVars()
srenatus Jan 29, 2022
46478b6
ast/compile: add ContainsClosures
srenatus Jan 29, 2022
5f5da56
ast/compile_test: add unused assigned var in "every" body case
srenatus Jan 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
128 changes: 108 additions & 20 deletions ast/compile.go
Expand Up @@ -1452,7 +1452,7 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
// Visit comprehension bodies recursively to ensure print statements inside
// those bodies only close over variables that are safe.
for i := range body {
if ContainsComprehensions(body[i]) {
if ContainsClosures(body[i]) {
safe := outputVarsForBody(body[:i], getArity, globals)
safe.Update(globals)
WalkClosures(body[i], func(x interface{}) bool {
Expand All @@ -1463,6 +1463,9 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *ObjectComprehension:
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *Every:
safe.Update(x.KeyValueVars())
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
}
return true
})
Expand Down Expand Up @@ -1524,6 +1527,8 @@ func erasePrintCalls(node interface{}) {
x.Body = erasePrintCallsInBody(x.Body)
case *ObjectComprehension:
x.Body = erasePrintCallsInBody(x.Body)
case *Every:
x.Body = erasePrintCallsInBody(x.Body)
}
return false
}).Walk(node)
Expand Down Expand Up @@ -2864,7 +2869,8 @@ type bodySafetyTransformer struct {
}

func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
if term, ok := x.(*Term); ok {
switch term := x.(type) {
case *Term:
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
Expand Down Expand Up @@ -2894,6 +2900,12 @@ func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
xform.reorderSetComprehensionSafety(x)
return true
}
case *Expr:
if ev, ok := term.Terms.(*Every); ok {
xform.globals.Update(ev.KeyValueVars())
ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body)
return true
}
}
return false
}
Expand Down Expand Up @@ -2950,6 +2962,10 @@ func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Bod
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
return true
}
Comment on lines +2965 to +2968
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a trap: WalkClosures checks that its node in question is of type *Every. So, when passing a function, you get an x of type *Every, and to walk its actual closure, the every.Body, it'll have to be done like this.

I wonder if there's a better approach here. Having WalkClosures apply f to ev.Body here would not allow us to differentiate what we've been called on when we need it, here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unavoidable due to the scoping rules we've implemented for every. Presumably it'll be the same for any other keywords we add that support closures and local variable declaration. This seems like the best we can do.

vis.Walk(x)
return true
})
Expand Down Expand Up @@ -3035,6 +3051,8 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
}

return outputVarsForExprCall(expr, ar, safe, terms)
case *Every:
return outputVarsForTerms(terms.Domain, safe)
default:
panic("illegal expression")
}
Expand Down Expand Up @@ -3062,33 +3080,27 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va
return output
}

vis := NewVarVisitor().WithParams(VarVisitorParams{
params := VarVisitorParams{
SkipClosures: true,
SkipSets: true,
SkipObjectKeys: true,
SkipRefHead: true,
})

}
vis := NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[:numInputTerms]))
unsafe := vis.Vars().Diff(output).Diff(safe)

if len(unsafe) > 0 {
return VarSet{}
}

vis = NewVarVisitor().WithParams(VarVisitorParams{
SkipRefHead: true,
SkipSets: true,
SkipObjectKeys: true,
SkipClosures: true,
})

vis = NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[numInputTerms:]))
output.Update(vis.vars)
return output
}

func outputVarsForTerms(expr *Expr, safe VarSet) VarSet {
func outputVarsForTerms(expr interface{}, safe VarSet) VarSet {
output := VarSet{}
WalkTerms(expr, func(x *Term) bool {
switch r := x.Value.(type) {
Expand Down Expand Up @@ -3284,7 +3296,7 @@ func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error {
}

func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body {
r := Body{}
r := make([]*Expr, 0, len(body))
for _, expr := range body {
r = append(r, resolveRefsInExpr(globals, ignore, expr))
}
Expand All @@ -3306,6 +3318,20 @@ func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr
if val, ok := ts.Symbols[0].Value.(Call); ok {
cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}}
}
case *Every:
locals := NewVarSet()
if ts.Key != nil {
locals.Update(ts.Key.Vars())
}
locals.Update(ts.Value.Vars())
ignore.Push(locals)
cpy.Terms = &Every{
Key: ts.Key.Copy(), // TODO(sr): do more?
Value: ts.Value.Copy(), // TODO(sr): do more?
Domain: resolveRefsInTerm(globals, ignore, ts.Domain),
Body: resolveRefsInBody(globals, ignore, ts.Body),
}
ignore.Pop()
}
for _, w := range cpy.With {
w.Target = resolveRefsInTerm(globals, ignore, w.Target)
Expand Down Expand Up @@ -3553,11 +3579,14 @@ func rewriteEquals(x interface{}) {
func rewriteDynamics(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
if expr.IsEquality() {
switch {
case expr.IsEquality():
result = rewriteDynamicsEqExpr(f, expr, result)
} else if expr.IsCall() {
case expr.IsCall():
result = rewriteDynamicsCallExpr(f, expr, result)
} else {
case expr.IsEvery():
result = rewriteDynamicsEveryExpr(f, expr, result)
default:
result = rewriteDynamicsTermExpr(f, expr, result)
}
}
Expand Down Expand Up @@ -3587,6 +3616,13 @@ func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
return appendExpr(result, expr)
}

func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body {
ev := expr.Terms.(*Every)
result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result)
ev.Body = rewriteDynamics(f, ev.Body)
return appendExpr(result, expr)
}

func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
term := expr.Terms.(*Term)
result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
Expand Down Expand Up @@ -3733,6 +3769,21 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
result = append(result, extras...)
}
result = append(result, expr)
case *Every:
var extras []*Expr
if _, ok := terms.Domain.Value.(Call); ok {
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
} else {
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain)
eq.Generated = true
eq.Location = terms.Domain.Location
extras = append(extras, eq)
terms.Domain = term
}
terms.Body = rewriteExprTermsInBody(gen, terms.Body)
result = append(result, extras...)
result = append(result, expr)
}
return
}
Expand Down Expand Up @@ -3991,11 +4042,14 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u

for i := range body {
var expr *Expr
if body[i].IsAssignment() {
switch {
case body[i].IsAssignment():
expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict)
} else if _, ok := body[i].Terms.(*SomeDecl); ok {
case body[i].IsSome():
expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict)
} else {
case body[i].IsEvery():
expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict)
default:
expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict)
}
if expr != nil {
Expand Down Expand Up @@ -4085,6 +4139,40 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe
return errs
}

func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
every := e.Terms.(*Every)

errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict)

stack.Push()
defer stack.Pop()

// optionally rewrite the key
if every.Key != nil {
if v := every.Key.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Key.Value = gv
}
}

// value is always present
if v := every.Value.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Value.Value = gv
}

used := NewVarSet()
every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict)
return e, errs
}

func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
decl := e.Terms.(*SomeDecl)
Expand Down