diff --git a/topdown/sets.go b/topdown/sets.go index a9c5ad86c1..cb0b1d7d37 100644 --- a/topdown/sets.go +++ b/topdown/sets.go @@ -58,22 +58,26 @@ func builtinSetIntersection(a ast.Value) (ast.Value, error) { // builtinSetUnion returns the union of the given input sets func builtinSetUnion(a ast.Value) (ast.Value, error) { + // The set union logic here is duplicated and manually inlined on + // purpose. By lifting this logic up a level, and not doing pairwise + // set unions, we avoid a number of heap allocations. This improves + // performance dramatically over the naive approach. + result := ast.NewSet() inputSet, err := builtins.SetOperand(a, 1) if err != nil { return nil, err } - result := ast.NewSet() - err = inputSet.Iter(func(x *ast.Term) error { - n, err := builtins.SetOperand(x.Value, 1) + item, err := builtins.SetOperand(x.Value, 1) if err != nil { return err } - result = result.Union(n) + item.Foreach(result.Add) return nil }) + return result, err } diff --git a/topdown/sets_bench_test.go b/topdown/sets_bench_test.go new file mode 100644 index 0000000000..231106695c --- /dev/null +++ b/topdown/sets_bench_test.go @@ -0,0 +1,132 @@ +// Copyright 2022 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "context" + "fmt" + "testing" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/storage/inmem" +) + +func BenchmarkSetUnion(b *testing.B) { + ctx := context.Background() + + sizes := []int{10, 100, 250} + + for _, n := range sizes { + for _, m := range sizes { + b.Run(fmt.Sprintf("%dx%d", n, m), func(b *testing.B) { + store := inmem.NewFromObject(map[string]interface{}{"nsets": n, "nsize": m}) + + // Code is lifted from here: + // https://github.com/open-policy-agent/opa/issues/4979#issue-1332019382 + + module := `package test + + nums := numbers.range(0, data.nsets) + sizes := numbers.range(0, data.nsize) + + sets[n] = x { + nums[n] + x := {sprintf("%d,%d", [n, i]) | sizes[i]} + } + combined := union({s | s := sets[_]})` + + query := ast.MustParseBody("data.test.combined") + compiler := ast.MustCompileModules(map[string]string{ + "test.rego": module, + }) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error { + + q := NewQuery(query). + WithCompiler(compiler). + WithStore(store). + WithTransaction(txn) + + _, err := q.Run(ctx) + if err != nil { + return err + } + + return nil + }) + + if err != nil { + b.Fatal(err) + } + } + }) + } + } +} + +func BenchmarkSetUnionSlow(b *testing.B) { + // This benchmarks the suggested means to implement union + // without using the builtin, to give us an idea of whether or not + // the builtin is actually making things any faster. + ctx := context.Background() + + sizes := []int{10, 100, 250} + + for _, n := range sizes { + for _, m := range sizes { + b.Run(fmt.Sprintf("%dx%d", n, m), func(b *testing.B) { + store := inmem.NewFromObject(map[string]interface{}{"nsets": n, "nsize": m}) + + // Code is lifted from here: + // https://github.com/open-policy-agent/opa/issues/4979#issue-1332019382 + + module := `package test + + nums := numbers.range(0, data.nsets) + sizes := numbers.range(0, data.nsize) + + sets[n] = x { + nums[n] + x := {sprintf("%d,%d", [n, i]) | sizes[i]} + } + combined := {t | s := sets[_]; s[t]}` + + query := ast.MustParseBody("data.test.combined") + compiler := ast.MustCompileModules(map[string]string{ + "test.rego": module, + }) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error { + + q := NewQuery(query). + WithCompiler(compiler). + WithStore(store). + WithTransaction(txn) + + _, err := q.Run(ctx) + if err != nil { + return err + } + + return nil + }) + + if err != nil { + b.Fatal(err) + } + } + }) + } + } +} diff --git a/topdown/sets_test.go b/topdown/sets_test.go new file mode 100644 index 0000000000..2b93302661 --- /dev/null +++ b/topdown/sets_test.go @@ -0,0 +1,75 @@ +// Copyright 2022 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "testing" + + "github.com/open-policy-agent/opa/ast" +) + +func TestSetUnionBuiltin(t *testing.T) { + tests := []struct { + note string + query string + input string + expected string + }{ + // NOTE(philipc): These tests assume that erroneous types are + // checked elsewhere, and focus only on functional correctness. + { + note: "Empty", + input: `{set()}`, + expected: `set()`, + }, + { + note: "Singletons", + input: `{{1}, {2}, {3}, {4}, {5}}`, + expected: `{1, 2, 3, 4, 5}`, + }, + { + note: "One set", + input: `{{1, 2, 3, 4, 5}}`, + expected: `{1, 2, 3, 4, 5}`, + }, + { + note: "One set + empty", + input: `{{1, 2, 3, 4, 5}, set()}`, + expected: `{1, 2, 3, 4, 5}`, + }, + { + note: "Multiple sets, with duplicates", + input: `{{1, 2, 3}, {1, 2}, {3}, {4, 5}}`, + expected: `{1, 2, 3, 4, 5}`, + }, + } + + for _, tc := range tests { + inputs := ast.MustParseTerm(tc.input) + result, err := getResult(functionalWrapper1("union", builtinSetUnion), inputs) + if err != nil { + t.Fatal(err) + } + + expected := ast.MustParseTerm(tc.expected) + if !result.Equal(expected) { + t.Fatalf("Expected %v but got %v", expected, result) + } + } +} + +// Used to get older-style (ast.Term, error) tuples out of newer functions. +func getResult(fn BuiltinFunc, operands ...*ast.Term) (*ast.Term, error) { + var result *ast.Term + extractionFn := func(r *ast.Term) error { + result = r + return nil + } + err := fn(BuiltinContext{}, operands, extractionFn) + if err != nil { + return nil, err + } + return result, nil +}