diff --git a/ast/policy.go b/ast/policy.go index d1f1c660a0..fdb3d897ed 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -1053,10 +1053,22 @@ func (expr *Expr) sortOrder() int { return -1 } +// CopyWithoutTerms returns a deep copy of expr without its Terms +func (expr *Expr) CopyWithoutTerms() *Expr { + cpy := *expr + + cpy.With = make([]*With, len(expr.With)) + for i := range expr.With { + cpy.With[i] = expr.With[i].Copy() + } + + return &cpy +} + // Copy returns a deep copy of expr. func (expr *Expr) Copy() *Expr { - cpy := *expr + cpy := expr.CopyWithoutTerms() switch ts := expr.Terms.(type) { case *SomeDecl: @@ -1073,12 +1085,7 @@ func (expr *Expr) Copy() *Expr { cpy.Terms = ts.Copy() } - cpy.With = make([]*With, len(expr.With)) - for i := range expr.With { - cpy.With[i] = expr.With[i].Copy() - } - - return &cpy + return cpy } // Hash returns the hash code of the Expr. diff --git a/test/e2e/concurrency/concurrency_test.go b/test/e2e/concurrency/concurrency_test.go index e302ca3766..5aea394ba8 100644 --- a/test/e2e/concurrency/concurrency_test.go +++ b/test/e2e/concurrency/concurrency_test.go @@ -1,4 +1,4 @@ -package metrics +package concurrency import ( "flag" @@ -8,6 +8,7 @@ import ( "sync" "testing" + "github.com/open-policy-agent/opa/server/types" "github.com/open-policy-agent/opa/test/e2e" ) @@ -26,7 +27,7 @@ func TestMain(m *testing.M) { os.Exit(testRuntime.RunTests(m)) } -func TestConcurrency(t *testing.T) { +func TestConcurrencyGetV1Data(t *testing.T) { policy := ` package test @@ -39,10 +40,12 @@ func TestConcurrency(t *testing.T) { } var wg sync.WaitGroup + num := runtime.NumCPU() + wg.Add(num) - for i := 0; i < runtime.NumCPU(); i++ { - wg.Add(1) + for i := 0; i < num; i++ { go func() { + defer wg.Done() for n := 0; n < 1000; n++ { dr := struct { Result bool `json:"result"` @@ -56,10 +59,50 @@ func TestConcurrency(t *testing.T) { return } } - wg.Done() }() } wg.Wait() } + +func TestConcurrencyCompile(t *testing.T) { + + policy := ` + package test + f(_) + p { + not q + } + q { + not f(input.foo) + } + ` + + err := testRuntime.UploadPolicy(t.Name(), strings.NewReader(policy)) + if err != nil { + t.Fatal(err) + } + + req := types.CompileRequestV1{ + Query: "data.test.p", + } + + var wg sync.WaitGroup + num := runtime.NumCPU() + wg.Add(num) + + for i := 0; i < num; i++ { + go func() { + defer wg.Done() + for n := 0; n < 1000; n++ { + if _, err := testRuntime.CompileRequest(req); err != nil { + t.Error(err) + return + } + } + }() + } + + wg.Wait() +} diff --git a/test/e2e/metrics/metrics_test.go b/test/e2e/metrics/metrics_test.go index 15741b39ca..d4656ab626 100644 --- a/test/e2e/metrics/metrics_test.go +++ b/test/e2e/metrics/metrics_test.go @@ -162,7 +162,7 @@ func TestRequestWithInstrumentationV1CompileAPI(t *testing.T) { Unknowns: &[]string{"data.y"}, } - resp, err := testRuntime.CompileRequestWitInstrumentation(req) + resp, err := testRuntime.CompileRequestWithInstrumentation(req) if err != nil { t.Fatal(err) } diff --git a/test/e2e/testing.go b/test/e2e/testing.go index c8098dec34..ecc6c38cac 100644 --- a/test/e2e/testing.go +++ b/test/e2e/testing.go @@ -383,11 +383,24 @@ func (t *TestRuntime) GetData(url string) (io.ReadCloser, error) { return t.request("GET", url, nil) } -// CompileRequestWitInstrumentation will use the v1 compile API and POST with the given request and instrumentation enabled. -func (t *TestRuntime) CompileRequestWitInstrumentation(req types.CompileRequestV1) (*types.CompileResponseV1, error) { +// CompileRequestWithInstrumentation will use the v1 compile API and POST with the given request and instrumentation enabled. +func (t *TestRuntime) CompileRequestWithInstrumentation(req types.CompileRequestV1) (*types.CompileResponseV1, error) { + return t.compileRequest(req, true) +} + +// CompileRequest will use the v1 compile API and POST with the given request. +func (t *TestRuntime) CompileRequest(req types.CompileRequestV1) (*types.CompileResponseV1, error) { + return t.compileRequest(req, false) +} + +func (t *TestRuntime) compileRequest(req types.CompileRequestV1, instrument bool) (*types.CompileResponseV1, error) { inputPayload := util.MustMarshalJSON(req) - resp, err := t.request("POST", t.URL()+"/v1/compile?instrument", bytes.NewReader(inputPayload)) + url := t.URL() + "/v1/compile" + if instrument { + url += "?instrument" + } + resp, err := t.request("POST", url, bytes.NewReader(inputPayload)) if err != nil { return nil, err } diff --git a/topdown/eval.go b/topdown/eval.go index 9930f42104..ab90230141 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -664,18 +664,12 @@ func (e *eval) evalNotPartialSupport(negationID uint64, expr *ast.Expr, unknowns } // Save expression that refers to support rule set. - - terms := expr.Terms - expr.Terms = nil // Prevent unnecessary copying the terms. - cpy := expr.Copy() - expr.Terms = terms + cpy := expr.CopyWithoutTerms() if len(args) > 0 { terms := make([]*ast.Term, len(args)+1) terms[0] = term - for i := 0; i < len(args); i++ { - terms[i+1] = args[i] - } + copy(terms[1:], args) cpy.Terms = terms } else { cpy.Terms = term