Skip to content

Commit

Permalink
ast/compile: 'every' rewriting steps (#4231)
Browse files Browse the repository at this point in the history
This commit covers the compiler doing it's job on "every":

- resolve refs
- rewriting other declared vars in domain
- rewriting its declared vars and others in its body
- rewriting dynamics in the body
- expanding expressions
- safety body reordering and var checking

No evaluation happening yet. That's future work.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Jan 29, 2022
1 parent d12fb7c commit cb867a1
Show file tree
Hide file tree
Showing 7 changed files with 477 additions and 46 deletions.
128 changes: 108 additions & 20 deletions ast/compile.go
Expand Up @@ -1452,7 +1452,7 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
// Visit comprehension bodies recursively to ensure print statements inside
// those bodies only close over variables that are safe.
for i := range body {
if ContainsComprehensions(body[i]) {
if ContainsClosures(body[i]) {
safe := outputVarsForBody(body[:i], getArity, globals)
safe.Update(globals)
WalkClosures(body[i], func(x interface{}) bool {
Expand All @@ -1463,6 +1463,9 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *ObjectComprehension:
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *Every:
safe.Update(x.KeyValueVars())
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
}
return true
})
Expand Down Expand Up @@ -1524,6 +1527,8 @@ func erasePrintCalls(node interface{}) {
x.Body = erasePrintCallsInBody(x.Body)
case *ObjectComprehension:
x.Body = erasePrintCallsInBody(x.Body)
case *Every:
x.Body = erasePrintCallsInBody(x.Body)
}
return false
}).Walk(node)
Expand Down Expand Up @@ -2864,7 +2869,8 @@ type bodySafetyTransformer struct {
}

func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
if term, ok := x.(*Term); ok {
switch term := x.(type) {
case *Term:
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
Expand Down Expand Up @@ -2894,6 +2900,12 @@ func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
xform.reorderSetComprehensionSafety(x)
return true
}
case *Expr:
if ev, ok := term.Terms.(*Every); ok {
xform.globals.Update(ev.KeyValueVars())
ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body)
return true
}
}
return false
}
Expand Down Expand Up @@ -2950,6 +2962,10 @@ func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Bod
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
return true
}
vis.Walk(x)
return true
})
Expand Down Expand Up @@ -3035,6 +3051,8 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
}

return outputVarsForExprCall(expr, ar, safe, terms)
case *Every:
return outputVarsForTerms(terms.Domain, safe)
default:
panic("illegal expression")
}
Expand Down Expand Up @@ -3062,33 +3080,27 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va
return output
}

vis := NewVarVisitor().WithParams(VarVisitorParams{
params := VarVisitorParams{
SkipClosures: true,
SkipSets: true,
SkipObjectKeys: true,
SkipRefHead: true,
})

}
vis := NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[:numInputTerms]))
unsafe := vis.Vars().Diff(output).Diff(safe)

if len(unsafe) > 0 {
return VarSet{}
}

vis = NewVarVisitor().WithParams(VarVisitorParams{
SkipRefHead: true,
SkipSets: true,
SkipObjectKeys: true,
SkipClosures: true,
})

vis = NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[numInputTerms:]))
output.Update(vis.vars)
return output
}

func outputVarsForTerms(expr *Expr, safe VarSet) VarSet {
func outputVarsForTerms(expr interface{}, safe VarSet) VarSet {
output := VarSet{}
WalkTerms(expr, func(x *Term) bool {
switch r := x.Value.(type) {
Expand Down Expand Up @@ -3284,7 +3296,7 @@ func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error {
}

func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body {
r := Body{}
r := make([]*Expr, 0, len(body))
for _, expr := range body {
r = append(r, resolveRefsInExpr(globals, ignore, expr))
}
Expand All @@ -3306,6 +3318,20 @@ func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr
if val, ok := ts.Symbols[0].Value.(Call); ok {
cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}}
}
case *Every:
locals := NewVarSet()
if ts.Key != nil {
locals.Update(ts.Key.Vars())
}
locals.Update(ts.Value.Vars())
ignore.Push(locals)
cpy.Terms = &Every{
Key: ts.Key.Copy(), // TODO(sr): do more?
Value: ts.Value.Copy(), // TODO(sr): do more?
Domain: resolveRefsInTerm(globals, ignore, ts.Domain),
Body: resolveRefsInBody(globals, ignore, ts.Body),
}
ignore.Pop()
}
for _, w := range cpy.With {
w.Target = resolveRefsInTerm(globals, ignore, w.Target)
Expand Down Expand Up @@ -3553,11 +3579,14 @@ func rewriteEquals(x interface{}) {
func rewriteDynamics(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
if expr.IsEquality() {
switch {
case expr.IsEquality():
result = rewriteDynamicsEqExpr(f, expr, result)
} else if expr.IsCall() {
case expr.IsCall():
result = rewriteDynamicsCallExpr(f, expr, result)
} else {
case expr.IsEvery():
result = rewriteDynamicsEveryExpr(f, expr, result)
default:
result = rewriteDynamicsTermExpr(f, expr, result)
}
}
Expand Down Expand Up @@ -3587,6 +3616,13 @@ func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
return appendExpr(result, expr)
}

func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body {
ev := expr.Terms.(*Every)
result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result)
ev.Body = rewriteDynamics(f, ev.Body)
return appendExpr(result, expr)
}

func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
term := expr.Terms.(*Term)
result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
Expand Down Expand Up @@ -3733,6 +3769,21 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
result = append(result, extras...)
}
result = append(result, expr)
case *Every:
var extras []*Expr
if _, ok := terms.Domain.Value.(Call); ok {
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
} else {
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain)
eq.Generated = true
eq.Location = terms.Domain.Location
extras = append(extras, eq)
terms.Domain = term
}
terms.Body = rewriteExprTermsInBody(gen, terms.Body)
result = append(result, extras...)
result = append(result, expr)
}
return
}
Expand Down Expand Up @@ -3991,11 +4042,14 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u

for i := range body {
var expr *Expr
if body[i].IsAssignment() {
switch {
case body[i].IsAssignment():
expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict)
} else if _, ok := body[i].Terms.(*SomeDecl); ok {
case body[i].IsSome():
expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict)
} else {
case body[i].IsEvery():
expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict)
default:
expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict)
}
if expr != nil {
Expand Down Expand Up @@ -4085,6 +4139,40 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe
return errs
}

func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
every := e.Terms.(*Every)

errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict)

stack.Push()
defer stack.Pop()

// optionally rewrite the key
if every.Key != nil {
if v := every.Key.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Key.Value = gv
}
}

// value is always present
if v := every.Value.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Value.Value = gv
}

used := NewVarSet()
every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict)
return e, errs
}

func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
decl := e.Terms.(*SomeDecl)
Expand Down

0 comments on commit cb867a1

Please sign in to comment.