From 84c2c7cd6f45aae31e71882d3446a0a84ba89784 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Mon, 10 Jan 2022 14:22:38 +0100 Subject: [PATCH 01/20] ast: prealloc expression Signed-off-by: Stephan Renatus --- ast/compile.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ast/compile.go b/ast/compile.go index 9d89ea5628..7cc4285d9f 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -3284,7 +3284,7 @@ func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error { } func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body { - r := Body{} + r := make([]*Expr, 0, len(body)) for _, expr := range body { r = append(r, resolveRefsInExpr(globals, ignore, expr)) } From 7f3380c2ef9b689d72e2acf2f6b27d1a2413055c Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Fri, 14 Jan 2022 11:51:47 +0100 Subject: [PATCH 02/20] ast/visit: skip 'Every' Body when skipping closures Signed-off-by: Stephan Renatus --- ast/visit.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ast/visit.go b/ast/visit.go index 893d867bc0..f1950a54b4 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -542,6 +542,8 @@ func (vis *VarVisitor) Vars() VarSet { return vis.vars } +// visit determines if the VarVisitor will recurse into x: if it returns `true`, +// the visitor will _skip_ that branch of the AST func (vis *VarVisitor) visit(v interface{}) bool { if vis.params.SkipObjectKeys { if o, ok := v.(Object); ok { @@ -560,9 +562,13 @@ func (vis *VarVisitor) visit(v interface{}) bool { } } if vis.params.SkipClosures { - switch v.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: + switch v := v.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: return true + case *Expr: + if _, ok := v.Terms.(*Every); ok { + return true + } } } if vis.params.SkipWithTarget { From b1ce68b830fad8318fa894e07406b2b9f48b6b4d Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Fri, 14 Jan 2022 12:16:56 +0100 Subject: [PATCH 03/20] ast/parser_ext: add MustParseModuleWithOpts, MustParseBodyWithOpts Signed-off-by: Stephan Renatus --- ast/parser_ext.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ast/parser_ext.go b/ast/parser_ext.go index 6fa3b611c5..a1f2fbb959 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -22,7 +22,13 @@ import ( // MustParseBody returns a parsed body. // If an error occurs during parsing, panic. func MustParseBody(input string) Body { - parsed, err := ParseBody(input) + return MustParseBodyWithOpts(input, ParserOptions{}) +} + +// MustParseBodyWithOpts returns a parsed body. +// If an error occurs during parsing, panic. +func MustParseBodyWithOpts(input string, opts ParserOptions) Body { + parsed, err := ParseBodyWithOpts(input, opts) if err != nil { panic(err) } @@ -52,7 +58,13 @@ func MustParseImports(input string) []*Import { // MustParseModule returns a parsed module. // If an error occurs during parsing, panic. func MustParseModule(input string) *Module { - parsed, err := ParseModule("", input) + return MustParseModuleWithOpts(input, ParserOptions{}) +} + +// MustParseModuleWithOpts returns a parsed module. +// If an error occurs during parsing, panic. +func MustParseModuleWithOpts(input string, opts ParserOptions) *Module { + parsed, err := ParseModuleWithOpts("", input, opts) if err != nil { panic(err) } From 4f9e3da90f9b1bc98d691cbd01b8556561644e3c Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Mon, 10 Jan 2022 13:38:02 +0100 Subject: [PATCH 04/20] ast/compile: 'every' rewriting (dynamic, declared vars) Signed-off-by: Stephan Renatus --- ast/compile.go | 72 +++++++++++++++++++++++++++++++++++++++----- ast/compile_test.go | 72 +++++++++++++++++++++++++++++++++++++++++--- ast/compilehelper.go | 3 +- ast/policy.go | 12 ++++++++ 4 files changed, 146 insertions(+), 13 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 7cc4285d9f..ca39937738 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -3035,6 +3035,8 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet { } return outputVarsForExprCall(expr, ar, safe, terms) + case *Every: + return outputVarsForTerms(terms.Domain, safe) default: panic("illegal expression") } @@ -3088,7 +3090,7 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va return output } -func outputVarsForTerms(expr *Expr, safe VarSet) VarSet { +func outputVarsForTerms(expr interface{}, safe VarSet) VarSet { output := VarSet{} WalkTerms(expr, func(x *Term) bool { switch r := x.Value.(type) { @@ -3306,6 +3308,20 @@ func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr if val, ok := ts.Symbols[0].Value.(Call); ok { cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}} } + case *Every: + locals := NewVarSet() + if ts.Key != nil { + locals.Update(ts.Key.Vars()) + } + locals.Update(ts.Value.Vars()) + ignore.Push(locals) + cpy.Terms = &Every{ + Key: ts.Key.Copy(), // TODO(sr): do more? + Value: ts.Value.Copy(), // TODO(sr): do more? + Domain: resolveRefsInTerm(globals, ignore, ts.Domain), + Body: resolveRefsInBody(globals, ignore, ts.Body), + } + ignore.Pop() } for _, w := range cpy.With { w.Target = resolveRefsInTerm(globals, ignore, w.Target) @@ -3553,11 +3569,14 @@ func rewriteEquals(x interface{}) { func rewriteDynamics(f *equalityFactory, body Body) Body { result := make(Body, 0, len(body)) for _, expr := range body { - if expr.IsEquality() { + switch { + case expr.IsEquality(): result = rewriteDynamicsEqExpr(f, expr, result) - } else if expr.IsCall() { + case expr.IsCall(): result = rewriteDynamicsCallExpr(f, expr, result) - } else { + case expr.IsEvery(): + result = rewriteDynamicsEveryExpr(f, expr, result) + default: result = rewriteDynamicsTermExpr(f, expr, result) } } @@ -3587,6 +3606,13 @@ func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body { return appendExpr(result, expr) } +func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body { + ev := expr.Terms.(*Every) + result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result) + ev.Body = rewriteDynamics(f, ev.Body) + return appendExpr(result, expr) +} + func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body { term := expr.Terms.(*Term) result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result) @@ -3733,6 +3759,12 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { result = append(result, extras...) } result = append(result, expr) + case *Every: + var extras []*Expr + extras, terms.Domain = expandExprTerm(gen, terms.Domain) + terms.Body = rewriteExprTermsInBody(gen, terms.Body) + result = append(result, extras...) + result = append(result, expr) } return } @@ -3991,11 +4023,14 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u for i := range body { var expr *Expr - if body[i].IsAssignment() { + switch { + case body[i].IsAssignment(): expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict) - } else if _, ok := body[i].Terms.(*SomeDecl); ok { + case body[i].IsSome(): expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict) - } else { + case body[i].IsEvery(): + expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict) + default: expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict) } if expr != nil { @@ -4085,6 +4120,29 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe return errs } +func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { + e := expr.Copy() + every := e.Terms.(*Every) + declared := NewVarSet() + if every.Key != nil { + declared.Update(every.Key.Vars()) + } + declared.Update(every.Value.Vars()) + for _, v := range declared.Sorted() { + if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) + } + } + used := NewVarSet() + if every.Key != nil { + errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Key, errs, strict) + } + errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Value, errs, strict) + errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict) + every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict) + return e, errs +} + func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { e := expr.Copy() decl := e.Terms.(*SomeDecl) diff --git a/ast/compile_test.go b/ast/compile_test.go index 1262e16e24..673e14982c 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -203,12 +203,26 @@ func TestOutputVarsForNode(t *testing.T) { query: `z = "abc"; x = split(z, a)[y]`, exp: `{z}`, }, + { + note: "every: simple: no output vars", + query: `every k, v in [1, 2] { k < v }`, + exp: `set()`, + }, + { + note: "every: output vars in domain", + query: `xs = []; every k, v in xs[i] { k < v }`, + exp: `{xs, i}`, + }, } for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - body := MustParseBody(tc.query) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + body, err := ParseBodyWithOpts(tc.query, opts) + if err != nil { + t.Fatal(err) + } arity := func(r Ref) int { a, ok := tc.arities[r.String()] if !ok { @@ -1654,6 +1668,19 @@ p[foo[bar[i]]] = {"baz": baz} { true }`) r = [y | y = f(1)[0]] `) + c.Modules["everykw"] = MustParseModuleWithOpts(`package everykw + + nums = {1, 2, 3} + f(_) = true + x = 100 + xs = [1, 2, 3] + p { + every x in xs { + nums[x] + x > 10 + } + }`, ParserOptions{unreleasedKeywords: true, FutureKeywords: []string{"every", "in"}}) + compileStages(c, c.resolveAllRefs) assertNotFailed(t, c) @@ -1780,6 +1807,17 @@ p[foo[bar[i]]] = {"baz": baz} { true }`) assertTermEqual(t, someInAssignCall[2], VarTerm("v")) collectionLastElem = someInAssignCall[3].Value.(*Array).Get(IntNumberTerm(2)) assertTermEqual(t, collectionLastElem, MustParseTerm("data.someinassignwithkey.y")) + + mod16 := c.Modules["everykw"] + everyExpr := mod16.Rules[len(mod16.Rules)-1].Body[0].Terms.(*Every) + assertTermEqual(t, everyExpr.Body[0].Terms.(*Term), MustParseTerm("data.everykw.nums[x]")) + assertTermEqual(t, everyExpr.Domain, MustParseTerm("data.everykw.xs")) + + // 'x' is not resolved + assertTermEqual(t, everyExpr.Value, VarTerm("x")) + gt10 := MustParseExpr("x > 10") + gt10.Index++ // TODO(sr): why? + assertExprEqual(t, everyExpr.Body[1], gt10) } func TestCompilerResolveErrors(t *testing.T) { @@ -2614,6 +2652,29 @@ func TestRewriteDeclaredVars(t *testing.T) { p = true { __local2__ = data.test.i; __local1__ = data.test.xs[__local2__][__local0__]; __local1__ = "a"; __local0__ = 2 } `, }, + { + note: "rewrite every", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + i = 0 + xs = [1, 2] + k = "foo" + v = "bar" + p { + every k, v in xs { k + v > i } + } + `, + exp: ` + package test + i = 0 + xs = [1, 2] + k = "foo" + v = "bar" + p = true { __local3__ = data.test.xs; every __local0__, __local1__ in __local3__ { plus(__local0__, __local1__, __local2__); __local4__ = data.test.i; gt(__local2__, __local4__) } } + `, + }, { note: "rewrite closures", module: ` @@ -2757,7 +2818,8 @@ func TestRewriteDeclaredVars(t *testing.T) { for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { - compiler, err := CompileModules(map[string]string{"test.rego": tc.module}) + opts := CompileOpts{ParserOptions: ParserOptions{FutureKeywords: []string{"in", "every"}, unreleasedKeywords: true}} + compiler, err := CompileModulesWithOpt(map[string]string{"test.rego": tc.module}, opts) if tc.wantErr != nil { if err == nil { t.Fatal("Expected error but got success") @@ -2768,7 +2830,7 @@ func TestRewriteDeclaredVars(t *testing.T) { } else if err != nil { t.Fatal(err) } else { - exp := MustParseModule(tc.exp) + exp := MustParseModuleWithOpts(tc.exp, opts.ParserOptions) result := compiler.Modules["test.rego"] if exp.Compare(result) != 0 { t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, result) @@ -2938,10 +3000,10 @@ func TestCompilerRewriteDynamicTerms(t *testing.T) { t.Run(tc.input, func(t *testing.T) { c := NewCompiler() module := fixture + tc.input - c.Modules["test"] = MustParseModule(module) + c.Modules["test"] = MustParseModuleWithOpts(module, ParserOptions{AllFutureKeywords: true}) compileStages(c, c.rewriteDynamicTerms) assertNotFailed(t, c) - expected := MustParseBody(tc.expected) + expected := MustParseBodyWithOpts(tc.expected, ParserOptions{AllFutureKeywords: true}) result := c.Modules["test"].Rules[1].Body if result.Compare(expected) != 0 { t.Fatalf("\nExp: %v\nGot: %v", expected, result) diff --git a/ast/compilehelper.go b/ast/compilehelper.go index ca75dfabae..dd48884f9d 100644 --- a/ast/compilehelper.go +++ b/ast/compilehelper.go @@ -13,6 +13,7 @@ func CompileModules(modules map[string]string) (*Compiler, error) { // CompileOpts defines a set of options for the compiler. type CompileOpts struct { EnablePrintStatements bool + ParserOptions ParserOptions } // CompileModulesWithOpt takes a set of Rego modules represented as strings and @@ -24,7 +25,7 @@ func CompileModulesWithOpt(modules map[string]string, opts CompileOpts) (*Compil for f, module := range modules { var pm *Module var err error - if pm, err = ParseModule(f, module); err != nil { + if pm, err = ParseModuleWithOpts(f, module, opts.ParserOptions); err != nil { return nil, err } parsed[f] = pm diff --git a/ast/policy.go b/ast/policy.go index ef997638e4..5265f42f5b 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -1262,6 +1262,18 @@ func (expr *Expr) IsCall() bool { return ok } +// IsEvery returns true if this expression is an 'every' expression. +func (expr *Expr) IsEvery() bool { + _, ok := expr.Terms.(*Every) + return ok +} + +// IsSome returns true if this expression is a 'some' expression. +func (expr *Expr) IsSome() bool { + _, ok := expr.Terms.(*SomeDecl) + return ok +} + // Operator returns the name of the function or built-in this expression refers // to. If this expression is not a function call, returns nil. func (expr *Expr) Operator() Ref { From 6f0b910943c6cb9622c3fb032960bb7944e0d795 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Tue, 18 Jan 2022 10:59:08 +0100 Subject: [PATCH 05/20] ast/compile: simplify every rewriting (key/val are vars) Signed-off-by: Stephan Renatus --- ast/compile.go | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index ca39937738..190d183148 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -4123,22 +4123,28 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) { e := expr.Copy() every := e.Terms.(*Every) - declared := NewVarSet() + + errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict) + + stack.Push() + + // optionally rewrite the key if every.Key != nil { - declared.Update(every.Key.Vars()) - } - declared.Update(every.Value.Vars()) - for _, v := range declared.Sorted() { - if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil { + gv, err := rewriteDeclaredVar(g, stack, every.Key.Value.(Var), declaredVar) + if err != nil { return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) } + every.Key.Value = gv } - used := NewVarSet() - if every.Key != nil { - errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Key, errs, strict) + + // value is always present + gv, err := rewriteDeclaredVar(g, stack, every.Value.Value.(Var), declaredVar) + if err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) } - errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Value, errs, strict) - errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict) + every.Value.Value = gv + + used := NewVarSet() every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict) return e, errs } From 3227f5de3c9e42e484a0745fa1e29a4e3dc0a00f Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Tue, 18 Jan 2022 11:32:27 +0100 Subject: [PATCH 06/20] ast/compile: deal with wildcard cases Signed-off-by: Stephan Renatus --- ast/compile.go | 20 ++++++----- ast/compile_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 8 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 190d183148..452b292806 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -4130,19 +4130,23 @@ func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr // optionally rewrite the key if every.Key != nil { - gv, err := rewriteDeclaredVar(g, stack, every.Key.Value.(Var), declaredVar) - if err != nil { - return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) + if v := every.Key.Value.(Var); !v.IsWildcard() { // TODO + gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) + if err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) + } + every.Key.Value = gv } - every.Key.Value = gv } // value is always present - gv, err := rewriteDeclaredVar(g, stack, every.Value.Value.(Var), declaredVar) - if err != nil { - return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) + if v := every.Value.Value.(Var); !v.IsWildcard() { + gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) + if err != nil { + return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) + } + every.Value.Value = gv } - every.Value.Value = gv used := NewVarSet() every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict) diff --git a/ast/compile_test.go b/ast/compile_test.go index 673e14982c..787db9c384 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -2675,6 +2675,94 @@ func TestRewriteDeclaredVars(t *testing.T) { p = true { __local3__ = data.test.xs; every __local0__, __local1__ in __local3__ { plus(__local0__, __local1__, __local2__); __local4__ = data.test.i; gt(__local2__, __local4__) } } `, }, + { + note: "rewrite every: unused key var", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every k, v in [1] { v >= i } + } + `, + wantErr: errors.New("declared var k unused"), + }, + { + note: "rewrite every: unused value var", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every v in [1] { true } + } + `, + wantErr: errors.New("declared var v unused"), + }, + { + note: "rewrite every: wildcard value var, used key", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every k, _ in [1] { k >= 0 } + } + `, + exp: ` + package test + p = true { every __local0__, _ in [1] { gte(__local0__, 0) } } + `, + }, + { + note: "rewrite every: wildcard key+value var", // NOTE(sr): may be silly, but valid + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + every _, _ in [1] { true } + } + `, + exp: ` + package test + p = true { every _, _ in [1] { true } } + `, + }, + { + note: "rewrite every: declared vars with different scopes", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + some x + x = 10 + every x in [1] { x == 1 } + } + `, + exp: ` + package test + p = true { __local0__ = 10; every __local1__ in [1] { equal(__local1__, 1) } } + `, + }, + { + note: "rewrite every: declared vars used in body", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + some y + y = 10 + every x in [1] { x == y } + } + `, + exp: ` + package test + p = true { __local0__ = 10; every __local1__ in [1] { equal(__local1__, __local0__) } } + `, + }, { note: "rewrite closures", module: ` From 22ced5d7de9a6441bb26ea5c87081dfbadce7231 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Tue, 18 Jan 2022 11:59:18 +0100 Subject: [PATCH 07/20] ast/compile: add missing stack.Pop() Signed-off-by: Stephan Renatus --- ast/compile.go | 1 + ast/compile_test.go | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/ast/compile.go b/ast/compile.go index 452b292806..1207c144aa 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -4127,6 +4127,7 @@ func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict) stack.Push() + defer stack.Pop() // optionally rewrite the key if every.Key != nil { diff --git a/ast/compile_test.go b/ast/compile_test.go index 787db9c384..a9e57a698d 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -2763,6 +2763,23 @@ func TestRewriteDeclaredVars(t *testing.T) { p = true { __local0__ = 10; every __local1__ in [1] { equal(__local1__, __local0__) } } `, }, + { + note: "rewrite every: pops declared var stack", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p[x] { + some x + x = 10 + every _ in [1] { true } + } + `, + exp: ` + package test + p[__local0__] { __local0__ = 10; every _ in [1] { true } } + `, + }, { note: "rewrite closures", module: ` From b757da814bcd0908b8eed80f25672b62c7eb0992 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Tue, 18 Jan 2022 11:59:42 +0100 Subject: [PATCH 08/20] ast/compile_test: fix indentation Signed-off-by: Stephan Renatus --- ast/compile_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ast/compile_test.go b/ast/compile_test.go index a9e57a698d..686f63204b 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -2848,7 +2848,7 @@ func TestRewriteDeclaredVars(t *testing.T) { data.test.f(__local2__, "bar") } - f(__local0__, __local1__) = true { + f(__local0__, __local1__) = true { __local0__[__local1__] } `, From 00e3896e68f5c6f7d27841be061f218049750e67 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Tue, 18 Jan 2022 13:46:08 +0100 Subject: [PATCH 09/20] ast/compile: fix output vars of 'every' body Signed-off-by: Stephan Renatus --- ast/compile.go | 25 ++++++++++++++----------- ast/compile_test.go | 5 +++++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 1207c144aa..3a40b677f0 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -3036,7 +3036,16 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet { return outputVarsForExprCall(expr, ar, safe, terms) case *Every: - return outputVarsForTerms(terms.Domain, safe) + s := outputVarsForTerms(terms.Domain, safe) + + cpy := safe.Copy() + if terms.Key != nil { + cpy.Add(terms.Key.Value.(Var)) + } + cpy.Add(terms.Value.Value.(Var)) + + s.Update(outputVarsForBody(terms.Body, arity, cpy)) + return s default: panic("illegal expression") } @@ -3064,13 +3073,13 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va return output } - vis := NewVarVisitor().WithParams(VarVisitorParams{ + params := VarVisitorParams{ SkipClosures: true, SkipSets: true, SkipObjectKeys: true, SkipRefHead: true, - }) - + } + vis := NewVarVisitor().WithParams(params) vis.Walk(Args(terms[:numInputTerms])) unsafe := vis.Vars().Diff(output).Diff(safe) @@ -3078,13 +3087,7 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va return VarSet{} } - vis = NewVarVisitor().WithParams(VarVisitorParams{ - SkipRefHead: true, - SkipSets: true, - SkipObjectKeys: true, - SkipClosures: true, - }) - + vis = NewVarVisitor().WithParams(params) vis.Walk(Args(terms[numInputTerms:])) output.Update(vis.vars) return output diff --git a/ast/compile_test.go b/ast/compile_test.go index 686f63204b..53980958f0 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -213,6 +213,11 @@ func TestOutputVarsForNode(t *testing.T) { query: `xs = []; every k, v in xs[i] { k < v }`, exp: `{xs, i}`, }, + { + note: "every: output vars in body", + query: `every k, v in [] { k < v; i = 1 }`, + exp: `{i}`, + }, } for _, tc := range tests { From ddd81d8028bb51e6ba48ee3fd219983eff22ba9e Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Tue, 18 Jan 2022 14:19:33 +0100 Subject: [PATCH 10/20] ast/compile: safety check vars in every.Body Signed-off-by: Stephan Renatus --- ast/compile.go | 11 +++++++++-- ast/compile_test.go | 15 ++++++++++----- ast/policy.go | 15 +++++++++++++++ ast/visit.go | 8 ++++++-- 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 3a40b677f0..afdbf89681 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -2864,7 +2864,8 @@ type bodySafetyTransformer struct { } func (xform *bodySafetyTransformer) Visit(x interface{}) bool { - if term, ok := x.(*Term); ok { + switch term := x.(type) { + case *Term: switch x := term.Value.(type) { case *object: cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) { @@ -2894,6 +2895,12 @@ func (xform *bodySafetyTransformer) Visit(x interface{}) bool { xform.reorderSetComprehensionSafety(x) return true } + case *Expr: + if ev, ok := term.Terms.(*Every); ok { + xform.globals.Update(ev.Vars()) + ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body) + return true + } } return false } @@ -4134,7 +4141,7 @@ func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr // optionally rewrite the key if every.Key != nil { - if v := every.Key.Value.(Var); !v.IsWildcard() { // TODO + if v := every.Key.Value.(Var); !v.IsWildcard() { gv, err := rewriteDeclaredVar(g, stack, v, declaredVar) if err != nil { return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) diff --git a/ast/compile_test.go b/ast/compile_test.go index 53980958f0..6ca1c3673a 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -692,15 +692,18 @@ func TestCompilerCheckSafetyBodyReordering(t *testing.T) { contains(x, "oo") `}, {"userfunc", `split(y, ".", z); data.a.b.funcs.fn("...foo.bar..", y)`, `data.a.b.funcs.fn("...foo.bar..", y); split(y, ".", z)`}, + {"every", `every _ in [] { x != 1 }; x = 1`, `x = 1; every _ in [] { x != 1}`}, + {"every-domain", `every _ in xs { true }; xs = [1]`, `xs = [1]; every _ in xs { true }`}, } for i, tc := range tests { t.Run(tc.note, func(t *testing.T) { + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} c := NewCompiler() c.Modules = getCompilerTestModules() - c.Modules["reordering"] = MustParseModule(fmt.Sprintf( + c.Modules["reordering"] = MustParseModuleWithOpts(fmt.Sprintf( `package test - p { %s }`, tc.body)) + p { %s }`, tc.body), opts) compileStages(c, c.checkSafetyRuleBodies) @@ -709,7 +712,7 @@ func TestCompilerCheckSafetyBodyReordering(t *testing.T) { return } - expected := MustParseBody(tc.expected) + expected := MustParseBodyWithOpts(tc.expected, opts) result := c.Modules["reordering"].Rules[0].Body if !expected.Equal(result) { @@ -807,6 +810,7 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) { {"call-too-few", "p { f(1,x) } f(x,y) { true }", "{x,}"}, {"object-key-comprehension", "p { { {p|x}: 0 } }", "{x,}"}, {"set-value-comprehension", "p { {1, {p|x}} }", "{x,}"}, + {"every", "p { every y in [10] { x > y } }", "{x,}"}, } makeErrMsg := func(varName string) string { @@ -827,15 +831,16 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) { sort.Strings(expected) // Compile test module. + popts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} c := NewCompiler() c.Modules = map[string]*Module{ - "newMod": MustParseModule(fmt.Sprintf(` + "newMod": MustParseModuleWithOpts(fmt.Sprintf(` %v %v - `, moduleBegin, tc.moduleContent)), + `, moduleBegin, tc.moduleContent), popts), } compileStages(c, c.checkSafetyRuleBodies) diff --git a/ast/policy.go b/ast/policy.go index 5265f42f5b..412ddf560e 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -1491,6 +1491,21 @@ func (q *Every) Compare(other *Every) int { return q.Body.Compare(other.Body) } +// Vars returns the key and val arguments of an every expression, +// if they are non-nil and not wildcards. +func (q *Every) Vars() VarSet { + r := NewVarSet() + if v := q.Value.Value.(Var); !v.IsWildcard() { + r.Add(v) + } + if q.Key != nil { + if v := q.Key.Value.(Var); !v.IsWildcard() { + r.Add(v) + } + } + return r +} + func (w *With) String() string { return "with " + w.Target.String() + " as " + w.Value.String() } diff --git a/ast/visit.go b/ast/visit.go index f1950a54b4..06255b4b9a 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -160,9 +160,11 @@ func WalkVars(x interface{}, f func(Var) bool) { // returns true, AST nodes under the last node will not be visited. func WalkClosures(x interface{}, f func(interface{}) bool) { vis := &GenericVisitor{func(x interface{}) bool { - switch x.(type) { + switch x := x.(type) { case *ArrayComprehension, *ObjectComprehension, *SetComprehension: return f(x) + case *Every: + return f(x.Body) } return false }} @@ -566,7 +568,9 @@ func (vis *VarVisitor) visit(v interface{}) bool { case *ArrayComprehension, *ObjectComprehension, *SetComprehension: return true case *Expr: - if _, ok := v.Terms.(*Every); ok { + if ev, ok := v.Terms.(*Every); ok { + vis.Walk(ev.Domain) + // We're _not_ walking ev.Body -- that's the closure here return true } } From 8e78ae71e1c56dbb02150164f7281234769d09d0 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Wed, 19 Jan 2022 10:22:47 +0100 Subject: [PATCH 11/20] ast/compile_test: add nested case for every rewriting Signed-off-by: Stephan Renatus --- ast/compile_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/ast/compile_test.go b/ast/compile_test.go index 6ca1c3673a..3c3303e745 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -2790,6 +2790,33 @@ func TestRewriteDeclaredVars(t *testing.T) { p[__local0__] { __local0__ = 10; every _ in [1] { true } } `, }, + { + note: "rewrite every: nested", + module: ` + package test + # import future.keywords.in + # import future.keywords.every + p { + xs := [[1], [2]] + every v in [1] { + every w in xs[v] { + w == 2 + } + } + } + `, + exp: ` + package test + p = true { + __local0__ = [[1], [2]] + every __local1__ in [1] { + __local3__ = __local0__[__local1__] + every __local2__ in __local3__ { + equal(__local2__, 2) + } + } + } + `}, { note: "rewrite closures", module: ` From 8ac09b4c7894c250ecd6f6d62243bec39c14f68d Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Wed, 19 Jan 2022 10:52:22 +0100 Subject: [PATCH 12/20] ast/compiler: rewrite prints in 'every' bodies Signed-off-by: Stephan Renatus --- ast/compile.go | 7 ++++++- ast/compile_test.go | 34 ++++++++++++++++++++++++++++++---- ast/visit.go | 4 +--- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index afdbf89681..db3b04b8e8 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -1452,7 +1452,7 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V // Visit comprehension bodies recursively to ensure print statements inside // those bodies only close over variables that are safe. for i := range body { - if ContainsComprehensions(body[i]) { + if ContainsComprehensions(body[i]) || body[i].IsEvery() { safe := outputVarsForBody(body[:i], getArity, globals) safe.Update(globals) WalkClosures(body[i], func(x interface{}) bool { @@ -1463,6 +1463,9 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V errs = rewritePrintCalls(gen, getArity, safe, x.Body) case *ObjectComprehension: errs = rewritePrintCalls(gen, getArity, safe, x.Body) + case *Every: + safe.Update(x.Vars()) + errs = rewritePrintCalls(gen, getArity, safe, x.Body) } return true }) @@ -1524,6 +1527,8 @@ func erasePrintCalls(node interface{}) { x.Body = erasePrintCallsInBody(x.Body) case *ObjectComprehension: x.Body = erasePrintCallsInBody(x.Body) + case *Every: + x.Body = erasePrintCallsInBody(x.Body) } return false }).Walk(node) diff --git a/ast/compile_test.go b/ast/compile_test.go index 3c3303e745..eb7328a533 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -3293,6 +3293,16 @@ func TestCompilerRewritePrintCallsErasure(t *testing.T) { p { {"x": 1 | false} } `, }, + { + note: "every body", + module: `package test + + p { every _ in [] { false; print(1) } } + `, + exp: `package test + + p { every _ in [] { false } }`, + }, { note: "in head", module: `package test @@ -3307,13 +3317,14 @@ func TestCompilerRewritePrintCallsErasure(t *testing.T) { for _, tc := range cases { t.Run(tc.note, func(t *testing.T) { c := NewCompiler().WithEnablePrintStatements(false) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} c.Compile(map[string]*Module{ - "test.rego": MustParseModule(tc.module), + "test.rego": MustParseModuleWithOpts(tc.module, opts), }) if c.Failed() { t.Fatal(c.Errors) } - exp := MustParseModule(tc.exp) + exp := MustParseModuleWithOpts(tc.exp, opts) if !exp.Equal(c.Modules["test.rego"]) { t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"]) } @@ -3417,6 +3428,20 @@ func TestCompilerRewritePrintCalls(t *testing.T) { p = true { x = 1; {"x": 2 | __local1__ = {__local0__ | __local0__ = x}; internal.print([__local1__])} }`, }, + { + note: "print inside every", + module: `package test + + p { every x in [1,2] { print(x) } }`, + exp: `package test + + p = true { + every __local0__ in [1, 2] { + __local2__ = {__local1__ | __local1__ = __local0__} + internal.print([__local2__]) + } + }`, + }, { note: "print output of nested call", module: `package test @@ -3503,13 +3528,14 @@ func TestCompilerRewritePrintCalls(t *testing.T) { for _, tc := range cases { t.Run(tc.note, func(t *testing.T) { c := NewCompiler().WithEnablePrintStatements(true) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} c.Compile(map[string]*Module{ - "test.rego": MustParseModule(tc.module), + "test.rego": MustParseModuleWithOpts(tc.module, opts), }) if c.Failed() { t.Fatal(c.Errors) } - exp := MustParseModule(tc.exp) + exp := MustParseModuleWithOpts(tc.exp, opts) if !exp.Equal(c.Modules["test.rego"]) { t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"]) } diff --git a/ast/visit.go b/ast/visit.go index 06255b4b9a..521523e591 100644 --- a/ast/visit.go +++ b/ast/visit.go @@ -161,10 +161,8 @@ func WalkVars(x interface{}, f func(Var) bool) { func WalkClosures(x interface{}, f func(interface{}) bool) { vis := &GenericVisitor{func(x interface{}) bool { switch x := x.(type) { - case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: return f(x) - case *Every: - return f(x.Body) } return false }} From 0b72c1cd6919b57cd225305f0e49947a18bf2982 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Wed, 19 Jan 2022 10:59:36 +0100 Subject: [PATCH 13/20] ast/compile_test: add "rewrite dynamics" tests for "every" Signed-off-by: Stephan Renatus --- ast/compile_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ast/compile_test.go b/ast/compile_test.go index eb7328a533..f7cf835abe 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -3136,16 +3136,20 @@ func TestCompilerRewriteDynamicTerms(t *testing.T) { {`call_with { count(str) with input as 1 }`, `__local0__ = data.test.str with input as 1; count(__local0__) with input as 1`}, {`call_func { f(input, "foo") } f(x,y) { x[y] }`, `__local2__ = input; data.test.f(__local2__, "foo")`}, {`call_func2 { f(input.foo, "foo") } f(x,y) { x[y] }`, `__local2__ = input.foo; data.test.f(__local2__, "foo")`}, + {`every_domain { every _ in str { true } }`, `__local0__ = data.test.str; every _ in __local0__ { true }`}, + {`every_domain_call { every _ in numbers.range(1, 10) { true } }`, `numbers.range(1, 10, __local0__); every _ in __local0__ { true }`}, + {`every_body { every _ in [] { [str] } }`, `every _ in [] { __local0__ = data.test.str; [__local0__] }`}, } for _, tc := range tests { t.Run(tc.input, func(t *testing.T) { c := NewCompiler() + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} module := fixture + tc.input - c.Modules["test"] = MustParseModuleWithOpts(module, ParserOptions{AllFutureKeywords: true}) + c.Modules["test"] = MustParseModuleWithOpts(module, opts) compileStages(c, c.rewriteDynamicTerms) assertNotFailed(t, c) - expected := MustParseBodyWithOpts(tc.expected, ParserOptions{AllFutureKeywords: true}) + expected := MustParseBodyWithOpts(tc.expected, opts) result := c.Modules["test"].Rules[1].Body if result.Compare(expected) != 0 { t.Fatalf("\nExp: %v\nGot: %v", expected, result) From 3954eb31b7d0693a5458b8036e225cbe5a2deebc Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Thu, 20 Jan 2022 13:56:38 +0100 Subject: [PATCH 14/20] ast/compile: expand "every" domain There's one broken test that I haven't figured out how to fix yet. From the perspective of topdown, this change felt right, it will simplify evaluation. Signed-off-by: Stephan Renatus --- ast/compile.go | 11 ++++++- ast/compile_test.go | 73 +++++++++++++++++++++++++++++++++------------ 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index db3b04b8e8..41adb271a9 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -3776,7 +3776,16 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { result = append(result, expr) case *Every: var extras []*Expr - extras, terms.Domain = expandExprTerm(gen, terms.Domain) + if _, ok := terms.Domain.Value.(Call); ok { + extras, terms.Domain = expandExprTerm(gen, terms.Domain) + } else { + term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location) + eq := Equality.Expr(term, terms.Domain) + eq.Generated = true + eq.Location = terms.Domain.Location + extras = append(extras, eq) + terms.Domain = term + } terms.Body = rewriteExprTermsInBody(gen, terms.Body) result = append(result, extras...) result = append(result, expr) diff --git a/ast/compile_test.go b/ast/compile_test.go index f7cf835abe..4d088b60c3 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -692,7 +692,7 @@ func TestCompilerCheckSafetyBodyReordering(t *testing.T) { contains(x, "oo") `}, {"userfunc", `split(y, ".", z); data.a.b.funcs.fn("...foo.bar..", y)`, `data.a.b.funcs.fn("...foo.bar..", y); split(y, ".", z)`}, - {"every", `every _ in [] { x != 1 }; x = 1`, `x = 1; every _ in [] { x != 1}`}, + {"every", `every _ in [] { x != 1 }; x = 1`, `__local3__ = []; x = 1; every _ in __local3__ { x != 1}`}, {"every-domain", `every _ in xs { true }; xs = [1]`, `xs = [1]; every _ in xs { true }`}, } @@ -1458,13 +1458,26 @@ func TestCompilerRewriteExprTerms(t *testing.T) { f(__local0__[0]) { true; __local0__ = [1] }`, }, + { + note: "every: domain", + module: ` + package test + + p { every x in [1,2] { x } }`, + expected: ` + package test + + p { __local1__ = [1, 2]; every __local0__ in __local1__ { __local0__ } }`, + }, } for _, tc := range cases { t.Run(tc.note, func(t *testing.T) { compiler := NewCompiler() + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + compiler.Modules = map[string]*Module{ - "test": MustParseModule(tc.module), + "test": MustParseModuleWithOpts(tc.module, opts), } compileStages(compiler, compiler.rewriteExprTerms) @@ -1472,7 +1485,7 @@ func TestCompilerRewriteExprTerms(t *testing.T) { case string: assertNotFailed(t, compiler) - expected := MustParseModule(exp) + expected := MustParseModuleWithOpts(exp, opts) if !expected.Equal(compiler.Modules["test"]) { t.Fatalf("Expected modules to be equal. Expected:\n\n%v\n\nGot:\n\n%v", expected, compiler.Modules["test"]) @@ -2682,8 +2695,14 @@ func TestRewriteDeclaredVars(t *testing.T) { xs = [1, 2] k = "foo" v = "bar" - p = true { __local3__ = data.test.xs; every __local0__, __local1__ in __local3__ { plus(__local0__, __local1__, __local2__); __local4__ = data.test.i; gt(__local2__, __local4__) } } - `, + p = true { + __local2__ = data.test.xs + every __local0__, __local1__ in __local2__ { + plus(__local0__, __local1__, __local3__) + __local4__ = data.test.i + gt(__local3__, __local4__) + } + } `, }, { note: "rewrite every: unused key var", @@ -2721,7 +2740,10 @@ func TestRewriteDeclaredVars(t *testing.T) { `, exp: ` package test - p = true { every __local0__, _ in [1] { gte(__local0__, 0) } } + p = true { + __local1__ = [1] + every __local0__, _ in __local1__ { gte(__local0__, 0) } + } `, }, { @@ -2736,7 +2758,7 @@ func TestRewriteDeclaredVars(t *testing.T) { `, exp: ` package test - p = true { every _, _ in [1] { true } } + p = true { __local0__ = [1]; every _, _ in __local0__ { true } } `, }, { @@ -2753,7 +2775,11 @@ func TestRewriteDeclaredVars(t *testing.T) { `, exp: ` package test - p = true { __local0__ = 10; every __local1__ in [1] { equal(__local1__, 1) } } + p = true { + __local0__ = 10 + __local2__ = [1] + every __local1__ in __local2__ { __local1__ == 1 } + } `, }, { @@ -2770,7 +2796,13 @@ func TestRewriteDeclaredVars(t *testing.T) { `, exp: ` package test - p = true { __local0__ = 10; every __local1__ in [1] { equal(__local1__, __local0__) } } + p = true { + __local0__ = 10 + __local2__ = [1] + every __local1__ in __local2__ { + __local1__ == __local0__ + } + } `, }, { @@ -2787,7 +2819,7 @@ func TestRewriteDeclaredVars(t *testing.T) { `, exp: ` package test - p[__local0__] { __local0__ = 10; every _ in [1] { true } } + p[__local0__] { __local0__ = 10; __local1__ = [1]; every _ in __local1__ { true } } `, }, { @@ -2809,10 +2841,11 @@ func TestRewriteDeclaredVars(t *testing.T) { package test p = true { __local0__ = [[1], [2]] - every __local1__ in [1] { - __local3__ = __local0__[__local1__] - every __local2__ in __local3__ { - equal(__local2__, 2) + __local3__ = [1] + every __local1__ in __local3__ { + __local4__ = __local0__[__local1__] + every __local2__ in __local4__ { + __local2__ == 2 } } } @@ -3138,7 +3171,8 @@ func TestCompilerRewriteDynamicTerms(t *testing.T) { {`call_func2 { f(input.foo, "foo") } f(x,y) { x[y] }`, `__local2__ = input.foo; data.test.f(__local2__, "foo")`}, {`every_domain { every _ in str { true } }`, `__local0__ = data.test.str; every _ in __local0__ { true }`}, {`every_domain_call { every _ in numbers.range(1, 10) { true } }`, `numbers.range(1, 10, __local0__); every _ in __local0__ { true }`}, - {`every_body { every _ in [] { [str] } }`, `every _ in [] { __local0__ = data.test.str; [__local0__] }`}, + {`every_body { every _ in [] { [str] } }`, + `__local0__ = []; every _ in __local0__ { __local1__ = data.test.str; [__local1__] }`}, } for _, tc := range tests { @@ -3305,7 +3339,7 @@ func TestCompilerRewritePrintCallsErasure(t *testing.T) { `, exp: `package test - p { every _ in [] { false } }`, + p = true { __local0__ = []; every _ in __local0__ { false } }`, }, { note: "in head", @@ -3440,9 +3474,10 @@ func TestCompilerRewritePrintCalls(t *testing.T) { exp: `package test p = true { - every __local0__ in [1, 2] { - __local2__ = {__local1__ | __local1__ = __local0__} - internal.print([__local2__]) + __local2__ = [1, 2] + every __local0__ in __local2__ { + __local3__ = {__local1__ | __local1__ = __local0__} + internal.print([__local3__]) } }`, }, From 0f1bfe06cd0d58ba422d1fa6f9de37c925b7d3c9 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Fri, 21 Jan 2022 13:40:52 +0100 Subject: [PATCH 15/20] ast/compile: outputVarsForExpr: don't return vars from "Every" body Signed-off-by: Stephan Renatus --- ast/compile.go | 11 +---------- ast/compile_test.go | 2 +- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 41adb271a9..1dfafbac75 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -3048,16 +3048,7 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet { return outputVarsForExprCall(expr, ar, safe, terms) case *Every: - s := outputVarsForTerms(terms.Domain, safe) - - cpy := safe.Copy() - if terms.Key != nil { - cpy.Add(terms.Key.Value.(Var)) - } - cpy.Add(terms.Value.Value.(Var)) - - s.Update(outputVarsForBody(terms.Body, arity, cpy)) - return s + return outputVarsForTerms(terms.Domain, safe) default: panic("illegal expression") } diff --git a/ast/compile_test.go b/ast/compile_test.go index 4d088b60c3..d529e65fee 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -216,7 +216,7 @@ func TestOutputVarsForNode(t *testing.T) { { note: "every: output vars in body", query: `every k, v in [] { k < v; i = 1 }`, - exp: `{i}`, + exp: `set()`, }, } From 9289cb594163c7c0980d7f987ba5793e572f6825 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Fri, 21 Jan 2022 14:22:42 +0100 Subject: [PATCH 16/20] ast/compile: fix safety reordering for every Signed-off-by: Stephan Renatus --- ast/compile.go | 4 ++++ ast/compile_test.go | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ast/compile.go b/ast/compile.go index 1dfafbac75..f01d08cc72 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -2962,6 +2962,10 @@ func reorderBodyForClosures(arity func(Ref) int, globals VarSet, body Body) (Bod vs := VarSet{} WalkClosures(e, func(x interface{}) bool { vis := &VarVisitor{vars: vs} + if ev, ok := x.(*Every); ok { + vis.Walk(ev.Body) + return true + } vis.Walk(x) return true }) diff --git a/ast/compile_test.go b/ast/compile_test.go index d529e65fee..23a5f0d719 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -693,7 +693,7 @@ func TestCompilerCheckSafetyBodyReordering(t *testing.T) { `}, {"userfunc", `split(y, ".", z); data.a.b.funcs.fn("...foo.bar..", y)`, `data.a.b.funcs.fn("...foo.bar..", y); split(y, ".", z)`}, {"every", `every _ in [] { x != 1 }; x = 1`, `__local3__ = []; x = 1; every _ in __local3__ { x != 1}`}, - {"every-domain", `every _ in xs { true }; xs = [1]`, `xs = [1]; every _ in xs { true }`}, + {"every-domain", `every _ in xs { true }; xs = [1]`, `xs = [1]; __local3__ = xs; every _ in __local3__ { true }`}, } for i, tc := range tests { From decbdb10f55a271cc95516cd104af8e71e4cd538 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Fri, 21 Jan 2022 14:28:43 +0100 Subject: [PATCH 17/20] ast: rename (Every).Vars() -> (Every).KeyValueVars() Signed-off-by: Stephan Renatus --- ast/compile.go | 4 ++-- ast/policy.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index f01d08cc72..43e8e77444 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -1464,7 +1464,7 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V case *ObjectComprehension: errs = rewritePrintCalls(gen, getArity, safe, x.Body) case *Every: - safe.Update(x.Vars()) + safe.Update(x.KeyValueVars()) errs = rewritePrintCalls(gen, getArity, safe, x.Body) } return true @@ -2902,7 +2902,7 @@ func (xform *bodySafetyTransformer) Visit(x interface{}) bool { } case *Expr: if ev, ok := term.Terms.(*Every); ok { - xform.globals.Update(ev.Vars()) + xform.globals.Update(ev.KeyValueVars()) ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body) return true } diff --git a/ast/policy.go b/ast/policy.go index 412ddf560e..ad0dfd60fa 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -1491,9 +1491,9 @@ func (q *Every) Compare(other *Every) int { return q.Body.Compare(other.Body) } -// Vars returns the key and val arguments of an every expression, -// if they are non-nil and not wildcards. -func (q *Every) Vars() VarSet { +// KeyValueVars returns the key and val arguments of an `every` +// expression, if they are non-nil and not wildcards. +func (q *Every) KeyValueVars() VarSet { r := NewVarSet() if v := q.Value.Value.(Var); !v.IsWildcard() { r.Add(v) From 5f41d872f4558380fa6efe57ff75a41e3a2d5cd6 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Sat, 29 Jan 2022 17:34:44 +0100 Subject: [PATCH 18/20] ast/compile: use VarVisitor for KeyValueVars() Signed-off-by: Stephan Renatus --- ast/policy.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/ast/policy.go b/ast/policy.go index ad0dfd60fa..fab9939a32 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -1494,16 +1494,12 @@ func (q *Every) Compare(other *Every) int { // KeyValueVars returns the key and val arguments of an `every` // expression, if they are non-nil and not wildcards. func (q *Every) KeyValueVars() VarSet { - r := NewVarSet() - if v := q.Value.Value.(Var); !v.IsWildcard() { - r.Add(v) - } + vis := &VarVisitor{vars: VarSet{}} if q.Key != nil { - if v := q.Key.Value.(Var); !v.IsWildcard() { - r.Add(v) - } + vis.Walk(q.Key) } - return r + vis.Walk(q.Value) + return vis.vars } func (w *With) String() string { From 46478b6e15b12100e64886b4c7a9e3cfa43d8a63 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Sat, 29 Jan 2022 17:36:03 +0100 Subject: [PATCH 19/20] ast/compile: add ContainsClosures Signed-off-by: Stephan Renatus --- ast/compile.go | 2 +- ast/term.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/ast/compile.go b/ast/compile.go index 43e8e77444..e468947d5e 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -1452,7 +1452,7 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V // Visit comprehension bodies recursively to ensure print statements inside // those bodies only close over variables that are safe. for i := range body { - if ContainsComprehensions(body[i]) || body[i].IsEvery() { + if ContainsClosures(body[i]) { safe := outputVarsForBody(body[:i], getArity, globals) safe.Update(globals) WalkClosures(body[i], func(x interface{}) bool { diff --git a/ast/term.go b/ast/term.go index ce254685b1..f56f16794f 100644 --- a/ast/term.go +++ b/ast/term.go @@ -493,6 +493,20 @@ func ContainsComprehensions(v interface{}) bool { return found } +// ContainsClosures returns true if the Value v contains closures. +func ContainsClosures(v interface{}) bool { + found := false + WalkClosures(v, func(x interface{}) bool { + switch x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: + found = true + return found + } + return found + }) + return found +} + // IsScalar returns true if the AST value is a scalar. func IsScalar(v Value) bool { switch v.(type) { From 5f5da56262ff86686c1af3441b2de80fb24beeb9 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Sat, 29 Jan 2022 17:41:42 +0100 Subject: [PATCH 20/20] ast/compile_test: add unused assigned var in "every" body case Signed-off-by: Stephan Renatus --- ast/compile_test.go | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ast/compile_test.go b/ast/compile_test.go index 23a5f0d719..c7a4a1aeea 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -4019,13 +4019,23 @@ func TestCompilerCheckUnusedAssignedVar(t *testing.T) { &Error{Message: "assigned var y unused"}, }, }, + { + note: "every: unused assigned var in body", + module: `package test + p { every i in [1] { y := 10; i == 1 } } + `, + expectedErrors: Errors{ + &Error{Message: "assigned var y unused"}, + }, + }, } makeTestRunner := func(tc testCase, strict bool) func(t *testing.T) { return func(t *testing.T) { compiler := NewCompiler().WithStrict(strict) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} compiler.Modules = map[string]*Module{ - "test": MustParseModule(tc.module), + "test": MustParseModuleWithOpts(tc.module, opts), } compileStages(compiler, compiler.rewriteLocalVars) @@ -5509,12 +5519,18 @@ func TestQueryCompilerWithUnusedAssignedVar(t *testing.T) { query: "{1: 2 | x := 2}", expectedErrors: fmt.Errorf("1 error occurred: 1:9: rego_compile_error: assigned var x unused"), }, + { + note: "every: unused var in body", + query: "every _ in [] { x := 10 }", + expectedErrors: fmt.Errorf("1 error occurred: 1:17: rego_compile_error: assigned var x unused"), + }, } makeTestRunner := func(tc testCase, strict bool) func(t *testing.T) { return func(t *testing.T) { c := NewCompiler().WithStrict(strict) - result, err := c.QueryCompiler().Compile(MustParseBody(tc.query)) + opts := ParserOptions{AllFutureKeywords: true, unreleasedKeywords: true} + result, err := c.QueryCompiler().Compile(MustParseBodyWithOpts(tc.query, opts)) if strict { if err == nil {