Skip to content

Commit

Permalink
topdown/eval: copy without modifying expr, update test/e2e/concurrency (
Browse files Browse the repository at this point in the history
open-policy-agent#4503)

What we previously did turned into a race condition with multiple
concurrent calls to /v1/compile.

With a change introduced with 0.38.0 (the `every` keyword), the
`nil` Terms of an `ast.Expr` node was surfaced: previously, it would
go unnoticed, but could potentially have yielded bad results.

The effect of this change is proven using a new e2e test that would
fail on the code we had previous.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus authored and rokkiter committed Apr 18, 2022
1 parent e932755 commit ea94481
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 24 deletions.
21 changes: 14 additions & 7 deletions ast/policy.go
Expand Up @@ -1053,10 +1053,22 @@ func (expr *Expr) sortOrder() int {
return -1
}

// CopyWithoutTerms returns a deep copy of expr without its Terms
func (expr *Expr) CopyWithoutTerms() *Expr {
cpy := *expr

cpy.With = make([]*With, len(expr.With))
for i := range expr.With {
cpy.With[i] = expr.With[i].Copy()
}

return &cpy
}

// Copy returns a deep copy of expr.
func (expr *Expr) Copy() *Expr {

cpy := *expr
cpy := expr.CopyWithoutTerms()

switch ts := expr.Terms.(type) {
case *SomeDecl:
Expand All @@ -1073,12 +1085,7 @@ func (expr *Expr) Copy() *Expr {
cpy.Terms = ts.Copy()
}

cpy.With = make([]*With, len(expr.With))
for i := range expr.With {
cpy.With[i] = expr.With[i].Copy()
}

return &cpy
return cpy
}

// Hash returns the hash code of the Expr.
Expand Down
53 changes: 48 additions & 5 deletions test/e2e/concurrency/concurrency_test.go
@@ -1,4 +1,4 @@
package metrics
package concurrency

import (
"flag"
Expand All @@ -8,6 +8,7 @@ import (
"sync"
"testing"

"github.com/open-policy-agent/opa/server/types"
"github.com/open-policy-agent/opa/test/e2e"
)

Expand All @@ -26,7 +27,7 @@ func TestMain(m *testing.M) {
os.Exit(testRuntime.RunTests(m))
}

func TestConcurrency(t *testing.T) {
func TestConcurrencyGetV1Data(t *testing.T) {

policy := `
package test
Expand All @@ -39,10 +40,12 @@ func TestConcurrency(t *testing.T) {
}

var wg sync.WaitGroup
num := runtime.NumCPU()
wg.Add(num)

for i := 0; i < runtime.NumCPU(); i++ {
wg.Add(1)
for i := 0; i < num; i++ {
go func() {
defer wg.Done()
for n := 0; n < 1000; n++ {
dr := struct {
Result bool `json:"result"`
Expand All @@ -56,10 +59,50 @@ func TestConcurrency(t *testing.T) {
return
}
}
wg.Done()
}()
}

wg.Wait()

}

func TestConcurrencyCompile(t *testing.T) {

policy := `
package test
f(_)
p {
not q
}
q {
not f(input.foo)
}
`

err := testRuntime.UploadPolicy(t.Name(), strings.NewReader(policy))
if err != nil {
t.Fatal(err)
}

req := types.CompileRequestV1{
Query: "data.test.p",
}

var wg sync.WaitGroup
num := runtime.NumCPU()
wg.Add(num)

for i := 0; i < num; i++ {
go func() {
defer wg.Done()
for n := 0; n < 1000; n++ {
if _, err := testRuntime.CompileRequest(req); err != nil {
t.Error(err)
return
}
}
}()
}

wg.Wait()
}
2 changes: 1 addition & 1 deletion test/e2e/metrics/metrics_test.go
Expand Up @@ -162,7 +162,7 @@ func TestRequestWithInstrumentationV1CompileAPI(t *testing.T) {
Unknowns: &[]string{"data.y"},
}

resp, err := testRuntime.CompileRequestWitInstrumentation(req)
resp, err := testRuntime.CompileRequestWithInstrumentation(req)
if err != nil {
t.Fatal(err)
}
Expand Down
19 changes: 16 additions & 3 deletions test/e2e/testing.go
Expand Up @@ -383,11 +383,24 @@ func (t *TestRuntime) GetData(url string) (io.ReadCloser, error) {
return t.request("GET", url, nil)
}

// CompileRequestWitInstrumentation will use the v1 compile API and POST with the given request and instrumentation enabled.
func (t *TestRuntime) CompileRequestWitInstrumentation(req types.CompileRequestV1) (*types.CompileResponseV1, error) {
// CompileRequestWithInstrumentation will use the v1 compile API and POST with the given request and instrumentation enabled.
func (t *TestRuntime) CompileRequestWithInstrumentation(req types.CompileRequestV1) (*types.CompileResponseV1, error) {
return t.compileRequest(req, true)
}

// CompileRequest will use the v1 compile API and POST with the given request.
func (t *TestRuntime) CompileRequest(req types.CompileRequestV1) (*types.CompileResponseV1, error) {
return t.compileRequest(req, false)
}

func (t *TestRuntime) compileRequest(req types.CompileRequestV1, instrument bool) (*types.CompileResponseV1, error) {
inputPayload := util.MustMarshalJSON(req)

resp, err := t.request("POST", t.URL()+"/v1/compile?instrument", bytes.NewReader(inputPayload))
url := t.URL() + "/v1/compile"
if instrument {
url += "?instrument"
}
resp, err := t.request("POST", url, bytes.NewReader(inputPayload))
if err != nil {
return nil, err
}
Expand Down
10 changes: 2 additions & 8 deletions topdown/eval.go
Expand Up @@ -664,18 +664,12 @@ func (e *eval) evalNotPartialSupport(negationID uint64, expr *ast.Expr, unknowns
}

// Save expression that refers to support rule set.

terms := expr.Terms
expr.Terms = nil // Prevent unnecessary copying the terms.
cpy := expr.Copy()
expr.Terms = terms
cpy := expr.CopyWithoutTerms()

if len(args) > 0 {
terms := make([]*ast.Term, len(args)+1)
terms[0] = term
for i := 0; i < len(args); i++ {
terms[i+1] = args[i]
}
copy(terms[1:], args)
cpy.Terms = terms
} else {
cpy.Terms = term
Expand Down

0 comments on commit ea94481

Please sign in to comment.