From d2ca577123f70f62af211addbbf349b8e412542e Mon Sep 17 00:00:00 2001 From: Philip Conrad Date: Fri, 8 Jul 2022 10:52:54 -0400 Subject: [PATCH] ast/term: Make Object key sorting lazy. (#4830) This commit delays the sorting of keys until just-before-use. This is a net win on asymptotics as Objects get larger, even with Quicksort as the sorting algorithm. This commit also adjusts the evaluator to use the new ObjectKeysIterator interface, instead of the raw keys array. Fixes #4625. Signed-off-by: Philip Conrad --- ast/term.go | 96 +++++++++++++++++++++++++++++------------- ast/term_bench_test.go | 58 ++++++++++++++++++++++++- ast/term_test.go | 31 -------------- topdown/eval.go | 10 ++--- 4 files changed, 129 insertions(+), 66 deletions(-) diff --git a/ast/term.go b/ast/term.go index 2f5db614a1..3ec4b47870 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1792,7 +1792,7 @@ type Object interface { MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) Filter(filter Object) (Object, error) Keys() []*Term - Elem(i int) (*Term, *Term) + KeysIterator() ObjectKeysIterator get(k *Term) *objectElem // To prevent external implementations } @@ -1815,7 +1815,8 @@ type object struct { keys objectElemSlice ground int // number of key and value grounds. Counting is // required to support insert's key-value replace. - hash int + hash int + numInserts int // number of inserts since last sorting. } func newobject(n int) *object { @@ -1824,10 +1825,11 @@ func newobject(n int) *object { keys = make(objectElemSlice, 0, n) } return &object{ - elems: make(map[int]*objectElem, n), - keys: keys, - ground: 0, - hash: 0, + elems: make(map[int]*objectElem, n), + keys: keys, + ground: 0, + hash: 0, + numInserts: 0, } } @@ -1849,6 +1851,14 @@ func Item(key, value *Term) [2]*Term { return [2]*Term{key, value} } +func (obj *object) sortedKeys() objectElemSlice { + if obj.numInserts > 0 { + sort.Sort(obj.keys) + obj.numInserts = 0 + } + return obj.keys +} + // Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (obj *object) Compare(other Value) int { @@ -1861,29 +1871,32 @@ func (obj *object) Compare(other Value) int { } a := obj b := other.(*object) - minLen := len(a.keys) - if len(b.keys) < len(a.keys) { - minLen = len(b.keys) + // Ensure that keys are in canonical sorted order before use! + akeys := a.sortedKeys() + bkeys := b.sortedKeys() + minLen := len(akeys) + if len(b.keys) < len(akeys) { + minLen = len(bkeys) } for i := 0; i < minLen; i++ { - keysCmp := Compare(a.keys[i].key, b.keys[i].key) + keysCmp := Compare(akeys[i].key, bkeys[i].key) if keysCmp < 0 { return -1 } if keysCmp > 0 { return 1 } - valA := a.keys[i].value - valB := b.keys[i].value + valA := akeys[i].value + valB := bkeys[i].value valCmp := Compare(valA, valB) if valCmp != 0 { return valCmp } } - if len(a.keys) < len(b.keys) { + if len(akeys) < len(bkeys) { return -1 } - if len(b.keys) < len(a.keys) { + if len(bkeys) < len(akeys) { return 1 } return 0 @@ -1959,7 +1972,7 @@ func (obj *object) Intersect(other Object) [][3]*Term { // Iter calls the function f for each key-value pair in the object. If f // returns an error, iteration stops and the error is returned. func (obj *object) Iter(f func(*Term, *Term) error) error { - for _, node := range obj.keys { + for _, node := range obj.sortedKeys() { if err := f(node.key, node.value); err != nil { return err } @@ -2011,21 +2024,22 @@ func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, erro func (obj *object) Keys() []*Term { keys := make([]*Term, len(obj.keys)) - for i, elem := range obj.keys { + for i, elem := range obj.sortedKeys() { keys[i] = elem.key } return keys } -func (obj *object) Elem(i int) (*Term, *Term) { - return obj.keys[i].key, obj.keys[i].value +// Returns an iterator over the obj's keys. +func (obj *object) KeysIterator() ObjectKeysIterator { + return newobjectKeysIterator(obj) } // MarshalJSON returns JSON encoded bytes representing obj. func (obj *object) MarshalJSON() ([]byte, error) { sl := make([][2]*Term, obj.Len()) - for i, node := range obj.keys { + for i, node := range obj.sortedKeys() { sl[i] = Item(node.key, node.value) } return json.Marshal(sl) @@ -2105,7 +2119,7 @@ func (obj object) String() string { var b strings.Builder b.WriteRune('{') - for i, elem := range obj.keys { + for i, elem := range obj.sortedKeys() { if i > 0 { b.WriteString(", ") } @@ -2308,15 +2322,9 @@ func (obj *object) insert(k, v *Term) { next: head, } obj.elems[hash] = elem - i := sort.Search(len(obj.keys), func(i int) bool { return Compare(elem.key, obj.keys[i].key) < 0 }) - if i < len(obj.keys) { - // insert at position `i`: - obj.keys = append(obj.keys, nil) // add some space - copy(obj.keys[i+1:], obj.keys[i:]) // move things over - obj.keys[i] = elem // drop it in position - } else { - obj.keys = append(obj.keys, elem) - } + // O(1) insertion, but we'll have to re-sort the keys later. + obj.keys = append(obj.keys, elem) + obj.numInserts++ // Track insertions since the last re-sorting. obj.hash += hash + v.Hash() if k.IsGround() { @@ -2392,6 +2400,36 @@ func filterObject(o Value, filter Value) (Value, error) { } } +// NOTE(philipc): The only way to get an ObjectKeyIterator should be +// from an Object. This ensures that the iterator can have implementation- +// specific details internally, with no contracts except to the very +// limited interface. +type ObjectKeysIterator interface { + Next() (*Term, bool) +} + +type objectKeysIterator struct { + obj *object + numKeys int + index int +} + +func newobjectKeysIterator(o *object) ObjectKeysIterator { + return &objectKeysIterator{ + obj: o, + numKeys: o.Len(), + index: 0, + } +} + +func (oki *objectKeysIterator) Next() (*Term, bool) { + if oki.index == oki.numKeys || oki.numKeys == 0 { + return nil, false + } + oki.index++ + return oki.obj.sortedKeys()[oki.index-1].key, true +} + // ArrayComprehension represents an array comprehension as defined in the language. type ArrayComprehension struct { Term *Term `json:"term"` diff --git a/ast/term_bench_test.go b/ast/term_bench_test.go index 48892898a4..da1aadf22d 100644 --- a/ast/term_bench_test.go +++ b/ast/term_bench_test.go @@ -32,6 +32,25 @@ func BenchmarkObjectLookup(b *testing.B) { } } +func BenchmarkObjectCreationAndLookup(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + obj := NewObject() + for i := 0; i < n; i++ { + obj.Insert(StringTerm(fmt.Sprint(i)), IntNumberTerm(i)) + } + key := StringTerm(fmt.Sprint(n - 1)) + for i := 0; i < b.N; i++ { + value := obj.Get(key) + if value == nil { + b.Fatal("expected hit") + } + } + }) + } +} + func BenchmarkSetIntersection(b *testing.B) { sizes := []int{5, 50, 500, 5000} for _, n := range sizes { @@ -154,8 +173,45 @@ func BenchmarkObjectString(b *testing.B) { } } -func BenchmarkObjectConstruction(b *testing.B) { +// This benchmark works similarly to BenchmarkObjectString, but with a key +// difference: it benchmarks the String and MarshalJSON interface functions +// for the Objec, instead of the underlying data structure. This ensures +// that we catch the full performance properties of Object's implementation. +func BenchmarkObjectStringInterfaces(b *testing.B) { + var err error sizes := []int{5, 50, 500, 5000, 50000} + + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + + obj := map[string]int{} + for i := 0; i < n; i++ { + obj[fmt.Sprint(i)] = i + } + valString := MustInterfaceToValue(obj) + valJSON := MustInterfaceToValue(obj) + + b.Run("String()", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + str = valString.String() + } + }) + b.Run("json.Marshal", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + bs, err = json.Marshal(valJSON) + if err != nil { + b.Fatal(err) + } + } + }) + }) + } +} + +func BenchmarkObjectConstruction(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} seed := time.Now().UnixNano() b.Run("shuffled keys", func(b *testing.B) { diff --git a/ast/term_test.go b/ast/term_test.go index ec069aca09..1986566521 100644 --- a/ast/term_test.go +++ b/ast/term_test.go @@ -263,37 +263,6 @@ func TestObjectFilter(t *testing.T) { } } -func TestObjectInsertKeepsSorting(t *testing.T) { - keysSorted := func(o *object) func(int, int) bool { - return func(i, j int) bool { - return Compare(o.keys[i].key, o.keys[j].key) < 0 - } - } - - obj := NewObject( - [2]*Term{StringTerm("d"), IntNumberTerm(4)}, - [2]*Term{StringTerm("b"), IntNumberTerm(2)}, - [2]*Term{StringTerm("a"), IntNumberTerm(1)}, - ) - o := obj.(*object) - act := sort.SliceIsSorted(o.keys, keysSorted(o)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range o.keys { - t.Logf("elem[%d]: %v", i, o.keys[i].key) - } - } - - obj.Insert(StringTerm("c"), IntNumberTerm(3)) - act = sort.SliceIsSorted(o.keys, keysSorted(o)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range o.keys { - t.Logf("elem[%d]: %v", i, o.keys[i].key) - } - } -} - func TestSetInsertKeepsKeysSorting(t *testing.T) { keysSorted := func(s *set) func(int, int) bool { return func(i, j int) bool { diff --git a/topdown/eval.go b/topdown/eval.go index dc0f484600..2fb87204f5 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -909,20 +909,20 @@ func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyItera b = plugKeys(b, b2) } - return e.biunifyObjectsRec(a, b, b1, b2, iter, a, 0) + return e.biunifyObjectsRec(a, b, b1, b2, iter, a, a.KeysIterator()) } -func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, idx int) error { - if idx == keys.Len() { +func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, oki ast.ObjectKeysIterator) error { + key, more := oki.Next() // Get next key from iterator. + if !more { return iter() } - key, _ := keys.Elem(idx) v2 := b.Get(key) if v2 == nil { return nil } return e.biunify(a.Get(key), v2, b1, b2, func() error { - return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, idx+1) + return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, oki) }) }