Skip to content

Commit

Permalink
ast/term: Make Object key sorting lazy.
Browse files Browse the repository at this point in the history
This commit delays the sorting of keys until just-before-use. This
is a net win on asymptotics as Objects get larger, even with Quicksort
as the sorting algorithm.

This commit also adjusts the evaluator to use the new ObjectKeysIterator
interface, instead of the raw keys array.

Fixes open-policy-agent#4625.

Signed-off-by: Philip Conrad <philipaconrad@gmail.com>
  • Loading branch information
philipaconrad committed Jul 8, 2022
1 parent 3822ce1 commit 092150b
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 66 deletions.
96 changes: 67 additions & 29 deletions ast/term.go
Expand Up @@ -1792,7 +1792,7 @@ type Object interface {
MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool)
Filter(filter Object) (Object, error)
Keys() []*Term
Elem(i int) (*Term, *Term)
KeysIterator() ObjectKeysIterator
get(k *Term) *objectElem // To prevent external implementations
}

Expand All @@ -1815,7 +1815,8 @@ type object struct {
keys objectElemSlice
ground int // number of key and value grounds. Counting is
// required to support insert's key-value replace.
hash int
hash int
numInserts int // number of inserts since last sorting.
}

func newobject(n int) *object {
Expand All @@ -1824,10 +1825,11 @@ func newobject(n int) *object {
keys = make(objectElemSlice, 0, n)
}
return &object{
elems: make(map[int]*objectElem, n),
keys: keys,
ground: 0,
hash: 0,
elems: make(map[int]*objectElem, n),
keys: keys,
ground: 0,
hash: 0,
numInserts: 0,
}
}

Expand All @@ -1849,6 +1851,14 @@ func Item(key, value *Term) [2]*Term {
return [2]*Term{key, value}
}

func (obj *object) sortedKeys() objectElemSlice {
if obj.numInserts > 0 {
sort.Sort(obj.keys)
obj.numInserts = 0
}
return obj.keys
}

// Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (obj *object) Compare(other Value) int {
Expand All @@ -1861,29 +1871,32 @@ func (obj *object) Compare(other Value) int {
}
a := obj
b := other.(*object)
minLen := len(a.keys)
if len(b.keys) < len(a.keys) {
minLen = len(b.keys)
// Ensure that keys are in canonical sorted order before use!
akeys := a.sortedKeys()
bkeys := b.sortedKeys()
minLen := len(akeys)
if len(b.keys) < len(akeys) {
minLen = len(bkeys)
}
for i := 0; i < minLen; i++ {
keysCmp := Compare(a.keys[i].key, b.keys[i].key)
keysCmp := Compare(akeys[i].key, bkeys[i].key)
if keysCmp < 0 {
return -1
}
if keysCmp > 0 {
return 1
}
valA := a.keys[i].value
valB := b.keys[i].value
valA := akeys[i].value
valB := bkeys[i].value
valCmp := Compare(valA, valB)
if valCmp != 0 {
return valCmp
}
}
if len(a.keys) < len(b.keys) {
if len(akeys) < len(bkeys) {
return -1
}
if len(b.keys) < len(a.keys) {
if len(bkeys) < len(akeys) {
return 1
}
return 0
Expand Down Expand Up @@ -1959,7 +1972,7 @@ func (obj *object) Intersect(other Object) [][3]*Term {
// Iter calls the function f for each key-value pair in the object. If f
// returns an error, iteration stops and the error is returned.
func (obj *object) Iter(f func(*Term, *Term) error) error {
for _, node := range obj.keys {
for _, node := range obj.sortedKeys() {
if err := f(node.key, node.value); err != nil {
return err
}
Expand Down Expand Up @@ -2011,21 +2024,22 @@ func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, erro
func (obj *object) Keys() []*Term {
keys := make([]*Term, len(obj.keys))

for i, elem := range obj.keys {
for i, elem := range obj.sortedKeys() {
keys[i] = elem.key
}

return keys
}

func (obj *object) Elem(i int) (*Term, *Term) {
return obj.keys[i].key, obj.keys[i].value
// Returns an iterator over the obj's keys.
func (obj *object) KeysIterator() ObjectKeysIterator {
return newobjectKeysIterator(obj)
}

// MarshalJSON returns JSON encoded bytes representing obj.
func (obj *object) MarshalJSON() ([]byte, error) {
sl := make([][2]*Term, obj.Len())
for i, node := range obj.keys {
for i, node := range obj.sortedKeys() {
sl[i] = Item(node.key, node.value)
}
return json.Marshal(sl)
Expand Down Expand Up @@ -2105,7 +2119,7 @@ func (obj object) String() string {
var b strings.Builder
b.WriteRune('{')

for i, elem := range obj.keys {
for i, elem := range obj.sortedKeys() {
if i > 0 {
b.WriteString(", ")
}
Expand Down Expand Up @@ -2308,15 +2322,9 @@ func (obj *object) insert(k, v *Term) {
next: head,
}
obj.elems[hash] = elem
i := sort.Search(len(obj.keys), func(i int) bool { return Compare(elem.key, obj.keys[i].key) < 0 })
if i < len(obj.keys) {
// insert at position `i`:
obj.keys = append(obj.keys, nil) // add some space
copy(obj.keys[i+1:], obj.keys[i:]) // move things over
obj.keys[i] = elem // drop it in position
} else {
obj.keys = append(obj.keys, elem)
}
// O(1) insertion, but we'll have to re-sort the keys later.
obj.keys = append(obj.keys, elem)
obj.numInserts++ // Track insertions since the last re-sorting.
obj.hash += hash + v.Hash()

if k.IsGround() {
Expand Down Expand Up @@ -2392,6 +2400,36 @@ func filterObject(o Value, filter Value) (Value, error) {
}
}

// NOTE(philipc): The only way to get an ObjectKeyIterator should be
// from an Object. This ensures that the iterator can have implementation-
// specific details internally, with no contracts except to the very
// limited interface.
type ObjectKeysIterator interface {
Next() (*Term, bool)
}

type objectKeysIterator struct {
obj *object
numKeys int
index int
}

func newobjectKeysIterator(o *object) ObjectKeysIterator {
return &objectKeysIterator{
obj: o,
numKeys: o.Len(),
index: 0,
}
}

func (oki *objectKeysIterator) Next() (*Term, bool) {
if oki.index == oki.numKeys || oki.numKeys == 0 {
return nil, false
}
oki.index++
return oki.obj.sortedKeys()[oki.index-1].key, true
}

// ArrayComprehension represents an array comprehension as defined in the language.
type ArrayComprehension struct {
Term *Term `json:"term"`
Expand Down
58 changes: 57 additions & 1 deletion ast/term_bench_test.go
Expand Up @@ -32,6 +32,25 @@ func BenchmarkObjectLookup(b *testing.B) {
}
}

func BenchmarkObjectCreationAndLookup(b *testing.B) {
sizes := []int{5, 50, 500, 5000, 50000, 500000}
for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {
obj := NewObject()
for i := 0; i < n; i++ {
obj.Insert(StringTerm(fmt.Sprint(i)), IntNumberTerm(i))
}
key := StringTerm(fmt.Sprint(n - 1))
for i := 0; i < b.N; i++ {
value := obj.Get(key)
if value == nil {
b.Fatal("expected hit")
}
}
})
}
}

func BenchmarkSetIntersection(b *testing.B) {
sizes := []int{5, 50, 500, 5000}
for _, n := range sizes {
Expand Down Expand Up @@ -154,8 +173,45 @@ func BenchmarkObjectString(b *testing.B) {
}
}

func BenchmarkObjectConstruction(b *testing.B) {
// This benchmark works similarly to BenchmarkObjectString, but with a key
// difference: it benchmarks the String and MarshalJSON interface functions
// for the Objec, instead of the underlying data structure. This ensures
// that we catch the full performance properties of Object's implementation.
func BenchmarkObjectStringInterfaces(b *testing.B) {
var err error
sizes := []int{5, 50, 500, 5000, 50000}

for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {

obj := map[string]int{}
for i := 0; i < n; i++ {
obj[fmt.Sprint(i)] = i
}
valString := MustInterfaceToValue(obj)
valJSON := MustInterfaceToValue(obj)

b.Run("String()", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
str = valString.String()
}
})
b.Run("json.Marshal", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
bs, err = json.Marshal(valJSON)
if err != nil {
b.Fatal(err)
}
}
})
})
}
}

func BenchmarkObjectConstruction(b *testing.B) {
sizes := []int{5, 50, 500, 5000, 50000, 500000}
seed := time.Now().UnixNano()

b.Run("shuffled keys", func(b *testing.B) {
Expand Down
31 changes: 0 additions & 31 deletions ast/term_test.go
Expand Up @@ -263,37 +263,6 @@ func TestObjectFilter(t *testing.T) {
}
}

func TestObjectInsertKeepsSorting(t *testing.T) {
keysSorted := func(o *object) func(int, int) bool {
return func(i, j int) bool {
return Compare(o.keys[i].key, o.keys[j].key) < 0
}
}

obj := NewObject(
[2]*Term{StringTerm("d"), IntNumberTerm(4)},
[2]*Term{StringTerm("b"), IntNumberTerm(2)},
[2]*Term{StringTerm("a"), IntNumberTerm(1)},
)
o := obj.(*object)
act := sort.SliceIsSorted(o.keys, keysSorted(o))
if exp := true; act != exp {
t.Errorf("SliceIsSorted: expected %v, got %v", exp, act)
for i := range o.keys {
t.Logf("elem[%d]: %v", i, o.keys[i].key)
}
}

obj.Insert(StringTerm("c"), IntNumberTerm(3))
act = sort.SliceIsSorted(o.keys, keysSorted(o))
if exp := true; act != exp {
t.Errorf("SliceIsSorted: expected %v, got %v", exp, act)
for i := range o.keys {
t.Logf("elem[%d]: %v", i, o.keys[i].key)
}
}
}

func TestSetInsertKeepsKeysSorting(t *testing.T) {
keysSorted := func(s *set) func(int, int) bool {
return func(i, j int) bool {
Expand Down
10 changes: 5 additions & 5 deletions topdown/eval.go
Expand Up @@ -909,20 +909,20 @@ func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyItera
b = plugKeys(b, b2)
}

return e.biunifyObjectsRec(a, b, b1, b2, iter, a, 0)
return e.biunifyObjectsRec(a, b, b1, b2, iter, a, a.KeysIterator())
}

func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, idx int) error {
if idx == keys.Len() {
func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, oki ast.ObjectKeysIterator) error {
key, more := oki.Next() // Get next key from iterator.
if !more {
return iter()
}
key, _ := keys.Elem(idx)
v2 := b.Get(key)
if v2 == nil {
return nil
}
return e.biunify(a.Get(key), v2, b1, b2, func() error {
return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, idx+1)
return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, oki)
})
}

Expand Down

0 comments on commit 092150b

Please sign in to comment.