From 9d2b1ad427363e72b949e55d75c36699df749666 Mon Sep 17 00:00:00 2001 From: Philip Conrad Date: Mon, 15 Aug 2022 10:37:49 -0400 Subject: [PATCH] Set Insertion Rework (#4999) This commit introduces lazy key slice sorting for the Set type, similar to what was done for Object types in #4830. After this change, sorting of the Set type's key slice will be delayed until just-before-use, identically to how lazy key slice sorting is done for the Object type. This will move the sorting overhead from construction-time for Sets over to evaluation-time, allowing much more efficient construction and use of enormous (500k+ item) Sets. This appears to be a performance-neutral change overall, while dramatically improving performance for the "large set" edge case. Signed-off-by: Philip Conrad --- ast/term.go | 51 ++++++++++++++++++++++-------------------- ast/term_bench_test.go | 47 ++++++++++++++++++++++++++++++++++++-- ast/term_test.go | 32 -------------------------- 3 files changed, 72 insertions(+), 58 deletions(-) diff --git a/ast/term.go b/ast/term.go index 3ec4b47870..e83084cf09 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1352,10 +1352,11 @@ func newset(n int) *set { keys = make([]*Term, 0, n) } return &set{ - elems: make(map[int]*Term, n), - keys: keys, - hash: 0, - ground: true, + elems: make(map[int]*Term, n), + keys: keys, + hash: 0, + ground: true, + numInserts: 0, } } @@ -1368,10 +1369,11 @@ func SetTerm(t ...*Term) *Term { } type set struct { - elems map[int]*Term - keys []*Term - hash int - ground bool + elems map[int]*Term + keys []*Term + hash int + ground bool + numInserts int // number of inserts since last sorting. } // Copy returns a deep copy of s. @@ -1401,7 +1403,7 @@ func (s *set) String() string { } var b strings.Builder b.WriteRune('{') - for i := range s.keys { + for i := range s.sortedKeys() { if i > 0 { b.WriteString(", ") } @@ -1411,6 +1413,14 @@ func (s *set) String() string { return b.String() } +func (s *set) sortedKeys() []*Term { + if s.numInserts > 0 { + sort.Sort(termSlice(s.keys)) + s.numInserts = 0 + } + return s.keys +} + // Compare compares s to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (s *set) Compare(other Value) int { @@ -1422,7 +1432,7 @@ func (s *set) Compare(other Value) int { return 1 } t := other.(*set) - return termSliceCompare(s.keys, t.keys) + return termSliceCompare(s.sortedKeys(), t.sortedKeys()) } // Find returns the set or dereferences the element itself. @@ -1488,7 +1498,7 @@ func (s *set) Add(t *Term) { // Iter calls f on each element in s. If f returns an error, iteration stops // and the return value is the error. func (s *set) Iter(f func(*Term) error) error { - for i := range s.keys { + for i := range s.sortedKeys() { if err := f(s.keys[i]); err != nil { return err } @@ -1564,20 +1574,19 @@ func (s *set) MarshalJSON() ([]byte, error) { if s.keys == nil { return []byte(`[]`), nil } - return json.Marshal(s.keys) + return json.Marshal(s.sortedKeys()) } // Sorted returns an Array that contains the sorted elements of s. func (s *set) Sorted() *Array { cpy := make([]*Term, len(s.keys)) - copy(cpy, s.keys) - sort.Sort(termSlice(cpy)) + copy(cpy, s.sortedKeys()) return NewArray(cpy...) } // Slice returns a slice of terms contained in the set. func (s *set) Slice() []*Term { - return s.keys + return s.sortedKeys() } func (s *set) insert(x *Term) { @@ -1670,15 +1679,9 @@ func (s *set) insert(x *Term) { } s.elems[insertHash] = x - i := sort.Search(len(s.keys), func(i int) bool { return Compare(x, s.keys[i]) < 0 }) - if i < len(s.keys) { - // insert at position `i`: - s.keys = append(s.keys, nil) // add some space - copy(s.keys[i+1:], s.keys[i:]) // move things over - s.keys[i] = x // drop it in position - } else { - s.keys = append(s.keys, x) - } + // O(1) insertion, but we'll have to re-sort the keys later. + s.keys = append(s.keys, x) + s.numInserts++ // Track insertions since the last re-sorting. s.hash += hash s.ground = s.ground && x.IsGround() diff --git a/ast/term_bench_test.go b/ast/term_bench_test.go index da1aadf22d..80f598e98a 100644 --- a/ast/term_bench_test.go +++ b/ast/term_bench_test.go @@ -51,6 +51,25 @@ func BenchmarkObjectCreationAndLookup(b *testing.B) { } } +func BenchmarkSetCreationAndLookup(b *testing.B) { + sizes := []int{5, 50, 500, 5000, 50000, 500000} + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + set := NewSet() + for i := 0; i < n; i++ { + set.Add(StringTerm(fmt.Sprint(i))) + } + key := StringTerm(fmt.Sprint(n - 1)) + for i := 0; i < b.N; i++ { + present := set.Contains(key) + if !present { + b.Fatal("expected hit") + } + } + }) + } +} + func BenchmarkSetIntersection(b *testing.B) { sizes := []int{5, 50, 500, 5000} for _, n := range sizes { @@ -288,11 +307,10 @@ func BenchmarkArrayString(b *testing.B) { } func BenchmarkSetString(b *testing.B) { - sizes := []int{5, 50, 500, 5000} + sizes := []int{5, 50, 500, 5000, 50000} for _, n := range sizes { b.Run(fmt.Sprint(n), func(b *testing.B) { - val := NewSet() for i := 0; i < n; i++ { val.Add(IntNumberTerm(i)) @@ -307,3 +325,28 @@ func BenchmarkSetString(b *testing.B) { }) } } + +func BenchmarkSetMarshalJSON(b *testing.B) { + var err error + sizes := []int{5, 50, 500, 5000, 50000} + + for _, n := range sizes { + b.Run(fmt.Sprint(n), func(b *testing.B) { + set := NewSet() + for i := 0; i < n; i++ { + set.Add(StringTerm(fmt.Sprint(i))) + } + + b.Run("json.Marshal", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + bs, err = json.Marshal(set) + if err != nil { + b.Fatal(err) + } + } + }) + }) + } + +} diff --git a/ast/term_test.go b/ast/term_test.go index 1986566521..35cf9374aa 100644 --- a/ast/term_test.go +++ b/ast/term_test.go @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" "reflect" - "sort" "strings" "testing" @@ -263,37 +262,6 @@ func TestObjectFilter(t *testing.T) { } } -func TestSetInsertKeepsKeysSorting(t *testing.T) { - keysSorted := func(s *set) func(int, int) bool { - return func(i, j int) bool { - return Compare(s.keys[i], s.keys[j]) < 0 - } - } - - s0 := NewSet( - StringTerm("d"), - StringTerm("b"), - StringTerm("a"), - ) - s := s0.(*set) - act := sort.SliceIsSorted(s.keys, keysSorted(s)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range s.keys { - t.Logf("elem[%d]: %v", i, s.keys[i]) - } - } - - s0.Add(StringTerm("c")) - act = sort.SliceIsSorted(s.keys, keysSorted(s)) - if exp := true; act != exp { - t.Errorf("SliceIsSorted: expected %v, got %v", exp, act) - for i := range s.keys { - t.Logf("elem[%d]: %v", i, s.keys[i]) - } - } -} - func TestTermBadJSON(t *testing.T) { input := `{