Skip to content

Commit

Permalink
rego: Setting query Rego-version from configured imports
Browse files Browse the repository at this point in the history
When `rego.v1` is in the list of imports directly applied on the `rego.Rego` SDK struct, this import, and it's effects, is applied to the query when parsed.
This change affects the `eval` and `bench` commands when the `--imports` flag is used.

Fixes: #6701

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Apr 22, 2024
1 parent 3954ba0 commit 091286b
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ast/parser.go
Expand Up @@ -2630,7 +2630,7 @@ func (p *Parser) regoV1Import(imp *Import) {
path := imp.Path.Value.(Ref)

if len(path) == 1 || !path[1].Equal(RegoV1CompatibleRef[1]) || len(path) > 2 {
p.errorf(imp.Path.Location, "invalid import, must be `%s`", RegoV1CompatibleRef)
p.errorf(imp.Path.Location, "invalid import `%s`, must be `%s`", path, RegoV1CompatibleRef)
return
}

Expand Down
7 changes: 4 additions & 3 deletions ast/parser_test.go
Expand Up @@ -1337,9 +1337,10 @@ func TestFutureAndRegoV1ImportsExtraction(t *testing.T) {
}

func TestRegoV1Import(t *testing.T) {
assertParseErrorContains(t, "rego", "import rego", "invalid import, must be `rego.v1`")
assertParseErrorContains(t, "rego.foo", "import rego.foo", "invalid import, must be `rego.v1`")
assertParseErrorContains(t, "rego.foo.bar", "import rego.foo.bar", "invalid import, must be `rego.v1`")
assertParseErrorContains(t, "rego", "import rego", "invalid import `rego`, must be `rego.v1`")
assertParseErrorContains(t, "rego.foo", "import rego.foo", "invalid import `rego.foo`, must be `rego.v1`")
assertParseErrorContains(t, "rego.foo.bar", "import rego.foo.bar", "invalid import `rego.foo.bar`, must be `rego.v1`")
assertParseErrorContains(t, "rego.v1.bar", "import rego.v1.bar", "invalid import `rego.v1.bar`, must be `rego.v1`")
assertParseErrorContains(t, "rego.v1 + alias", "import rego.v1 as xyz", "`rego` imports cannot be aliased")

assertParseImport(t, "import rego.v1",
Expand Down
38 changes: 38 additions & 0 deletions cmd/bench_test.go
Expand Up @@ -60,6 +60,44 @@ func TestRunBenchmark(t *testing.T) {
}
}

func TestRunBenchmarkWithQueryImport(t *testing.T) {
params := testBenchParams()
// We add the rego.v1 import ..
params.imports = newrepeatedStringFlag([]string{"rego.v1"})

// .. which provides the 'in' keyword
args := []string{`"a" in ["a", "b", "c"]`}
var buf bytes.Buffer

rc, err := benchMain(args, params, &buf, &goBenchRunner{})
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}

if rc != 0 {
t.Fatalf("Unexpected return code %d, expected 0", rc)
}

// Expect a json serialized benchmark result with histogram fields
var br testing.BenchmarkResult
err = util.UnmarshalJSON(buf.Bytes(), &br)
if err != nil {
t.Fatalf("Unexpected error unmarshalling output: %s", err)
}

if br.N == 0 || br.T == 0 || br.MemAllocs == 0 || br.MemBytes == 0 {
t.Fatalf("Expected benchmark results to be non-zero, got: %+v", br)
}

if _, ok := br.Extra["histogram_timer_rego_query_eval_ns_count"]; !ok {
t.Fatalf("Expected benchmark results to contain histogram_timer_rego_query_eval_ns_count, got: %+v", br)
}

if float64(br.N) != br.Extra["histogram_timer_rego_query_eval_ns_count"] {
t.Fatalf("Expected 'histogram_timer_rego_query_eval_ns_count' to be equal to N")
}
}

func TestRunBenchmarkE2E(t *testing.T) {
params := testBenchParams()
params.e2e = true
Expand Down
80 changes: 80 additions & 0 deletions cmd/eval_test.go
Expand Up @@ -2294,3 +2294,83 @@ p contains 2 if {
}
}
}

func TestWithQueryImports(t *testing.T) {
tests := []struct {
note string
query string
imports []string
exp string
expErrs []string
}{
{
note: "no imports, none required",
query: "1 + 2",
exp: "3\n",
},
{
note: "future keyword used, future.keywords imported",
query: `"b" in ["a", "b", "c"]`,
imports: []string{"future.keywords.in"},
exp: "true\n",
},
{
note: "future keyword used, rego.v1 imported",
query: `"b" in ["a", "b", "c"]`,
imports: []string{"rego.v1"},
exp: "true\n",
},
{
note: "future keyword used, invalid rego.v2 imported",
query: `"b" in ["a", "b", "c"]`,
imports: []string{"rego.v2"},
expErrs: []string{
"1:8: rego_parse_error: invalid import `rego.v2`, must be `rego.v1`",
},
},
{
note: "future keyword used, no imports",
query: `"b" in ["a", "b", "c"]`,
expErrs: []string{
"1:5: rego_unsafe_var_error: var in is unsafe (hint: `import future.keywords.in` to import a future keyword)",
},
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
params := newEvalCommandParams()
_ = params.outputFormat.Set(evalPrettyOutput)
params.imports = newrepeatedStringFlag(tc.imports)

var buf bytes.Buffer

defined, err := eval([]string{tc.query}, params, &buf)

if len(tc.expErrs) == 0 {
if err != nil {
t.Fatalf("Unexpected error: %v, buf: %s", err, buf.String())
}

if !defined {
t.Fatal("expected result to be defined")
}

if buf.String() != tc.exp {
t.Fatalf("expected:\n\n%s\n\ngot:\n\n%s", tc.exp, buf.String())
}
} else {
if err == nil {
t.Fatal("expected error, got none")
}

actual := buf.String()
for _, expErr := range tc.expErrs {
if !strings.Contains(actual, expErr) {
t.Fatalf("expected error:\n\n%v\n\ngot\n\n%v", expErr, actual)
}
}
}
})
}
}
28 changes: 22 additions & 6 deletions rego/rego.go
Expand Up @@ -1761,14 +1761,15 @@ func (r *Rego) prepare(ctx context.Context, qType queryType, extras []extraStage
return err
}

futureImports := []*ast.Import{}
queryImports := []*ast.Import{}
for _, imp := range imports {
if imp.Path.Value.(ast.Ref).HasPrefix(ast.Ref([]*ast.Term{ast.FutureRootDocument})) {
futureImports = append(futureImports, imp)
path := imp.Path.Value.(ast.Ref)
if path.HasPrefix([]*ast.Term{ast.FutureRootDocument}) || path.HasPrefix([]*ast.Term{ast.RegoRootDocument}) {
queryImports = append(queryImports, imp)
}
}

r.parsedQuery, err = r.parseQuery(futureImports, r.metrics)
r.parsedQuery, err = r.parseQuery(queryImports, r.metrics)
if err != nil {
return err
}
Expand Down Expand Up @@ -1921,22 +1922,37 @@ func (r *Rego) parseRawInput(rawInput *interface{}, m metrics.Metrics) (ast.Valu
return ast.InterfaceToValue(*rawPtr)
}

func (r *Rego) parseQuery(futureImports []*ast.Import, m metrics.Metrics) (ast.Body, error) {
func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Body, error) {
if r.parsedQuery != nil {
return r.parsedQuery, nil
}

m.Timer(metrics.RegoQueryParse).Start()
defer m.Timer(metrics.RegoQueryParse).Stop()

popts, err := future.ParserOptionsFromFutureImports(futureImports)
popts, err := future.ParserOptionsFromFutureImports(queryImports)
if err != nil {
return nil, err
}
popts, err = parserOptionsFromRegoVersionImport(queryImports, popts)
if err != nil {
return nil, err
}
popts.SkipRules = true
return ast.ParseBodyWithOpts(r.query, popts)
}

func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserOptions) (ast.ParserOptions, error) {
for _, imp := range imports {
path := imp.Path.Value.(ast.Ref)
if ast.Compare(path, ast.RegoV1CompatibleRef) == 0 {
popts.RegoVersion = ast.RegoV1
return popts, nil
}
}
return popts, nil
}

func (r *Rego) compileModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error {

// Only compile again if there are new modules.
Expand Down

0 comments on commit 091286b

Please sign in to comment.