Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Object Insertion Rework #4830

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 67 additions & 29 deletions ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ type Object interface {
MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool)
Filter(filter Object) (Object, error)
Keys() []*Term
Elem(i int) (*Term, *Term)
KeysIterator() ObjectKeysIterator
get(k *Term) *objectElem // To prevent external implementations
}

Expand All @@ -1815,7 +1815,8 @@ type object struct {
keys objectElemSlice
ground int // number of key and value grounds. Counting is
// required to support insert's key-value replace.
hash int
hash int
numInserts int // number of inserts since last sorting.
}

func newobject(n int) *object {
Expand All @@ -1824,10 +1825,11 @@ func newobject(n int) *object {
keys = make(objectElemSlice, 0, n)
}
return &object{
elems: make(map[int]*objectElem, n),
keys: keys,
ground: 0,
hash: 0,
elems: make(map[int]*objectElem, n),
keys: keys,
ground: 0,
hash: 0,
numInserts: 0,
}
}

Expand All @@ -1849,6 +1851,14 @@ func Item(key, value *Term) [2]*Term {
return [2]*Term{key, value}
}

func (obj *object) sortedKeys() objectElemSlice {
if obj.numInserts > 0 {
sort.Sort(obj.keys)
obj.numInserts = 0
}
return obj.keys
}

// Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (obj *object) Compare(other Value) int {
Expand All @@ -1861,29 +1871,32 @@ func (obj *object) Compare(other Value) int {
}
a := obj
b := other.(*object)
minLen := len(a.keys)
if len(b.keys) < len(a.keys) {
minLen = len(b.keys)
// Ensure that keys are in canonical sorted order before use!
akeys := a.sortedKeys()
bkeys := b.sortedKeys()
minLen := len(akeys)
if len(b.keys) < len(akeys) {
minLen = len(bkeys)
}
for i := 0; i < minLen; i++ {
keysCmp := Compare(a.keys[i].key, b.keys[i].key)
keysCmp := Compare(akeys[i].key, bkeys[i].key)
if keysCmp < 0 {
return -1
}
if keysCmp > 0 {
return 1
}
valA := a.keys[i].value
valB := b.keys[i].value
valA := akeys[i].value
valB := bkeys[i].value
valCmp := Compare(valA, valB)
if valCmp != 0 {
return valCmp
}
}
if len(a.keys) < len(b.keys) {
if len(akeys) < len(bkeys) {
return -1
}
if len(b.keys) < len(a.keys) {
if len(bkeys) < len(akeys) {
return 1
}
return 0
Expand Down Expand Up @@ -1959,7 +1972,7 @@ func (obj *object) Intersect(other Object) [][3]*Term {
// Iter calls the function f for each key-value pair in the object. If f
// returns an error, iteration stops and the error is returned.
func (obj *object) Iter(f func(*Term, *Term) error) error {
for _, node := range obj.keys {
for _, node := range obj.sortedKeys() {
if err := f(node.key, node.value); err != nil {
return err
}
Expand Down Expand Up @@ -2011,21 +2024,22 @@ func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, erro
func (obj *object) Keys() []*Term {
keys := make([]*Term, len(obj.keys))

for i, elem := range obj.keys {
for i, elem := range obj.sortedKeys() {
keys[i] = elem.key
}

return keys
}

func (obj *object) Elem(i int) (*Term, *Term) {
return obj.keys[i].key, obj.keys[i].value
// Returns an iterator over the obj's keys.
func (obj *object) KeysIterator() ObjectKeysIterator {
return newobjectKeysIterator(obj)
}

// MarshalJSON returns JSON encoded bytes representing obj.
func (obj *object) MarshalJSON() ([]byte, error) {
sl := make([][2]*Term, obj.Len())
for i, node := range obj.keys {
for i, node := range obj.sortedKeys() {
sl[i] = Item(node.key, node.value)
}
return json.Marshal(sl)
Expand Down Expand Up @@ -2105,7 +2119,7 @@ func (obj object) String() string {
var b strings.Builder
b.WriteRune('{')

for i, elem := range obj.keys {
for i, elem := range obj.sortedKeys() {
if i > 0 {
b.WriteString(", ")
}
Expand Down Expand Up @@ -2308,15 +2322,9 @@ func (obj *object) insert(k, v *Term) {
next: head,
}
obj.elems[hash] = elem
i := sort.Search(len(obj.keys), func(i int) bool { return Compare(elem.key, obj.keys[i].key) < 0 })
if i < len(obj.keys) {
// insert at position `i`:
obj.keys = append(obj.keys, nil) // add some space
copy(obj.keys[i+1:], obj.keys[i:]) // move things over
obj.keys[i] = elem // drop it in position
} else {
obj.keys = append(obj.keys, elem)
}
// O(1) insertion, but we'll have to re-sort the keys later.
obj.keys = append(obj.keys, elem)
obj.numInserts++ // Track insertions since the last re-sorting.
obj.hash += hash + v.Hash()

if k.IsGround() {
Expand Down Expand Up @@ -2392,6 +2400,36 @@ func filterObject(o Value, filter Value) (Value, error) {
}
}

// NOTE(philipc): The only way to get an ObjectKeyIterator should be
// from an Object. This ensures that the iterator can have implementation-
// specific details internally, with no contracts except to the very
// limited interface.
type ObjectKeysIterator interface {
philipaconrad marked this conversation as resolved.
Show resolved Hide resolved
srenatus marked this conversation as resolved.
Show resolved Hide resolved
Next() (*Term, bool)
}

type objectKeysIterator struct {
obj *object
numKeys int
index int
}

func newobjectKeysIterator(o *object) ObjectKeysIterator {
return &objectKeysIterator{
obj: o,
numKeys: o.Len(),
index: 0,
}
}

func (oki *objectKeysIterator) Next() (*Term, bool) {
if oki.index == oki.numKeys || oki.numKeys == 0 {
return nil, false
}
oki.index++
return oki.obj.sortedKeys()[oki.index-1].key, true
}

// ArrayComprehension represents an array comprehension as defined in the language.
type ArrayComprehension struct {
Term *Term `json:"term"`
Expand Down
58 changes: 57 additions & 1 deletion ast/term_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ func BenchmarkObjectLookup(b *testing.B) {
}
}

func BenchmarkObjectCreationAndLookup(b *testing.B) {
sizes := []int{5, 50, 500, 5000, 50000, 500000}
for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {
obj := NewObject()
for i := 0; i < n; i++ {
obj.Insert(StringTerm(fmt.Sprint(i)), IntNumberTerm(i))
}
key := StringTerm(fmt.Sprint(n - 1))
for i := 0; i < b.N; i++ {
value := obj.Get(key)
if value == nil {
b.Fatal("expected hit")
}
}
})
}
}

func BenchmarkSetIntersection(b *testing.B) {
sizes := []int{5, 50, 500, 5000}
for _, n := range sizes {
Expand Down Expand Up @@ -154,8 +173,45 @@ func BenchmarkObjectString(b *testing.B) {
}
}

func BenchmarkObjectConstruction(b *testing.B) {
// This benchmark works similarly to BenchmarkObjectString, but with a key
// difference: it benchmarks the String and MarshalJSON interface functions
// for the Objec, instead of the underlying data structure. This ensures
// that we catch the full performance properties of Object's implementation.
func BenchmarkObjectStringInterfaces(b *testing.B) {
var err error
sizes := []int{5, 50, 500, 5000, 50000}

for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {

obj := map[string]int{}
for i := 0; i < n; i++ {
obj[fmt.Sprint(i)] = i
}
valString := MustInterfaceToValue(obj)
valJSON := MustInterfaceToValue(obj)

b.Run("String()", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
str = valString.String()
}
})
b.Run("json.Marshal", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
bs, err = json.Marshal(valJSON)
if err != nil {
b.Fatal(err)
}
}
})
})
}
}

func BenchmarkObjectConstruction(b *testing.B) {
sizes := []int{5, 50, 500, 5000, 50000, 500000}
seed := time.Now().UnixNano()

b.Run("shuffled keys", func(b *testing.B) {
Expand Down
31 changes: 0 additions & 31 deletions ast/term_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,37 +263,6 @@ func TestObjectFilter(t *testing.T) {
}
}

func TestObjectInsertKeepsSorting(t *testing.T) {
keysSorted := func(o *object) func(int, int) bool {
return func(i, j int) bool {
return Compare(o.keys[i].key, o.keys[j].key) < 0
}
}

obj := NewObject(
[2]*Term{StringTerm("d"), IntNumberTerm(4)},
[2]*Term{StringTerm("b"), IntNumberTerm(2)},
[2]*Term{StringTerm("a"), IntNumberTerm(1)},
)
o := obj.(*object)
act := sort.SliceIsSorted(o.keys, keysSorted(o))
if exp := true; act != exp {
t.Errorf("SliceIsSorted: expected %v, got %v", exp, act)
for i := range o.keys {
t.Logf("elem[%d]: %v", i, o.keys[i].key)
}
}

obj.Insert(StringTerm("c"), IntNumberTerm(3))
act = sort.SliceIsSorted(o.keys, keysSorted(o))
if exp := true; act != exp {
t.Errorf("SliceIsSorted: expected %v, got %v", exp, act)
for i := range o.keys {
t.Logf("elem[%d]: %v", i, o.keys[i].key)
}
}
}

func TestSetInsertKeepsKeysSorting(t *testing.T) {
keysSorted := func(s *set) func(int, int) bool {
return func(i, j int) bool {
Expand Down
10 changes: 5 additions & 5 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,20 +909,20 @@ func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyItera
b = plugKeys(b, b2)
}

return e.biunifyObjectsRec(a, b, b1, b2, iter, a, 0)
return e.biunifyObjectsRec(a, b, b1, b2, iter, a, a.KeysIterator())
}

func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, idx int) error {
if idx == keys.Len() {
func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, oki ast.ObjectKeysIterator) error {
key, more := oki.Next() // Get next key from iterator.
if !more {
return iter()
}
key, _ := keys.Elem(idx)
v2 := b.Get(key)
if v2 == nil {
return nil
}
return e.biunify(a.Get(key), v2, b1, b2, func() error {
return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, idx+1)
return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, oki)
})
}

Expand Down