Skip to content

Commit

Permalink
Merge branch 'main' into 2437/strict_any_all_deprecated
Browse files Browse the repository at this point in the history
  • Loading branch information
johanfylling committed Jan 30, 2022
2 parents 2a9b8c0 + 59810d0 commit 1aac9de
Show file tree
Hide file tree
Showing 39 changed files with 3,330 additions and 609 deletions.
175 changes: 155 additions & 20 deletions ast/compile.go
Expand Up @@ -258,6 +258,7 @@ func NewCompiler() *Compiler {
f func()
}{
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
// Reference resolution should run first as it may be used to lazily
// load additional modules. If any stages run before resolution, they
// need to be re-run after resolution.
Expand Down Expand Up @@ -1310,6 +1311,44 @@ func (c *Compiler) checkDuplicateImports() {
}
}

func (c *Compiler) checkKeywordOverrides() {
for _, name := range c.sorted {
mod := c.Modules[name]
errs := checkKeywordOverrides(mod, c.strict)
for _, err := range errs {
c.err(err)
}
}
}

func checkKeywordOverrides(node interface{}, strict bool) Errors {
if !strict {
return nil
}

errors := Errors{}

WalkRules(node, func(rule *Rule) bool {
name := rule.Head.Name.String()
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errors = append(errors, NewError(CompileErr, rule.Location, "rules must not shadow %v (use a different rule name)", name))
}
return true
})

WalkExprs(node, func(expr *Expr) bool {
if expr.IsAssignment() {
name := expr.Operand(0).String()
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errors = append(errors, NewError(CompileErr, expr.Location, "variables must not shadow %v (use a different variable name)", name))
}
}
return false
})

return errors
}

// resolveAllRefs resolves references in expressions to their fully qualified values.
//
// For instance, given the following module:
Expand Down Expand Up @@ -1457,7 +1496,7 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
// Visit comprehension bodies recursively to ensure print statements inside
// those bodies only close over variables that are safe.
for i := range body {
if ContainsComprehensions(body[i]) {
if ContainsClosures(body[i]) {
safe := outputVarsForBody(body[:i], getArity, globals)
safe.Update(globals)
WalkClosures(body[i], func(x interface{}) bool {
Expand All @@ -1468,6 +1507,9 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *ObjectComprehension:
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
case *Every:
safe.Update(x.KeyValueVars())
errs = rewritePrintCalls(gen, getArity, safe, x.Body)
}
return true
})
Expand Down Expand Up @@ -1529,6 +1571,8 @@ func erasePrintCalls(node interface{}) {
x.Body = erasePrintCallsInBody(x.Body)
case *ObjectComprehension:
x.Body = erasePrintCallsInBody(x.Body)
case *Every:
x.Body = erasePrintCallsInBody(x.Body)
}
return false
}).Walk(node)
Expand Down Expand Up @@ -1963,6 +2007,7 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) {
metricName string
f func(*QueryContext, Body) (Body, error)
}{
{"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides},
{"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs},
{"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars},
{"CheckVoidCalls", "query_compile_stage_check_void_calls", qc.checkVoidCalls},
Expand Down Expand Up @@ -2010,6 +2055,13 @@ func (qc *queryCompiler) applyErrorLimit(err error) error {
return err
}

func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) {
if errs := checkKeywordOverrides(body, qc.compiler.strict); len(errs) > 0 {
return nil, errs
}
return body, nil
}

func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) {

var globals map[Var]Ref
Expand Down Expand Up @@ -2869,7 +2921,8 @@ type bodySafetyTransformer struct {
}

func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
if term, ok := x.(*Term); ok {
switch term := x.(type) {
case *Term:
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
Expand Down Expand Up @@ -2899,6 +2952,12 @@ func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
xform.reorderSetComprehensionSafety(x)
return true
}
case *Expr:
if ev, ok := term.Terms.(*Every); ok {
xform.globals.Update(ev.KeyValueVars())
ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body)
return true
}
}
return false
}
Expand Down Expand Up @@ -2955,6 +3014,10 @@ func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Bod
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
return true
}
vis.Walk(x)
return true
})
Expand Down Expand Up @@ -3040,6 +3103,8 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
}

return outputVarsForExprCall(expr, ar, safe, terms)
case *Every:
return outputVarsForTerms(terms.Domain, safe)
default:
panic("illegal expression")
}
Expand Down Expand Up @@ -3067,33 +3132,27 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va
return output
}

vis := NewVarVisitor().WithParams(VarVisitorParams{
params := VarVisitorParams{
SkipClosures: true,
SkipSets: true,
SkipObjectKeys: true,
SkipRefHead: true,
})

}
vis := NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[:numInputTerms]))
unsafe := vis.Vars().Diff(output).Diff(safe)

if len(unsafe) > 0 {
return VarSet{}
}

vis = NewVarVisitor().WithParams(VarVisitorParams{
SkipRefHead: true,
SkipSets: true,
SkipObjectKeys: true,
SkipClosures: true,
})

vis = NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[numInputTerms:]))
output.Update(vis.vars)
return output
}

func outputVarsForTerms(expr *Expr, safe VarSet) VarSet {
func outputVarsForTerms(expr interface{}, safe VarSet) VarSet {
output := VarSet{}
WalkTerms(expr, func(x *Term) bool {
switch r := x.Value.(type) {
Expand Down Expand Up @@ -3289,7 +3348,7 @@ func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error {
}

func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body {
r := Body{}
r := make([]*Expr, 0, len(body))
for _, expr := range body {
r = append(r, resolveRefsInExpr(globals, ignore, expr))
}
Expand All @@ -3311,6 +3370,20 @@ func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr
if val, ok := ts.Symbols[0].Value.(Call); ok {
cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}}
}
case *Every:
locals := NewVarSet()
if ts.Key != nil {
locals.Update(ts.Key.Vars())
}
locals.Update(ts.Value.Vars())
ignore.Push(locals)
cpy.Terms = &Every{
Key: ts.Key.Copy(), // TODO(sr): do more?
Value: ts.Value.Copy(), // TODO(sr): do more?
Domain: resolveRefsInTerm(globals, ignore, ts.Domain),
Body: resolveRefsInBody(globals, ignore, ts.Body),
}
ignore.Pop()
}
for _, w := range cpy.With {
w.Target = resolveRefsInTerm(globals, ignore, w.Target)
Expand Down Expand Up @@ -3558,11 +3631,14 @@ func rewriteEquals(x interface{}) {
func rewriteDynamics(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
if expr.IsEquality() {
switch {
case expr.IsEquality():
result = rewriteDynamicsEqExpr(f, expr, result)
} else if expr.IsCall() {
case expr.IsCall():
result = rewriteDynamicsCallExpr(f, expr, result)
} else {
case expr.IsEvery():
result = rewriteDynamicsEveryExpr(f, expr, result)
default:
result = rewriteDynamicsTermExpr(f, expr, result)
}
}
Expand Down Expand Up @@ -3592,6 +3668,13 @@ func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
return appendExpr(result, expr)
}

func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body {
ev := expr.Terms.(*Every)
result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result)
ev.Body = rewriteDynamics(f, ev.Body)
return appendExpr(result, expr)
}

func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
term := expr.Terms.(*Term)
result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
Expand Down Expand Up @@ -3738,6 +3821,21 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
result = append(result, extras...)
}
result = append(result, expr)
case *Every:
var extras []*Expr
if _, ok := terms.Domain.Value.(Call); ok {
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
} else {
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain)
eq.Generated = true
eq.Location = terms.Domain.Location
extras = append(extras, eq)
terms.Domain = term
}
terms.Body = rewriteExprTermsInBody(gen, terms.Body)
result = append(result, extras...)
result = append(result, expr)
}
return
}
Expand Down Expand Up @@ -3996,11 +4094,14 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u

for i := range body {
var expr *Expr
if body[i].IsAssignment() {
switch {
case body[i].IsAssignment():
expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict)
} else if _, ok := body[i].Terms.(*SomeDecl); ok {
case body[i].IsSome():
expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict)
} else {
case body[i].IsEvery():
expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict)
default:
expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict)
}
if expr != nil {
Expand Down Expand Up @@ -4090,6 +4191,40 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe
return errs
}

func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
every := e.Terms.(*Every)

errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict)

stack.Push()
defer stack.Pop()

// optionally rewrite the key
if every.Key != nil {
if v := every.Key.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Key.Value = gv
}
}

// value is always present
if v := every.Value.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Value.Value = gv
}

used := NewVarSet()
every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict)
return e, errs
}

func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
decl := e.Terms.(*SomeDecl)
Expand Down

0 comments on commit 1aac9de

Please sign in to comment.