Skip to content

Commit

Permalink
ast+topdown: cache Array groundness, use for plugging
Browse files Browse the repository at this point in the history
I had been believing that we couldn't do this, since many methods
allow changing an array's elements in ways that would invalidate
the cached groundness bit. However, if we don't do that, we're OK.

To get the tests to pass, the walk builtin implementation had
implicitly depended on array plugging returning a copy; that's now
made explicit.

In topdown/bindings, this adds a few more shortcuts where applicable
(array/set). Previously, only objects had a ground?- shortcutin
plugging and namespacing vars.

Fixes open-policy-agent#3679.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Dec 17, 2021
1 parent 668708d commit 1842468
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 90 deletions.
33 changes: 22 additions & 11 deletions ast/term.go
Expand Up @@ -1076,26 +1076,30 @@ type QueryIterator func(map[Var]Value, Value) error

// ArrayTerm creates a new Term with an Array value.
func ArrayTerm(a ...*Term) *Term {
return &Term{Value: &Array{a, 0}}
return &Term{Value: &Array{elems: a, hash: 0, ground: termSliceIsGround(a)}}
}

// NewArray creates an Array with the terms provided. The array will
// use the provided term slice.
func NewArray(a ...*Term) *Array {
return &Array{a, 0}
return &Array{elems: a, hash: 0, ground: termSliceIsGround(a)}
}

// Array represents an array as defined by the language. Arrays are similar to the
// same types as defined by JSON with the exception that they can contain Vars
// and References.
type Array struct {
elems []*Term
hash int
elems []*Term
hash int
ground bool
}

// Copy returns a deep copy of arr.
func (arr *Array) Copy() *Array {
return &Array{termSliceCopy(arr.elems), arr.hash}
return &Array{
elems: termSliceCopy(arr.elems),
hash: arr.hash,
ground: arr.IsGround()}
}

// Equal returns true if arr is equal to other.
Expand Down Expand Up @@ -1170,13 +1174,13 @@ func (arr *Array) Hash() int {

// IsGround returns true if all of the Array elements are ground.
func (arr *Array) IsGround() bool {
return termSliceIsGround(arr.elems)
return arr.ground
}

// MarshalJSON returns JSON encoded bytes representing arr.
func (arr *Array) MarshalJSON() ([]byte, error) {
if len(arr.elems) == 0 {
return json.Marshal([]interface{}{})
return []byte(`[]`), nil
}
return json.Marshal(arr.elems)
}
Expand Down Expand Up @@ -1206,6 +1210,7 @@ func (arr *Array) Elem(i int) *Term {

// set sets the element i of arr.
func (arr *Array) set(i int, v *Term) {
arr.ground = arr.ground && v.IsGround()
arr.elems[i] = v
arr.hash = 0
}
Expand All @@ -1215,11 +1220,16 @@ func (arr *Array) set(i int, v *Term) {
// copy and any modifications to either of arrays may be reflected to
// the other.
func (arr *Array) Slice(i, j int) *Array {
var elems []*Term
if j == -1 {
return &Array{elems: arr.elems[i:]}
elems = arr.elems[i:]
} else {
elems = arr.elems[i:j]
}

return &Array{elems: arr.elems[i:j]}
// If arr is ground, the slice is, too.
// If it's not, the slice could still be.
gr := arr.ground || termSliceIsGround(elems)
return &Array{elems: elems, ground: gr}
}

// Iter calls f on each element in arr. If f returns an error,
Expand Down Expand Up @@ -1257,6 +1267,7 @@ func (arr *Array) Append(v *Term) *Array {
cpy := *arr
cpy.elems = append(arr.elems, v)
cpy.hash = 0
cpy.ground = arr.ground && v.IsGround()
return &cpy
}

Expand Down Expand Up @@ -1509,7 +1520,7 @@ func (s *set) Len() int {
// MarshalJSON returns JSON encoded bytes representing s.
func (s *set) MarshalJSON() ([]byte, error) {
if s.keys == nil {
return json.Marshal([]interface{}{})
return []byte(`[]`), nil
}
return json.Marshal(s.keys)
}
Expand Down
77 changes: 0 additions & 77 deletions test/cases/testdata/walkbuiltin/test-walkbuiltin-0971.yaml
Expand Up @@ -5,83 +5,6 @@ cases:
- 2
- 3
- 4
b:
v1: hello
v2: goodbye
c:
- x:
- true
- false
- foo
"y":
- null
- 3.14159
z:
p: true
q: false
d:
e:
- bar
- baz
f:
- xs:
- 1
ys:
- 2
- xs:
- 2
ys:
- 3
g:
a:
- 1
- 0
- 0
- 0
b:
- 0
- 2
- 0
- 0
c:
- 0
- 0
- 0
- 4
h:
- - 1
- 2
- 3
- - 2
- 3
- 4
l:
- a: bob
b: -1
c:
- 1
- 2
- 3
- 4
- a: alice
b: 1
c:
- 2
- 3
- 4
- 5
d: null
m: []
numbers:
- "1"
- "2"
- "3"
- "4"
strings:
bar: 2
baz: 3
foo: 1
three: 3
modules:
- |
package generated
Expand Down
12 changes: 12 additions & 0 deletions topdown/bindings.go
Expand Up @@ -87,6 +87,9 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
}
return u.namespaceVar(b, caller)
case *ast.Array:
if a.IsGround() {
return a
}
cpy := *a
arr := make([]*ast.Term, v.Len())
for i := 0; i < len(arr); i++ {
Expand All @@ -104,6 +107,9 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
})
return &cpy
case ast.Set:
if a.IsGround() {
return a
}
cpy := *a
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
return u.plugNamespaced(x, caller), nil
Expand Down Expand Up @@ -242,6 +248,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
case ast.Var:
return vis.b.namespaceVar(a, vis.caller)
case *ast.Array:
if a.IsGround() {
return a
}
cpy := *a
arr := make([]*ast.Term, v.Len())
for i := 0; i < len(arr); i++ {
Expand All @@ -259,6 +268,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
})
return &cpy
case ast.Set:
if a.IsGround() {
return a
}
cpy := *a
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
return vis.namespaceTerm(x), nil
Expand Down
48 changes: 48 additions & 0 deletions topdown/topdown_bench_test.go
Expand Up @@ -29,6 +29,54 @@ func BenchmarkArrayIteration(b *testing.B) {
}
}

func BenchmarkArrayPlugging(b *testing.B) {
ctx := context.Background()

sizes := []int{10, 100, 1000, 10000}

for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {
data := make([]interface{}, n)
for i := 0; i < n; i++ {
data[i] = fmt.Sprintf("whatever%d", i)
}
store := inmem.NewFromObject(map[string]interface{}{"fixture": data})
module := `package test
fixture := data.fixture
main { x := fixture }`

query := ast.MustParseBody("data.test.main")
compiler := ast.MustCompileModules(map[string]string{
"test.rego": module,
})

b.ResetTimer()

for i := 0; i < b.N; i++ {

err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error {

q := NewQuery(query).
WithCompiler(compiler).
WithStore(store).
WithTransaction(txn)

_, err := q.Run(ctx)
if err != nil {
return err
}

return nil
})

if err != nil {
b.Fatal(err)
}
}
})
}
}

func BenchmarkSetIteration(b *testing.B) {
sizes := []int{10, 100, 1000, 10000}
for _, n := range sizes {
Expand Down
4 changes: 2 additions & 2 deletions topdown/walk.go
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/open-policy-agent/opa/ast"
)

func evalWalk(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
func evalWalk(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
input := args[0]
filter := getOutputPath(args)
return walk(filter, nil, input, iter)
Expand All @@ -21,7 +21,7 @@ func walk(filter, path *ast.Array, input *ast.Term, iter func(*ast.Term) error)
path = ast.NewArray()
}

if err := iter(ast.ArrayTerm(ast.NewTerm(path), input)); err != nil {
if err := iter(ast.ArrayTerm(ast.NewTerm(path.Copy()), input)); err != nil {
return err
}
}
Expand Down

0 comments on commit 1842468

Please sign in to comment.