From 1dfa6bd723762ef26a73d9963452a28411ca1634 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Mon, 21 Feb 2022 10:40:12 +0100 Subject: [PATCH] ast: hash containers on insert/update (#4367) * ast: hash containers at insertion/creation time This allows us to use the Hashes for comparisons, too, since they're cheaply available all the time. * rego_test: add test for concurrent eval of PreparedEvalResult Fixes #4345. Signed-off-by: Stephan Renatus --- ast/term.go | 68 ++++++++++++++++++++++++------------------- ast/term_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++++- rego/rego_test.go | 56 ++++++++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 31 deletions(-) diff --git a/ast/term.go b/ast/term.go index f56f16794f..a6db2e8d9b 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1090,13 +1090,19 @@ type QueryIterator func(map[Var]Value, Value) error // ArrayTerm creates a new Term with an Array value. func ArrayTerm(a ...*Term) *Term { - return &Term{Value: &Array{elems: a, hash: 0, ground: termSliceIsGround(a)}} + return NewTerm(NewArray(a...)) } // NewArray creates an Array with the terms provided. The array will // use the provided term slice. func NewArray(a ...*Term) *Array { - return &Array{elems: a, hash: 0, ground: termSliceIsGround(a)} + hs := make([]int, len(a)) + for i, e := range a { + hs[i] = e.Value.Hash() + } + arr := &Array{elems: a, hashs: hs, ground: termSliceIsGround(a)} + arr.rehash() + return arr } // Array represents an array as defined by the language. Arrays are similar to the @@ -1104,14 +1110,18 @@ func NewArray(a ...*Term) *Array { // and References. type Array struct { elems []*Term + hashs []int // element hashes hash int ground bool } // Copy returns a deep copy of arr. func (arr *Array) Copy() *Array { + cpy := make([]int, len(arr.elems)) + copy(cpy, arr.hashs) return &Array{ elems: termSliceCopy(arr.elems), + hashs: cpy, hash: arr.hash, ground: arr.IsGround()} } @@ -1173,16 +1183,12 @@ func (arr *Array) Sorted() *Array { } sort.Sort(termSlice(cpy)) a := NewArray(cpy...) - a.hash = arr.hash + a.hashs = arr.hashs return a } // Hash returns the hash code for the Value. func (arr *Array) Hash() int { - if arr.hash == 0 { - arr.hash = termSliceHash(arr.elems) - } - return arr.hash } @@ -1222,11 +1228,19 @@ func (arr *Array) Elem(i int) *Term { return arr.elems[i] } +// rehash updates the cached hash of arr. +func (arr *Array) rehash() { + arr.hash = 0 + for _, h := range arr.hashs { + arr.hash += h + } +} + // set sets the element i of arr. func (arr *Array) set(i int, v *Term) { arr.ground = arr.ground && v.IsGround() arr.elems[i] = v - arr.hash = 0 + arr.hashs[i] = v.Value.Hash() } // Slice returns a slice of arr starting from i index to j. -1 @@ -1235,15 +1249,21 @@ func (arr *Array) set(i int, v *Term) { // the other. func (arr *Array) Slice(i, j int) *Array { var elems []*Term + var hashs []int if j == -1 { elems = arr.elems[i:] + hashs = arr.hashs[i:] } else { elems = arr.elems[i:j] + hashs = arr.hashs[i:j] } // If arr is ground, the slice is, too. // If it's not, the slice could still be. gr := arr.ground || termSliceIsGround(elems) - return &Array{elems: elems, ground: gr} + + s := &Array{elems: elems, hashs: hashs, ground: gr} + s.rehash() + return s } // Iter calls f on each element in arr. If f returns an error, @@ -1280,7 +1300,8 @@ func (arr *Array) Foreach(f func(*Term)) { func (arr *Array) Append(v *Term) *Array { cpy := *arr cpy.elems = append(arr.elems, v) - cpy.hash = 0 + cpy.hashs = append(arr.hashs, v.Value.Hash()) + cpy.hash = arr.hash + v.Value.Hash() cpy.ground = arr.ground && v.IsGround() return &cpy } @@ -1359,11 +1380,6 @@ func (s *set) IsGround() bool { // Hash returns a hash code for s. func (s *set) Hash() int { - if s.hash == 0 { - s.Foreach(func(x *Term) { - s.hash += x.Hash() - }) - } return s.hash } @@ -1556,6 +1572,7 @@ func (s *set) Slice() []*Term { func (s *set) insert(x *Term) { hash := x.Hash() + insertHash := hash // This `equal` utility is duplicated and manually inlined a number of // time in this file. Inlining it avoids heap allocations, so it makes // a big performance difference: some operations like lookup become twice @@ -1633,16 +1650,16 @@ func (s *set) insert(x *Term) { equal = func(y Value) bool { return Compare(x, y) == 0 } } - for curr, ok := s.elems[hash]; ok; { + for curr, ok := s.elems[insertHash]; ok; { if equal(curr.Value) { return } - hash++ - curr, ok = s.elems[hash] + insertHash++ + curr, ok = s.elems[insertHash] } - s.elems[hash] = x + s.elems[insertHash] = x i := sort.Search(len(s.keys), func(i int) bool { return Compare(x, s.keys[i]) < 0 }) if i < len(s.keys) { // insert at position `i`: @@ -1653,7 +1670,7 @@ func (s *set) insert(x *Term) { s.keys = append(s.keys, x) } - s.hash = 0 + s.hash += hash s.ground = s.ground && x.IsGround() } @@ -1888,14 +1905,6 @@ func (obj *object) Get(k *Term) *Term { // Hash returns the hash code for the Value. func (obj *object) Hash() int { - if obj.hash == 0 { - for h, curr := range obj.elems { - for ; curr != nil; curr = curr.next { - obj.hash += h - obj.hash += curr.value.Hash() - } - } - } return obj.hash } @@ -2280,7 +2289,6 @@ func (obj *object) insert(k, v *Term) { } curr.value = v - obj.hash = 0 return } } @@ -2299,7 +2307,7 @@ func (obj *object) insert(k, v *Term) { } else { obj.keys = append(obj.keys, elem) } - obj.hash = 0 + obj.hash += hash + v.Hash() if k.IsGround() { obj.ground++ diff --git a/ast/term_test.go b/ast/term_test.go index ca2d482b48..ec069aca09 100644 --- a/ast/term_test.go +++ b/ast/term_test.go @@ -418,7 +418,7 @@ func TestFind(t *testing.T) { } } -func TestHash(t *testing.T) { +func TestHashObject(t *testing.T) { doc := `{"a": [[true, {"b": [null]}, {"c": "d"}]], "e": {100: a[i].b}, "k": ["foo" | true], "o": {"foo": "bar" | true}, "sc": {"foo" | true}, "s": {1, 2, {3, 4}}, "big": 1e+1000}` @@ -431,6 +431,77 @@ func TestHash(t *testing.T) { if obj1.Hash() != obj2.Hash() { t.Errorf("Expected hash codes to be equal") } + + // Calculate hash like we did before moving the caching to create/update: + obj := obj1.(*object) + exp := 0 + for h, curr := range obj.elems { + for ; curr != nil; curr = curr.next { + exp += h + exp += curr.value.Hash() + } + } + + if act := obj1.Hash(); exp != act { + t.Errorf("expected %v, got %v", exp, act) + } +} + +func TestHashArray(t *testing.T) { + + doc := `[{"a": [[true, {"b": [null]}, {"c": "d"}]]}, 100, true, [a[i].b], {100: a[i].b}, ["foo" | true], {"foo": "bar" | true}, {"foo" | true}, {1, 2, {3, 4}}, 1e+1000]` + + stmt1 := MustParseStatement(doc) + stmt2 := MustParseStatement(doc) + + arr1 := stmt1.(Body)[0].Terms.(*Term).Value.(*Array) + arr2 := stmt2.(Body)[0].Terms.(*Term).Value.(*Array) + + if arr1.Hash() != arr2.Hash() { + t.Errorf("Expected hash codes to be equal") + } + + // Calculate hash like we did before moving the caching to create/update: + exp := termSliceHash(arr1.elems) + + if act := arr1.Hash(); exp != act { + t.Errorf("expected %v, got %v", exp, act) + } + + for j := 0; j < arr1.Len(); j++ { + for i := 0; i <= j; i++ { + slice := arr1.Slice(i, j) + exp := termSliceHash(slice.elems) + if act := slice.Hash(); exp != act { + t.Errorf("arr1[%d:%d]: expected %v, got %v", i, j, exp, act) + } + } + } +} + +func TestHashSet(t *testing.T) { + + doc := `{{"a": [[true, {"b": [null]}, {"c": "d"}]]}, 100, 100, 100, true, [a[i].b], {100: a[i].b}, ["foo" | true], {"foo": "bar" | true}, {"foo" | true}, {1, 2, {3, 4}}, 1e+1000}` + + stmt1 := MustParseStatement(doc) + stmt2 := MustParseStatement(doc) + + set1 := stmt1.(Body)[0].Terms.(*Term).Value.(Set) + set2 := stmt2.(Body)[0].Terms.(*Term).Value.(Set) + + if set1.Hash() != set2.Hash() { + t.Errorf("Expected hash codes to be equal") + } + + // Calculate hash like we did before moving the caching to create/update: + exp := 0 + set1.Foreach(func(x *Term) { + exp += x.Hash() + }) + + if act := set1.Hash(); exp != act { + t.Errorf("expected %v, got %v", exp, act) + } } func TestTermIsGround(t *testing.T) { diff --git a/rego/rego_test.go b/rego/rego_test.go index 0b251ec5d1..46e86dbaee 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -18,6 +18,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -658,6 +659,61 @@ func TestPartialRewriteEquals(t *testing.T) { } } +// NOTE(sr): https://github.com/open-policy-agent/opa/issues/4345 +func TestPrepareAndEvalRaceConditions(t *testing.T) { + tests := []struct { + note string + module string + exp string + }{ + { + note: "object", + module: `package test + p[{"x":"y"}]`, + exp: `[[[{"x":"y"}]]]`, + }, + { + note: "set", + module: `package test + p[{"x"}]`, + exp: `[[[["x"]]]]`, + }, + { + note: "array", + module: `package test + p[["x"]]`, + exp: `[[[["x"]]]]`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + r := New( + Query("data.test.p"), + Module("", tc.module), + Package("foo"), + ) + + pq, err := r.PrepareForEval(context.Background()) + if err != nil { + t.Fatalf("Unexpected error: %s", err.Error()) + } + + // run this 1000 times concurrently + var wg sync.WaitGroup + wg.Add(1000) + for i := 0; i < 1000; i++ { + go func(t *testing.T) { + t.Helper() + assertPreparedEvalQueryEval(t, pq, []EvalOption{}, tc.exp) + wg.Done() + }(t) + } + wg.Wait() + }) + } +} + func TestPrepareAndEvalNewInput(t *testing.T) { module := ` package test