diff --git a/ast/term.go b/ast/term.go index 2f5db614a1..3ec4b47870 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1792,7 +1792,7 @@ type Object interface { MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) Filter(filter Object) (Object, error) Keys() []*Term - Elem(i int) (*Term, *Term) + KeysIterator() ObjectKeysIterator get(k *Term) *objectElem // To prevent external implementations } @@ -1815,7 +1815,8 @@ type object struct { keys objectElemSlice ground int // number of key and value grounds. Counting is // required to support insert's key-value replace. - hash int + hash int + numInserts int // number of inserts since last sorting. } func newobject(n int) *object { @@ -1824,10 +1825,11 @@ func newobject(n int) *object { keys = make(objectElemSlice, 0, n) } return &object{ - elems: make(map[int]*objectElem, n), - keys: keys, - ground: 0, - hash: 0, + elems: make(map[int]*objectElem, n), + keys: keys, + ground: 0, + hash: 0, + numInserts: 0, } } @@ -1849,6 +1851,14 @@ func Item(key, value *Term) [2]*Term { return [2]*Term{key, value} } +func (obj *object) sortedKeys() objectElemSlice { + if obj.numInserts > 0 { + sort.Sort(obj.keys) + obj.numInserts = 0 + } + return obj.keys +} + // Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (obj *object) Compare(other Value) int { @@ -1861,29 +1871,32 @@ func (obj *object) Compare(other Value) int { } a := obj b := other.(*object) - minLen := len(a.keys) - if len(b.keys) < len(a.keys) { - minLen = len(b.keys) + // Ensure that keys are in canonical sorted order before use! + akeys := a.sortedKeys() + bkeys := b.sortedKeys() + minLen := len(akeys) + if len(b.keys) < len(akeys) { + minLen = len(bkeys) } for i := 0; i < minLen; i++ { - keysCmp := Compare(a.keys[i].key, b.keys[i].key) + keysCmp := Compare(akeys[i].key, bkeys[i].key) if keysCmp < 0 { return -1 } if keysCmp > 0 { return 1 } - valA := a.keys[i].value - valB := b.keys[i].value + valA := akeys[i].value + valB := bkeys[i].value valCmp := Compare(valA, valB) if valCmp != 0 { return valCmp } } - if len(a.keys) < len(b.keys) { + if len(akeys) < len(bkeys) { return -1 } - if len(b.keys) < len(a.keys) { + if len(bkeys) < len(akeys) { return 1 } return 0 @@ -1959,7 +1972,7 @@ func (obj *object) Intersect(other Object) [][3]*Term { // Iter calls the function f for each key-value pair in the object. If f // returns an error, iteration stops and the error is returned. func (obj *object) Iter(f func(*Term, *Term) error) error { - for _, node := range obj.keys { + for _, node := range obj.sortedKeys() { if err := f(node.key, node.value); err != nil { return err } @@ -2011,21 +2024,22 @@ func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, erro func (obj *object) Keys() []*Term { keys := make([]*Term, len(obj.keys)) - for i, elem := range obj.keys { + for i, elem := range obj.sortedKeys() { keys[i] = elem.key } return keys } -func (obj *object) Elem(i int) (*Term, *Term) { - return obj.keys[i].key, obj.keys[i].value +// Returns an iterator over the obj's keys. +func (obj *object) KeysIterator() ObjectKeysIterator { + return newobjectKeysIterator(obj) } // MarshalJSON returns JSON encoded bytes representing obj. func (obj *object) MarshalJSON() ([]byte, error) { sl := make([][2]*Term, obj.Len()) - for i, node := range obj.keys { + for i, node := range obj.sortedKeys() { sl[i] = Item(node.key, node.value) } return json.Marshal(sl) @@ -2105,7 +2119,7 @@ func (obj object) String() string { var b strings.Builder b.WriteRune('{') - for i, elem := range obj.keys { + for i, elem := range obj.sortedKeys() { if i > 0 { b.WriteString(", ") } @@ -2308,15 +2322,9 @@ func (obj *object) insert(k, v *Term) { next: head, } obj.elems[hash] = elem - i := sort.Search(len(obj.keys), func(i int) bool { return Compare(elem.key, obj.keys[i].key) < 0 }) - if i < len(obj.keys) { - // insert at position `i`: - obj.keys = append(obj.keys, nil) // add some space - copy(obj.keys[i+1:], obj.keys[i:]) // move things over - obj.keys[i] = elem // drop it in position - } else { - obj.keys = append(obj.keys, elem) - } + // O(1) insertion, but we'll have to re-sort the keys later. + obj.keys = append(obj.keys, elem) + obj.numInserts++ // Track insertions since the last re-sorting. obj.hash += hash + v.Hash() if k.IsGround() { @@ -2392,6 +2400,36 @@ func filterObject(o Value, filter Value) (Value, error) { } } +// NOTE(philipc): The only way to get an ObjectKeyIterator should be +// from an Object. This ensures that the iterator can have implementation- +// specific details internally, with no contracts except to the very +// limited interface. +type ObjectKeysIterator interface { + Next() (*Term, bool) +} + +type objectKeysIterator struct { + obj *object + numKeys int + index int +} + +func newobjectKeysIterator(o *object) ObjectKeysIterator { + return &objectKeysIterator{ + obj: o, + numKeys: o.Len(), + index: 0, + } +} + +func (oki *objectKeysIterator) Next() (*Term, bool) { + if oki.index == oki.numKeys || oki.numKeys == 0 { + return nil, false + } + oki.index++ + return oki.obj.sortedKeys()[oki.index-1].key, true +} + // ArrayComprehension represents an array comprehension as defined in the language. type ArrayComprehension struct { Term *Term `json:"term"` diff --git a/ast/term_bench_test.go b/ast/term_bench_test.go index 48892898a4..da1aadf22d 100644 --- a/ast/term_bench_test.go +++ b/ast/term_bench_test.go @@ -32,6 +32,25 @@ func BenchmarkObjectLookup(b *testing.B) { } } +func BenchmarkObjectCreationAndLookup(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + obj := NewObject() + for i := 0; i < n; i++ { + obj.Insert(StringTerm(fmt.Sprint(i)), IntNumberTerm(i)) + } + key := StringTerm(fmt.Sprint(n - 1)) + for i := 0; i < b.N; i++ { + value := obj.Get(key) + if value == nil { + b.Fatal("expected hit") + } + } + }) + } +} + func BenchmarkSetIntersection(b *testing.B) { sizes := []int{5, 50, 500, 5000} for _, n := range sizes { @@ -154,8 +173,45 @@ func BenchmarkObjectString(b *testing.B) { } } -func BenchmarkObjectConstruction(b *testing.B) { +// This benchmark works similarly to BenchmarkObjectString, but with a key +// difference: it benchmarks the String and MarshalJSON interface functions +// for the Objec, instead of the underlying data structure. This ensures +// that we catch the full performance properties of Object's implementation. +func BenchmarkObjectStringInterfaces(b *testing.B) { + var err error sizes := []int{5, 50, 500, 5000, 50000} + + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + + obj := map[string]int{} + for i := 0; i < n; i++ { + obj[fmt.Sprint(i)] = i + } + valString := MustInterfaceToValue(obj) + valJSON := MustInterfaceToValue(obj) + + b.Run("String()", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + str = valString.String() + } + }) + b.Run("json.Marshal", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + bs, err = json.Marshal(valJSON) + if err != nil { + b.Fatal(err) + } + } + }) + }) + } +} + +func BenchmarkObjectConstruction(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} seed := time.Now().UnixNano() b.Run("shuffled keys", func(b *testing.B) { diff --git a/ast/term_test.go b/ast/term_test.go index ec069aca09..1986566521 100644 --- a/ast/term_test.go +++ b/ast/term_test.go @@ -263,37 +263,6 @@ func TestObjectFilter(t *testing.T) { } } -func TestObjectInsertKeepsSorting(t *testing.T) { - keysSorted := func(o *object) func(int, int) bool { - return func(i, j int) bool { - return Compare(o.keys[i].key, o.keys[j].key) < 0 - } - } - - obj := NewObject( - [2]*Term{StringTerm("d"), IntNumberTerm(4)}, - [2]*Term{StringTerm("b"), IntNumberTerm(2)}, - [2]*Term{StringTerm("a"), IntNumberTerm(1)}, - ) - o := obj.(*object) - act := sort.SliceIsSorted(o.keys, keysSorted(o)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range o.keys { - t.Logf("elem[%d]: %v", i, o.keys[i].key) - } - } - - obj.Insert(StringTerm("c"), IntNumberTerm(3)) - act = sort.SliceIsSorted(o.keys, keysSorted(o)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range o.keys { - t.Logf("elem[%d]: %v", i, o.keys[i].key) - } - } -} - func TestSetInsertKeepsKeysSorting(t *testing.T) { keysSorted := func(s *set) func(int, int) bool { return func(i, j int) bool { diff --git a/topdown/eval.go b/topdown/eval.go index dc0f484600..2fb87204f5 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -909,20 +909,20 @@ func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyItera b = plugKeys(b, b2) } - return e.biunifyObjectsRec(a, b, b1, b2, iter, a, 0) + return e.biunifyObjectsRec(a, b, b1, b2, iter, a, a.KeysIterator()) } -func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, idx int) error { - if idx == keys.Len() { +func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, oki ast.ObjectKeysIterator) error { + key, more := oki.Next() // Get next key from iterator. + if !more { return iter() } - key, _ := keys.Elem(idx) v2 := b.Get(key) if v2 == nil { return nil } return e.biunify(a.Get(key), v2, b1, b2, func() error { - return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, idx+1) + return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, oki) }) }